diff --git a/WORKSPACE b/WORKSPACE index 31a7a1b29..0aee35c67 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -64,6 +64,12 @@ http_archive( sha256 = "267103f8a1e9578978aa1dc256001e6529ef593e5aea38193d31c2872ee025e8", strip_prefix = "glog-0.3.5", build_file = "@//third_party:glog.BUILD", + patches = [ + "@//third_party:com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff" + ], + patch_args = [ + "-p1", + ], ) # libyuv diff --git a/mediapipe/calculators/audio/audio_decoder_calculator.cc b/mediapipe/calculators/audio/audio_decoder_calculator.cc index 24689492e..b80b64bae 100644 --- a/mediapipe/calculators/audio/audio_decoder_calculator.cc +++ b/mediapipe/calculators/audio/audio_decoder_calculator.cc @@ -61,7 +61,9 @@ class AudioDecoderCalculator : public CalculatorBase { ::mediapipe::Status AudioDecoderCalculator::GetContract( CalculatorContract* cc) { cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set(); - + if (cc->InputSidePackets().HasTag("OPTIONS")) { + cc->InputSidePackets().Tag("OPTIONS").Set(); + } cc->Outputs().Tag("AUDIO").Set(); if (cc->Outputs().HasTag("AUDIO_HEADER")) { cc->Outputs().Tag("AUDIO_HEADER").SetNone(); @@ -72,7 +74,9 @@ class AudioDecoderCalculator : public CalculatorBase { ::mediapipe::Status AudioDecoderCalculator::Open(CalculatorContext* cc) { const std::string& input_file_path = cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get(); - const auto& decoder_options = cc->Options(); + const auto& decoder_options = + tool::RetrieveOptions(cc->Options(), + cc->InputSidePackets(), "OPTIONS"); decoder_ = absl::make_unique(); MP_RETURN_IF_ERROR(decoder_->Initialize(input_file_path, decoder_options)); std::unique_ptr header = diff --git a/mediapipe/calculators/audio/stabilized_log_calculator.cc b/mediapipe/calculators/audio/stabilized_log_calculator.cc index 50ccc01a0..b5623ee0f 100644 --- a/mediapipe/calculators/audio/stabilized_log_calculator.cc +++ b/mediapipe/calculators/audio/stabilized_log_calculator.cc @@ -75,8 +75,13 @@ class StabilizedLogCalculator : public CalculatorBase { ::mediapipe::Status Process(CalculatorContext* cc) override { auto input_matrix = cc->Inputs().Index(0).Get(); + if (input_matrix.array().isNaN().any()) { + return ::mediapipe::InvalidArgumentError("NaN input to log operation."); + } if (check_nonnegativity_) { - CHECK_GE(input_matrix.minCoeff(), 0); + if (input_matrix.minCoeff() < 0.0) { + return ::mediapipe::OutOfRangeError("Negative input to log operation."); + } } std::unique_ptr output_frame(new Matrix( output_scale_ * (input_matrix.array() + stabilizer_).log().matrix())); diff --git a/mediapipe/calculators/audio/stabilized_log_calculator_test.cc b/mediapipe/calculators/audio/stabilized_log_calculator_test.cc index 9831f4fe9..e6e0b5c6f 100644 --- a/mediapipe/calculators/audio/stabilized_log_calculator_test.cc +++ b/mediapipe/calculators/audio/stabilized_log_calculator_test.cc @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include #include "Eigen/Core" #include "mediapipe/calculators/audio/stabilized_log_calculator.pb.h" @@ -108,13 +109,22 @@ TEST_F(StabilizedLogCalculatorTest, ZerosAreStabilized) { runner_->Outputs().Index(0).packets[0].Get()); } -TEST_F(StabilizedLogCalculatorTest, NegativeValuesCheckFail) { +TEST_F(StabilizedLogCalculatorTest, NanValuesReturnError) { + InitializeGraph(); + FillInputHeader(); + AppendInputPacket( + new Matrix(Matrix::Constant(kNumChannels, kNumSamples, std::nanf(""))), + 0 /* timestamp */); + ASSERT_FALSE(RunGraph().ok()); +} + +TEST_F(StabilizedLogCalculatorTest, NegativeValuesReturnError) { InitializeGraph(); FillInputHeader(); AppendInputPacket( new Matrix(Matrix::Constant(kNumChannels, kNumSamples, -1.0)), 0 /* timestamp */); - ASSERT_DEATH(RunGraphNoReturn(), ""); + ASSERT_FALSE(RunGraph().ok()); } TEST_F(StabilizedLogCalculatorTest, NegativeValuesDoNotCheckFailIfCheckIsOff) { diff --git a/mediapipe/calculators/audio/time_series_framer_calculator.cc b/mediapipe/calculators/audio/time_series_framer_calculator.cc index 34adb5700..04f593bca 100644 --- a/mediapipe/calculators/audio/time_series_framer_calculator.cc +++ b/mediapipe/calculators/audio/time_series_framer_calculator.cc @@ -56,6 +56,14 @@ namespace mediapipe { // If pad_final_packet is true, all input samples will be emitted and the final // packet will be zero padded as necessary. If pad_final_packet is false, some // samples may be dropped at the end of the stream. +// +// If use_local_timestamp is true, the output packet's timestamp is based on the +// last sample of the packet. The timestamp of this sample is inferred by +// input_packet_timesamp + local_sample_index / sampling_rate_. If false, the +// output packet's timestamp is based on the cumulative timestamping, which is +// done by adopting the timestamp of the first sample of the packet and this +// sample's timestamp is inferred by initial_input_timestamp_ + +// cumulative_completed_samples / sample_rate_. class TimeSeriesFramerCalculator : public CalculatorBase { public: static ::mediapipe::Status GetContract(CalculatorContract* cc) { @@ -86,11 +94,26 @@ class TimeSeriesFramerCalculator : public CalculatorBase { void FrameOutput(CalculatorContext* cc); Timestamp CurrentOutputTimestamp() { + if (use_local_timestamp_) { + return current_timestamp_; + } + return CumulativeOutputTimestamp(); + } + + Timestamp CumulativeOutputTimestamp() { return initial_input_timestamp_ + round(cumulative_completed_samples_ / sample_rate_ * Timestamp::kTimestampUnitsPerSecond); } + // Returns the timestamp of a sample on a base, which is usually the time + // stamp of a packet. + Timestamp CurrentSampleTimestamp(const Timestamp& timestamp_base, + int64 number_of_samples) { + return timestamp_base + round(number_of_samples / sample_rate_ * + Timestamp::kTimestampUnitsPerSecond); + } + // The number of input samples to advance after the current output frame is // emitted. int next_frame_step_samples() const { @@ -118,14 +141,18 @@ class TimeSeriesFramerCalculator : public CalculatorBase { // any overlap). int64 cumulative_completed_samples_; Timestamp initial_input_timestamp_; + // The current timestamp is updated along with the incoming packets. + Timestamp current_timestamp_; int num_channels_; // Each entry in this deque consists of a single sample, i.e. a - // single column vector. - std::deque sample_buffer_; + // single column vector, and its timestamp. + std::deque> sample_buffer_; bool use_window_; Matrix window_; + + bool use_local_timestamp_; }; REGISTER_CALCULATOR(TimeSeriesFramerCalculator); @@ -133,7 +160,8 @@ void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) { const Matrix& input_frame = cc->Inputs().Index(0).Get(); for (int i = 0; i < input_frame.cols(); ++i) { - sample_buffer_.emplace_back(input_frame.col(i)); + sample_buffer_.emplace_back(std::make_pair( + input_frame.col(i), CurrentSampleTimestamp(cc->InputTimestamp(), i))); } cumulative_input_samples_ += input_frame.cols(); @@ -151,14 +179,16 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { new Matrix(num_channels_, frame_duration_samples_)); for (int i = 0; i < std::min(frame_step_samples, frame_duration_samples_); ++i) { - output_frame->col(i) = sample_buffer_.front(); + output_frame->col(i) = sample_buffer_.front().first; + current_timestamp_ = sample_buffer_.front().second; sample_buffer_.pop_front(); } const int frame_overlap_samples = frame_duration_samples_ - frame_step_samples; if (frame_overlap_samples > 0) { for (int i = 0; i < frame_overlap_samples; ++i) { - output_frame->col(i + frame_step_samples) = sample_buffer_[i]; + output_frame->col(i + frame_step_samples) = sample_buffer_[i].first; + current_timestamp_ = sample_buffer_[i].second; } } else { samples_still_to_drop_ = -frame_overlap_samples; @@ -178,6 +208,7 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { ::mediapipe::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) { if (initial_input_timestamp_ == Timestamp::Unstarted()) { initial_input_timestamp_ = cc->InputTimestamp(); + current_timestamp_ = initial_input_timestamp_; } EnqueueInput(cc); @@ -195,7 +226,8 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { std::unique_ptr output_frame(new Matrix); output_frame->setZero(num_channels_, frame_duration_samples_); for (int i = 0; i < sample_buffer_.size(); ++i) { - output_frame->col(i) = sample_buffer_[i]; + output_frame->col(i) = sample_buffer_[i].first; + current_timestamp_ = sample_buffer_[i].second; } cc->Outputs().Index(0).Add(output_frame.release(), @@ -258,6 +290,7 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { cumulative_output_frames_ = 0; samples_still_to_drop_ = 0; initial_input_timestamp_ = Timestamp::Unstarted(); + current_timestamp_ = Timestamp::Unstarted(); std::vector window_vector; use_window_ = false; @@ -282,6 +315,7 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { frame_duration_samples_) .cast(); } + use_local_timestamp_ = framer_options.use_local_timestamp(); return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/audio/time_series_framer_calculator.proto b/mediapipe/calculators/audio/time_series_framer_calculator.proto index 61be38da7..9e5b07462 100644 --- a/mediapipe/calculators/audio/time_series_framer_calculator.proto +++ b/mediapipe/calculators/audio/time_series_framer_calculator.proto @@ -62,4 +62,11 @@ message TimeSeriesFramerCalculatorOptions { HANN = 2; } optional WindowFunction window_function = 4 [default = NONE]; + + // If use_local_timestamp is true, the output packet's timestamp is based on + // the last sample of the packet and it's inferred from the latest input + // packet's timestamp. If false, the output packet's timestamp is based on + // the cumulative timestamping, which is inferred from the intial input + // timestamp and the cumulative number of samples. + optional bool use_local_timestamp = 6 [default = false]; } diff --git a/mediapipe/calculators/audio/time_series_framer_calculator_test.cc b/mediapipe/calculators/audio/time_series_framer_calculator_test.cc index 1a370faa1..cd0c38e13 100644 --- a/mediapipe/calculators/audio/time_series_framer_calculator_test.cc +++ b/mediapipe/calculators/audio/time_series_framer_calculator_test.cc @@ -35,6 +35,8 @@ namespace mediapipe { namespace { const int kInitialTimestampOffsetMicroseconds = 4; +const int kGapBetweenPacketsInSeconds = 1; +const int kUniversalInputPacketSize = 50; class TimeSeriesFramerCalculatorTest : public TimeSeriesCalculatorTest { @@ -391,5 +393,93 @@ TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest, HannWindowSanityCheck) { RunAndTestSinglePacketAverage(0.5f); } -} // anonymous namespace +// A simple test class that checks the local packet time stamp. This class +// generate a series of packets with and without gaps between packets and tests +// the behavior with cumulative timestamping and local packet timestamping. +class TimeSeriesFramerCalculatorTimestampingTest + : public TimeSeriesFramerCalculatorTest { + protected: + // Creates test input and saves a reference copy. + void InitializeInputForTimeStampingTest() { + concatenated_input_samples_.resize(0, num_input_channels_); + num_input_samples_ = 0; + for (int i = 0; i < 10; ++i) { + // This range of packet sizes was chosen such that some input + // packets will be smaller than the output packet size and other + // input packets will be larger. + int packet_size = kUniversalInputPacketSize; + double timestamp_seconds = kInitialTimestampOffsetMicroseconds * 1.0e-6 + + num_input_samples_ / input_sample_rate_; + if (options_.use_local_timestamp()) { + timestamp_seconds += kGapBetweenPacketsInSeconds * i; + } + + Matrix* data_frame = + NewTestFrame(num_input_channels_, packet_size, timestamp_seconds); + + AppendInputPacket(data_frame, round(timestamp_seconds * + Timestamp::kTimestampUnitsPerSecond)); + num_input_samples_ += packet_size; + } + } + + void CheckOutputTimestamps() { + int num_full_packets = output().packets.size(); + if (options_.pad_final_packet()) { + num_full_packets -= 1; + } + + int64 num_samples = 0; + for (int packet_num = 0; packet_num < num_full_packets; ++packet_num) { + const Packet& packet = output().packets[packet_num]; + num_samples += FrameDurationSamples(); + double expected_timestamp = + options_.use_local_timestamp() + ? GetExpectedLocalTimestampForSample(num_samples - 1) + : GetExpectedCumulativeTimestamp(num_samples - 1); + ASSERT_NEAR(packet.Timestamp().Seconds(), expected_timestamp, 1e-10); + } + } + + ::mediapipe::Status RunTimestampTest() { + InitializeGraph(); + InitializeInputForTimeStampingTest(); + FillInputHeader(); + return RunGraph(); + } + + private: + // Returns the timestamp in seconds based on local timestamping. + double GetExpectedLocalTimestampForSample(int sample_index) { + return kInitialTimestampOffsetMicroseconds * 1.0e-6 + + sample_index / input_sample_rate_ + + (sample_index / kUniversalInputPacketSize) * + kGapBetweenPacketsInSeconds; + } + + // Returns the timestamp inseconds based on cumulative timestamping. + double GetExpectedCumulativeTimestamp(int sample_index) { + return kInitialTimestampOffsetMicroseconds * 1.0e-6 + + sample_index / FrameDurationSamples() * FrameDurationSamples() / + input_sample_rate_; + } +}; + +TEST_F(TimeSeriesFramerCalculatorTimestampingTest, UseLocalTimeStamp) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_use_local_timestamp(true); + + MP_ASSERT_OK(RunTimestampTest()); + CheckOutputTimestamps(); +} + +TEST_F(TimeSeriesFramerCalculatorTimestampingTest, UseCumulativeTimeStamp) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_use_local_timestamp(false); + + MP_ASSERT_OK(RunTimestampTest()); + CheckOutputTimestamps(); +} + +} // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index ebef8127f..80205f90e 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -166,7 +166,13 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "@org_tensorflow//tensorflow/lite:framework", - ], + ] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//mediapipe:ios": [], + "//conditions:default": [ + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", + ], + }), alwayslink = 1, ) diff --git a/mediapipe/calculators/core/concatenate_vector_calculator.cc b/mediapipe/calculators/core/concatenate_vector_calculator.cc index 7a8445f48..c4144990e 100644 --- a/mediapipe/calculators/core/concatenate_vector_calculator.cc +++ b/mediapipe/calculators/core/concatenate_vector_calculator.cc @@ -19,6 +19,10 @@ #include "mediapipe/framework/formats/landmark.pb.h" #include "tensorflow/lite/interpreter.h" +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" +#endif // !MEDIAPIPE_DISABLE_GPU + namespace mediapipe { // Example config: @@ -45,4 +49,11 @@ REGISTER_CALCULATOR(ConcatenateTfLiteTensorVectorCalculator); typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmark> ConcatenateLandmarkVectorCalculator; REGISTER_CALCULATOR(ConcatenateLandmarkVectorCalculator); + +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) +typedef ConcatenateVectorCalculator<::tflite::gpu::gl::GlBuffer> + ConcatenateGlBufferVectorCalculator; +REGISTER_CALCULATOR(ConcatenateGlBufferVectorCalculator); +#endif + } // namespace mediapipe diff --git a/mediapipe/calculators/core/concatenate_vector_calculator.h b/mediapipe/calculators/core/concatenate_vector_calculator.h index b7ee24a9c..08e8e954f 100644 --- a/mediapipe/calculators/core/concatenate_vector_calculator.h +++ b/mediapipe/calculators/core/concatenate_vector_calculator.h @@ -15,6 +15,7 @@ #ifndef MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_ #define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_ +#include #include #include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h" @@ -59,16 +60,58 @@ class ConcatenateVectorCalculator : public CalculatorBase { if (cc->Inputs().Index(i).IsEmpty()) return ::mediapipe::OkStatus(); } } - auto output = absl::make_unique>(); + + return ConcatenateVectors(std::is_copy_constructible(), cc); + } + + template + ::mediapipe::Status ConcatenateVectors(std::true_type, + CalculatorContext* cc) { + auto output = absl::make_unique>(); for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { if (cc->Inputs().Index(i).IsEmpty()) continue; - const std::vector& input = cc->Inputs().Index(i).Get>(); + const std::vector& input = cc->Inputs().Index(i).Get>(); output->insert(output->end(), input.begin(), input.end()); } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); return ::mediapipe::OkStatus(); } + template + ::mediapipe::Status ConcatenateVectors(std::false_type, + CalculatorContext* cc) { + return ConsumeAndConcatenateVectors(std::is_move_constructible(), cc); + } + + template + ::mediapipe::Status ConsumeAndConcatenateVectors(std::true_type, + CalculatorContext* cc) { + auto output = absl::make_unique>(); + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + if (cc->Inputs().Index(i).IsEmpty()) continue; + ::mediapipe::StatusOr>> input_status = + cc->Inputs().Index(i).Value().Consume>(); + if (input_status.ok()) { + std::unique_ptr> input_vector = + std::move(input_status).ValueOrDie(); + output->insert(output->end(), + std::make_move_iterator(input_vector->begin()), + std::make_move_iterator(input_vector->end())); + } else { + return input_status.status(); + } + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } + + template + ::mediapipe::Status ConsumeAndConcatenateVectors(std::false_type, + CalculatorContext* cc) { + return ::mediapipe::InternalError( + "Cannot copy or move input vectors to concatenate them"); + } + private: bool only_emit_if_all_present_; }; diff --git a/mediapipe/calculators/core/concatenate_vector_calculator_test.cc b/mediapipe/calculators/core/concatenate_vector_calculator_test.cc index 0baceaa26..4b27c2030 100644 --- a/mediapipe/calculators/core/concatenate_vector_calculator_test.cc +++ b/mediapipe/calculators/core/concatenate_vector_calculator_test.cc @@ -235,4 +235,167 @@ TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) { EXPECT_EQ(0, outputs.size()); } +typedef ConcatenateVectorCalculator> + TestConcatenateUniqueIntPtrCalculator; +REGISTER_CALCULATOR(TestConcatenateUniqueIntPtrCalculator); + +TEST(TestConcatenateUniqueIntVectorCalculatorTest, ConsumeOneTimestamp) { + /* Note: We don't use CalculatorRunner for this test because it keeps copies + * of input packets, so packets sent to the graph don't have sole ownership. + * The test needs to send packets that own the data. + */ + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(R"( + input_stream: "in_1" + input_stream: "in_2" + input_stream: "in_3" + node { + calculator: "TestConcatenateUniqueIntPtrCalculator" + input_stream: "in_1" + input_stream: "in_2" + input_stream: "in_3" + output_stream: "out" + } + )"); + + std::vector outputs; + tool::AddVectorSink("out", &graph_config, &outputs); + + CalculatorGraph graph; + MP_EXPECT_OK(graph.Initialize(graph_config)); + MP_EXPECT_OK(graph.StartRun({})); + + // input1 : {0, 1, 2} + std::unique_ptr>> input_1 = + absl::make_unique>>(3); + for (int i = 0; i < 3; ++i) { + input_1->at(i) = absl::make_unique(i); + } + // input2: {3} + std::unique_ptr>> input_2 = + absl::make_unique>>(1); + input_2->at(0) = absl::make_unique(3); + // input3: {4, 5} + std::unique_ptr>> input_3 = + absl::make_unique>>(2); + input_3->at(0) = absl::make_unique(4); + input_3->at(1) = absl::make_unique(5); + + MP_EXPECT_OK(graph.AddPacketToInputStream( + "in_1", Adopt(input_1.release()).At(Timestamp(1)))); + MP_EXPECT_OK(graph.AddPacketToInputStream( + "in_2", Adopt(input_2.release()).At(Timestamp(1)))); + MP_EXPECT_OK(graph.AddPacketToInputStream( + "in_3", Adopt(input_3.release()).At(Timestamp(1)))); + + MP_EXPECT_OK(graph.WaitUntilIdle()); + MP_EXPECT_OK(graph.CloseAllPacketSources()); + MP_EXPECT_OK(graph.WaitUntilDone()); + + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(Timestamp(1), outputs[0].Timestamp()); + const std::vector>& result = + outputs[0].Get>>(); + EXPECT_EQ(6, result.size()); + for (int i = 0; i < 6; ++i) { + const std::unique_ptr& v = result[i]; + EXPECT_EQ(i, *v); + } +} + +TEST(TestConcatenateUniqueIntVectorCalculatorTest, OneEmptyStreamStillOutput) { + /* Note: We don't use CalculatorRunner for this test because it keeps copies + * of input packets, so packets sent to the graph don't have sole ownership. + * The test needs to send packets that own the data. + */ + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(R"( + input_stream: "in_1" + input_stream: "in_2" + node { + calculator: "TestConcatenateUniqueIntPtrCalculator" + input_stream: "in_1" + input_stream: "in_2" + output_stream: "out" + } + )"); + + std::vector outputs; + tool::AddVectorSink("out", &graph_config, &outputs); + + CalculatorGraph graph; + MP_EXPECT_OK(graph.Initialize(graph_config)); + MP_EXPECT_OK(graph.StartRun({})); + + // input1 : {0, 1, 2} + std::unique_ptr>> input_1 = + absl::make_unique>>(3); + for (int i = 0; i < 3; ++i) { + input_1->at(i) = absl::make_unique(i); + } + + MP_EXPECT_OK(graph.AddPacketToInputStream( + "in_1", Adopt(input_1.release()).At(Timestamp(1)))); + + MP_EXPECT_OK(graph.WaitUntilIdle()); + MP_EXPECT_OK(graph.CloseAllPacketSources()); + MP_EXPECT_OK(graph.WaitUntilDone()); + + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(Timestamp(1), outputs[0].Timestamp()); + const std::vector>& result = + outputs[0].Get>>(); + EXPECT_EQ(3, result.size()); + for (int i = 0; i < 3; ++i) { + const std::unique_ptr& v = result[i]; + EXPECT_EQ(i, *v); + } +} + +TEST(TestConcatenateUniqueIntVectorCalculatorTest, OneEmptyStreamNoOutput) { + /* Note: We don't use CalculatorRunner for this test because it keeps copies + * of input packets, so packets sent to the graph don't have sole ownership. + * The test needs to send packets that own the data. + */ + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(R"( + input_stream: "in_1" + input_stream: "in_2" + node { + calculator: "TestConcatenateUniqueIntPtrCalculator" + input_stream: "in_1" + input_stream: "in_2" + output_stream: "out" + options { + [mediapipe.ConcatenateVectorCalculatorOptions.ext] { + only_emit_if_all_present: true + } + } + } + )"); + + std::vector outputs; + tool::AddVectorSink("out", &graph_config, &outputs); + + CalculatorGraph graph; + MP_EXPECT_OK(graph.Initialize(graph_config)); + MP_EXPECT_OK(graph.StartRun({})); + + // input1 : {0, 1, 2} + std::unique_ptr>> input_1 = + absl::make_unique>>(3); + for (int i = 0; i < 3; ++i) { + input_1->at(i) = absl::make_unique(i); + } + + MP_EXPECT_OK(graph.AddPacketToInputStream( + "in_1", Adopt(input_1.release()).At(Timestamp(1)))); + + MP_EXPECT_OK(graph.WaitUntilIdle()); + MP_EXPECT_OK(graph.CloseAllPacketSources()); + MP_EXPECT_OK(graph.WaitUntilDone()); + + EXPECT_EQ(0, outputs.size()); +} + } // namespace mediapipe diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 3773a2180..b8fdbdfae 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -19,7 +19,6 @@ package(default_visibility = ["//visibility:private"]) exports_files(["LICENSE"]) load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") -load("@bazel_skylib//lib:selects.bzl", "selects") proto_library( name = "opencv_image_encoder_calculator_proto", @@ -227,19 +226,13 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", ] + select({ - "//mediapipe:android": [ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_simple_shaders", - "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:shader_util", ], - "//mediapipe:ios": [ - "//mediapipe/gpu:gl_calculator_helper", - "//mediapipe/gpu:gl_simple_shaders", - "//mediapipe/gpu:gpu_buffer", - "//mediapipe/gpu:shader_util", - ], - "//conditions:default": [], }), alwayslink = 1, ) @@ -263,13 +256,13 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", ] + select({ - "//mediapipe:android": [ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_simple_shaders", - "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:shader_util", ], - "//conditions:default": [], }), alwayslink = 1, ) @@ -322,14 +315,14 @@ cc_library( "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - ] + selects.with_or({ - ("//mediapipe:android", "//mediapipe:ios"): [ + ] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:shader_util", ], - "//conditions:default": [], }), alwayslink = 1, ) @@ -363,14 +356,15 @@ cc_library( "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - ] + selects.with_or({ - ("//mediapipe:android", "//mediapipe:ios"): [ + "//mediapipe/gpu:gpu_buffer", + ] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:shader_util", ], - "//conditions:default": [], }), alwayslink = 1, ) @@ -415,19 +409,13 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/util:color_cc_proto", ] + select({ - "//mediapipe:android": [ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_simple_shaders", - "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:shader_util", ], - "//mediapipe:ios": [ - "//mediapipe/gpu:gl_calculator_helper", - "//mediapipe/gpu:gl_simple_shaders", - "//mediapipe/gpu:gpu_buffer", - "//mediapipe/gpu:shader_util", - ], - "//conditions:default": [], }), alwayslink = 1, ) @@ -486,11 +474,11 @@ cc_library( "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - ] + selects.with_or({ - ("//mediapipe:android", "//mediapipe:ios"): [ + ] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ "//mediapipe/gpu:gpu_buffer", ], - "//conditions:default": [], }), alwayslink = 1, ) diff --git a/mediapipe/calculators/image/bilateral_filter_calculator.cc b/mediapipe/calculators/image/bilateral_filter_calculator.cc index 0adb9390a..e1d26c1e0 100644 --- a/mediapipe/calculators/image/bilateral_filter_calculator.cc +++ b/mediapipe/calculators/image/bilateral_filter_calculator.cc @@ -27,11 +27,11 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" -#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) +#if !defined(MEDIAPIPE_DISABLE_GPU) #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/shader_util.h" -#endif // __ANDROID__ || __EMSCRIPTEN__ +#endif // !MEDIAPIPE_DISABLE_GPU namespace mediapipe { @@ -101,11 +101,11 @@ class BilateralFilterCalculator : public CalculatorBase { bool use_gpu_ = false; bool gpu_initialized_ = false; -#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) +#if !defined(MEDIAPIPE_DISABLE_GPU) mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; GLuint program_joint_ = 0; -#endif // __ANDROID__ || __EMSCRIPTEN__ +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(BilateralFilterCalculator); @@ -122,39 +122,46 @@ REGISTER_CALCULATOR(BilateralFilterCalculator); return ::mediapipe::InternalError("GPU output must have GPU input."); } + bool use_gpu = false; + // Input image to filter. -#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) +#if !defined(MEDIAPIPE_DISABLE_GPU) if (cc->Inputs().HasTag(kInputFrameTagGpu)) { cc->Inputs().Tag(kInputFrameTagGpu).Set(); + use_gpu |= true; } -#endif // __ANDROID__ || __EMSCRIPTEN__ +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputFrameTag)) { cc->Inputs().Tag(kInputFrameTag).Set(); } // Input guide image mask (optional) +#if !defined(MEDIAPIPE_DISABLE_GPU) if (cc->Inputs().HasTag(kInputGuideTagGpu)) { -#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) cc->Inputs().Tag(kInputGuideTagGpu).Set(); -#endif // __ANDROID__ || __EMSCRIPTEN__ + use_gpu |= true; } +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputGuideTag)) { cc->Inputs().Tag(kInputGuideTag).Set(); } // Output image. -#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) +#if !defined(MEDIAPIPE_DISABLE_GPU) if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { cc->Outputs().Tag(kOutputFrameTagGpu).Set(); + use_gpu |= true; } -#endif // __ANDROID__ || __EMSCRIPTEN__ +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kOutputFrameTag)) { cc->Outputs().Tag(kOutputFrameTag).Set(); } -#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) - MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // __ANDROID__ || __EMSCRIPTEN__ + if (use_gpu) { +#if !defined(MEDIAPIPE_DISABLE_GPU) + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // !MEDIAPIPE_DISABLE_GPU + } return ::mediapipe::OkStatus(); } @@ -166,11 +173,11 @@ REGISTER_CALCULATOR(BilateralFilterCalculator); if (cc->Inputs().HasTag(kInputFrameTagGpu) && cc->Outputs().HasTag(kOutputFrameTagGpu)) { -#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) +#if !defined(MEDIAPIPE_DISABLE_GPU) use_gpu_ = true; #else - RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet."; -#endif // __ANDROID__ || __EMSCRIPTEN__ + RET_CHECK_FAIL() << "GPU processing not enabled."; +#endif } sigma_color_ = options_.sigma_color(); @@ -180,9 +187,9 @@ REGISTER_CALCULATOR(BilateralFilterCalculator); if (!use_gpu_) sigma_color_ *= 255.0; if (use_gpu_) { -#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) +#if !defined(MEDIAPIPE_DISABLE_GPU) MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#endif +#endif // !MEDIAPIPE_DISABLE_GPU } return ::mediapipe::OkStatus(); @@ -190,7 +197,7 @@ REGISTER_CALCULATOR(BilateralFilterCalculator); ::mediapipe::Status BilateralFilterCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { -#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) +#if !defined(MEDIAPIPE_DISABLE_GPU) MP_RETURN_IF_ERROR( gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { if (!gpu_initialized_) { @@ -200,7 +207,7 @@ REGISTER_CALCULATOR(BilateralFilterCalculator); MP_RETURN_IF_ERROR(RenderGpu(cc)); return ::mediapipe::OkStatus(); })); -#endif // __ANDROID__ || __EMSCRIPTEN__ +#endif // !MEDIAPIPE_DISABLE_GPU } else { MP_RETURN_IF_ERROR(RenderCpu(cc)); } @@ -209,14 +216,14 @@ REGISTER_CALCULATOR(BilateralFilterCalculator); } ::mediapipe::Status BilateralFilterCalculator::Close(CalculatorContext* cc) { -#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) +#if !defined(MEDIAPIPE_DISABLE_GPU) gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; if (program_joint_) glDeleteProgram(program_joint_); program_joint_ = 0; }); -#endif // __ANDROID__ || __EMSCRIPTEN__ +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } @@ -263,7 +270,7 @@ REGISTER_CALCULATOR(BilateralFilterCalculator); if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) { return ::mediapipe::OkStatus(); } -#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) +#if !defined(MEDIAPIPE_DISABLE_GPU) const auto& input_frame = cc->Inputs().Tag(kInputFrameTagGpu).Get(); auto input_texture = gpu_helper_.CreateSourceTexture(input_frame); @@ -321,13 +328,13 @@ REGISTER_CALCULATOR(BilateralFilterCalculator); // Cleanup input_texture.Release(); output_texture.Release(); -#endif // __ANDROID__ || __EMSCRIPTEN__ +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } void BilateralFilterCalculator::GlRender(CalculatorContext* cc) { -#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) +#if !defined(MEDIAPIPE_DISABLE_GPU) static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -373,11 +380,11 @@ void BilateralFilterCalculator::GlRender(CalculatorContext* cc) { glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // __ANDROID__ || __EMSCRIPTEN__ +#endif // !MEDIAPIPE_DISABLE_GPU } ::mediapipe::Status BilateralFilterCalculator::GlSetup(CalculatorContext* cc) { -#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) +#if !defined(MEDIAPIPE_DISABLE_GPU) const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -545,7 +552,7 @@ void BilateralFilterCalculator::GlRender(CalculatorContext* cc) { glUniform1i(glGetUniformLocation(program_joint_, "input_frame"), 1); glUniform1i(glGetUniformLocation(program_joint_, "guide_frame"), 2); -#endif // __ANDROID__ || __EMSCRIPTEN__ +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/image/image_cropping_calculator.cc b/mediapipe/calculators/image/image_cropping_calculator.cc index b893cd260..a9277e871 100644 --- a/mediapipe/calculators/image/image_cropping_calculator.cc +++ b/mediapipe/calculators/image/image_cropping_calculator.cc @@ -24,12 +24,12 @@ #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU namespace { enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; @@ -37,9 +37,20 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; namespace mediapipe { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +namespace { -#endif // __ANDROID__ or iOS +#if !defined(MEDIAPIPE_DISABLE_GPU) + +#endif // !MEDIAPIPE_DISABLE_GPU + +constexpr char kRectTag[] = "RECT"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kHeightTag[] = "HEIGHT"; +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageGpuTag[] = "IMAGE_GPU"; +constexpr char kWidthTag[] = "WIDTH"; + +} // namespace // Crops the input texture to the given rectangle region. The rectangle can // be at arbitrary location on the image with rotation. If there's rotation, the @@ -91,48 +102,55 @@ class ImageCroppingCalculator : public CalculatorBase { bool use_gpu_ = false; // Output texture corners (4) after transoformation in normalized coordinates. float transformed_points_[8]; -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) bool gpu_initialized_ = false; mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(ImageCroppingCalculator); ::mediapipe::Status ImageCroppingCalculator::GetContract( CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag("IMAGE") ^ cc->Inputs().HasTag("IMAGE_GPU")); - RET_CHECK(cc->Outputs().HasTag("IMAGE") ^ cc->Outputs().HasTag("IMAGE_GPU")); + RET_CHECK(cc->Inputs().HasTag(kImageTag) ^ cc->Inputs().HasTag(kImageGpuTag)); + RET_CHECK(cc->Outputs().HasTag(kImageTag) ^ + cc->Outputs().HasTag(kImageGpuTag)); - if (cc->Inputs().HasTag("IMAGE")) { - RET_CHECK(cc->Outputs().HasTag("IMAGE")); - cc->Inputs().Tag("IMAGE").Set(); - cc->Outputs().Tag("IMAGE").Set(); - } -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) - if (cc->Inputs().HasTag("IMAGE_GPU")) { - RET_CHECK(cc->Outputs().HasTag("IMAGE_GPU")); - cc->Inputs().Tag("IMAGE_GPU").Set(); - cc->Outputs().Tag("IMAGE_GPU").Set(); - } -#endif // __ANDROID__ or iOS + bool use_gpu = false; - if (cc->Inputs().HasTag("RECT")) { - cc->Inputs().Tag("RECT").Set(); + if (cc->Inputs().HasTag(kImageTag)) { + RET_CHECK(cc->Outputs().HasTag(kImageTag)); + cc->Inputs().Tag(kImageTag).Set(); + cc->Outputs().Tag(kImageTag).Set(); } - if (cc->Inputs().HasTag("NORM_RECT")) { - cc->Inputs().Tag("NORM_RECT").Set(); +#if !defined(MEDIAPIPE_DISABLE_GPU) + if (cc->Inputs().HasTag(kImageGpuTag)) { + RET_CHECK(cc->Outputs().HasTag(kImageGpuTag)); + cc->Inputs().Tag(kImageGpuTag).Set(); + cc->Outputs().Tag(kImageGpuTag).Set(); + use_gpu |= true; } - if (cc->Inputs().HasTag("WIDTH")) { - cc->Inputs().Tag("WIDTH").Set(); +#endif // !MEDIAPIPE_DISABLE_GPU + + RET_CHECK(cc->Inputs().HasTag(kRectTag) ^ cc->Inputs().HasTag(kNormRectTag)); + if (cc->Inputs().HasTag(kRectTag)) { + cc->Inputs().Tag(kRectTag).Set(); } - if (cc->Inputs().HasTag("HEIGHT")) { - cc->Inputs().Tag("HEIGHT").Set(); + if (cc->Inputs().HasTag(kNormRectTag)) { + cc->Inputs().Tag(kNormRectTag).Set(); + } + if (cc->Inputs().HasTag(kWidthTag)) { + cc->Inputs().Tag(kWidthTag).Set(); + } + if (cc->Inputs().HasTag(kHeightTag)) { + cc->Inputs().Tag(kHeightTag).Set(); } -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) - MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // __ANDROID__ or iOS + if (use_gpu) { +#if !defined(MEDIAPIPE_DISABLE_GPU) + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // !MEDIAPIPE_DISABLE_GPU + } return ::mediapipe::OkStatus(); } @@ -140,26 +158,35 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); ::mediapipe::Status ImageCroppingCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - if (cc->Inputs().HasTag("IMAGE_GPU")) { + if (cc->Inputs().HasTag(kImageGpuTag)) { use_gpu_ = true; } options_ = cc->Options(); if (use_gpu_) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); #else RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU } return ::mediapipe::OkStatus(); } ::mediapipe::Status ImageCroppingCalculator::Process(CalculatorContext* cc) { + if (cc->Inputs().HasTag(kRectTag) && cc->Inputs().Tag(kRectTag).IsEmpty()) { + VLOG(1) << "RECT is empty for timestamp: " << cc->InputTimestamp(); + return ::mediapipe::OkStatus(); + } + if (cc->Inputs().HasTag(kNormRectTag) && + cc->Inputs().Tag(kNormRectTag).IsEmpty()) { + VLOG(1) << "NORM_RECT is empty for timestamp: " << cc->InputTimestamp(); + return ::mediapipe::OkStatus(); + } if (use_gpu_) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) MP_RETURN_IF_ERROR( gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { if (!gpu_initialized_) { @@ -169,7 +196,7 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); MP_RETURN_IF_ERROR(RenderGpu(cc)); return ::mediapipe::OkStatus(); })); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU } else { MP_RETURN_IF_ERROR(RenderCpu(cc)); } @@ -177,19 +204,22 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); } ::mediapipe::Status ImageCroppingCalculator::Close(CalculatorContext* cc) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; }); gpu_initialized_ = false; -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } ::mediapipe::Status ImageCroppingCalculator::RenderCpu(CalculatorContext* cc) { - const auto& input_img = cc->Inputs().Tag("IMAGE").Get(); + if (cc->Inputs().Tag(kImageTag).IsEmpty()) { + return ::mediapipe::OkStatus(); + } + const auto& input_img = cc->Inputs().Tag(kImageTag).Get(); cv::Mat input_mat = formats::MatView(&input_img); float rect_center_x = input_img.Width() / 2.0f; @@ -197,8 +227,8 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); float rotation = 0.0f; int target_width = input_img.Width(); int target_height = input_img.Height(); - if (cc->Inputs().HasTag("RECT")) { - const auto& rect = cc->Inputs().Tag("RECT").Get(); + if (cc->Inputs().HasTag(kRectTag)) { + const auto& rect = cc->Inputs().Tag(kRectTag).Get(); if (rect.width() > 0 && rect.height() > 0 && rect.x_center() >= 0 && rect.y_center() >= 0) { rect_center_x = rect.x_center(); @@ -207,8 +237,8 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); target_height = rect.height(); rotation = rect.rotation(); } - } else if (cc->Inputs().HasTag("NORM_RECT")) { - const auto& rect = cc->Inputs().Tag("NORM_RECT").Get(); + } else if (cc->Inputs().HasTag(kNormRectTag)) { + const auto& rect = cc->Inputs().Tag(kNormRectTag).Get(); if (rect.width() > 0.0 && rect.height() > 0.0 && rect.x_center() >= 0.0 && rect.y_center() >= 0.0) { rect_center_x = std::round(rect.x_center() * input_img.Width()); @@ -218,9 +248,9 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); rotation = rect.rotation(); } } else { - if (cc->Inputs().HasTag("WIDTH") && cc->Inputs().HasTag("HEIGHT")) { - target_width = cc->Inputs().Tag("WIDTH").Get(); - target_height = cc->Inputs().Tag("HEIGHT").Get(); + if (cc->Inputs().HasTag(kWidthTag) && cc->Inputs().HasTag(kHeightTag)) { + target_width = cc->Inputs().Tag(kWidthTag).Get(); + target_height = cc->Inputs().Tag(kHeightTag).Get(); } else if (options_.has_width() && options_.has_height()) { target_width = options_.width(); target_height = options_.height(); @@ -253,16 +283,17 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); input_img.Format(), cropped_image.cols, cropped_image.rows)); cv::Mat output_mat = formats::MatView(output_frame.get()); cropped_image.copyTo(output_mat); - cc->Outputs().Tag("IMAGE").Add(output_frame.release(), cc->InputTimestamp()); + cc->Outputs().Tag(kImageTag).Add(output_frame.release(), + cc->InputTimestamp()); return ::mediapipe::OkStatus(); } ::mediapipe::Status ImageCroppingCalculator::RenderGpu(CalculatorContext* cc) { - if (cc->Inputs().Tag("IMAGE_GPU").IsEmpty()) { + if (cc->Inputs().Tag(kImageGpuTag).IsEmpty()) { return ::mediapipe::OkStatus(); } -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) - const Packet& input_packet = cc->Inputs().Tag("IMAGE_GPU").Value(); +#if !defined(MEDIAPIPE_DISABLE_GPU) + const Packet& input_packet = cc->Inputs().Tag(kImageGpuTag).Value(); const auto& input_buffer = input_packet.Get(); auto src_tex = gpu_helper_.CreateSourceTexture(input_buffer); @@ -287,18 +318,18 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); // Send result image in GPU packet. auto output = dst_tex.GetFrame(); - cc->Outputs().Tag("IMAGE_GPU").Add(output.release(), cc->InputTimestamp()); + cc->Outputs().Tag(kImageGpuTag).Add(output.release(), cc->InputTimestamp()); // Cleanup src_tex.Release(); dst_tex.Release(); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } void ImageCroppingCalculator::GlRender() { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -342,11 +373,11 @@ void ImageCroppingCalculator::GlRender() { glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU } ::mediapipe::Status ImageCroppingCalculator::InitGpu(CalculatorContext* cc) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -392,7 +423,7 @@ void ImageCroppingCalculator::GlRender() { // Parameters glUseProgram(program_); glUniform1i(glGetUniformLocation(program_, "input_frame"), 1); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } @@ -410,8 +441,8 @@ void ImageCroppingCalculator::GetOutputDimensions(CalculatorContext* cc, int y_center = src_height / 2; // Get the rotation of the cropping box. float rotation = 0.0f; - if (cc->Inputs().HasTag("RECT")) { - const auto& rect = cc->Inputs().Tag("RECT").Get(); + if (cc->Inputs().HasTag(kRectTag)) { + const auto& rect = cc->Inputs().Tag(kRectTag).Get(); // Only use the rect if it is valid. if (rect.width() > 0 && rect.height() > 0 && rect.x_center() >= 0 && rect.y_center() >= 0) { @@ -421,8 +452,8 @@ void ImageCroppingCalculator::GetOutputDimensions(CalculatorContext* cc, crop_height = rect.height(); rotation = rect.rotation(); } - } else if (cc->Inputs().HasTag("NORM_RECT")) { - const auto& rect = cc->Inputs().Tag("NORM_RECT").Get(); + } else if (cc->Inputs().HasTag(kNormRectTag)) { + const auto& rect = cc->Inputs().Tag(kNormRectTag).Get(); // Only use the rect if it is valid. if (rect.width() > 0.0 && rect.height() > 0.0 && rect.x_center() >= 0.0 && rect.y_center() >= 0.0) { @@ -433,9 +464,9 @@ void ImageCroppingCalculator::GetOutputDimensions(CalculatorContext* cc, rotation = rect.rotation(); } } else { - if (cc->Inputs().HasTag("WIDTH") && cc->Inputs().HasTag("HEIGHT")) { - crop_width = cc->Inputs().Tag("WIDTH").Get(); - crop_height = cc->Inputs().Tag("HEIGHT").Get(); + if (cc->Inputs().HasTag(kWidthTag) && cc->Inputs().HasTag(kHeightTag)) { + crop_width = cc->Inputs().Tag(kWidthTag).Get(); + crop_height = cc->Inputs().Tag(kHeightTag).Get(); } else if (options_.has_width() && options_.has_height()) { crop_width = options_.width(); crop_height = options_.height(); diff --git a/mediapipe/calculators/image/image_properties_calculator.cc b/mediapipe/calculators/image/image_properties_calculator.cc index 70c49de61..ea6c06c43 100644 --- a/mediapipe/calculators/image/image_properties_calculator.cc +++ b/mediapipe/calculators/image/image_properties_calculator.cc @@ -15,9 +15,9 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) #include "mediapipe/gpu/gpu_buffer.h" -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU namespace mediapipe { @@ -44,11 +44,11 @@ class ImagePropertiesCalculator : public CalculatorBase { if (cc->Inputs().HasTag("IMAGE")) { cc->Inputs().Tag("IMAGE").Set(); } -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) if (cc->Inputs().HasTag("IMAGE_GPU")) { cc->Inputs().Tag("IMAGE_GPU").Set<::mediapipe::GpuBuffer>(); } -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag("SIZE")) { cc->Outputs().Tag("SIZE").Set>(); @@ -71,7 +71,7 @@ class ImagePropertiesCalculator : public CalculatorBase { width = image.Width(); height = image.Height(); } -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) if (cc->Inputs().HasTag("IMAGE_GPU") && !cc->Inputs().Tag("IMAGE_GPU").IsEmpty()) { const auto& image = @@ -79,7 +79,7 @@ class ImagePropertiesCalculator : public CalculatorBase { width = image.width(); height = image.height(); } -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU cc->Outputs().Tag("SIZE").AddPacket( MakePacket>(width, height) diff --git a/mediapipe/calculators/image/image_transformation_calculator.cc b/mediapipe/calculators/image/image_transformation_calculator.cc index c2d894547..5eb34c3c0 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.cc +++ b/mediapipe/calculators/image/image_transformation_calculator.cc @@ -22,12 +22,12 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/gpu/scale_mode.pb.h" -#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX +#if !defined(MEDIAPIPE_DISABLE_GPU) #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_quad_renderer.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/shader_util.h" -#endif // __ANDROID__ || iOS +#endif // !MEDIAPIPE_DISABLE_GPU #if defined(__ANDROID__) // The size of Java arrays is dynamic, which makes it difficult to @@ -42,9 +42,9 @@ typedef int DimensionsPacketType[2]; namespace mediapipe { -#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX +#if !defined(MEDIAPIPE_DISABLE_GPU) -#endif // __ANDROID__ || iOS +#endif // !MEDIAPIPE_DISABLE_GPU namespace { int RotationModeToDegrees(mediapipe::RotationMode_Mode rotation) { @@ -170,12 +170,12 @@ class ImageTransformationCalculator : public CalculatorBase { mediapipe::ScaleMode_Mode scale_mode_; bool use_gpu_ = false; -#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX +#if !defined(MEDIAPIPE_DISABLE_GPU) GlCalculatorHelper helper_; std::unique_ptr rgb_renderer_; std::unique_ptr yuv_renderer_; std::unique_ptr ext_rgb_renderer_; -#endif // __ANDROID__ || iOS +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(ImageTransformationCalculator); @@ -185,18 +185,22 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); RET_CHECK(cc->Inputs().HasTag("IMAGE") ^ cc->Inputs().HasTag("IMAGE_GPU")); RET_CHECK(cc->Outputs().HasTag("IMAGE") ^ cc->Outputs().HasTag("IMAGE_GPU")); + bool use_gpu = false; + if (cc->Inputs().HasTag("IMAGE")) { RET_CHECK(cc->Outputs().HasTag("IMAGE")); cc->Inputs().Tag("IMAGE").Set(); cc->Outputs().Tag("IMAGE").Set(); } -#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX +#if !defined(MEDIAPIPE_DISABLE_GPU) if (cc->Inputs().HasTag("IMAGE_GPU")) { RET_CHECK(cc->Outputs().HasTag("IMAGE_GPU")); cc->Inputs().Tag("IMAGE_GPU").Set(); cc->Outputs().Tag("IMAGE_GPU").Set(); + use_gpu |= true; } -#endif // __ANDROID__ || iOS +#endif // !MEDIAPIPE_DISABLE_GPU + if (cc->Inputs().HasTag("ROTATION_DEGREES")) { cc->Inputs().Tag("ROTATION_DEGREES").Set(); } @@ -212,9 +216,11 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); cc->Outputs().Tag("LETTERBOX_PADDING").Set>(); } -#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX - MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); -#endif // __ANDROID__ || iOS + if (use_gpu) { +#if !defined(MEDIAPIPE_DISABLE_GPU) + MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); +#endif // !MEDIAPIPE_DISABLE_GPU + } return ::mediapipe::OkStatus(); } @@ -250,12 +256,12 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); scale_mode_ = ParseScaleMode(options_.scale_mode(), DEFAULT_SCALE_MODE); if (use_gpu_) { -#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX +#if !defined(MEDIAPIPE_DISABLE_GPU) // Let the helper access the GL context information. MP_RETURN_IF_ERROR(helper_.Open(cc)); #else - RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; -#endif // __ANDROID__ || iOS + RET_CHECK_FAIL() << "GPU processing not enabled."; +#endif // !MEDIAPIPE_DISABLE_GPU } return ::mediapipe::OkStatus(); @@ -264,10 +270,10 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); ::mediapipe::Status ImageTransformationCalculator::Process( CalculatorContext* cc) { if (use_gpu_) { -#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX +#if !defined(MEDIAPIPE_DISABLE_GPU) return helper_.RunInGlContext( [this, cc]() -> ::mediapipe::Status { return RenderGpu(cc); }); -#endif // __ANDROID__ || iOS +#endif // !MEDIAPIPE_DISABLE_GPU } else { return RenderCpu(cc); } @@ -277,7 +283,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); ::mediapipe::Status ImageTransformationCalculator::Close( CalculatorContext* cc) { if (use_gpu_) { -#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX +#if !defined(MEDIAPIPE_DISABLE_GPU) QuadRenderer* rgb_renderer = rgb_renderer_.release(); QuadRenderer* yuv_renderer = yuv_renderer_.release(); QuadRenderer* ext_rgb_renderer = ext_rgb_renderer_.release(); @@ -295,8 +301,9 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); delete yuv_renderer; } }); -#endif // __ANDROID__ || iOS +#endif // !MEDIAPIPE_DISABLE_GPU } + return ::mediapipe::OkStatus(); } @@ -371,7 +378,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); ::mediapipe::Status ImageTransformationCalculator::RenderGpu( CalculatorContext* cc) { -#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX +#if !defined(MEDIAPIPE_DISABLE_GPU) int input_width = cc->Inputs().Tag("IMAGE_GPU").Get().width(); int input_height = cc->Inputs().Tag("IMAGE_GPU").Get().height(); @@ -408,7 +415,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); #endif // iOS { src1 = helper_.CreateSourceTexture(input); -#if defined(__ANDROID__) +#if defined(TEXTURE_EXTERNAL_OES) if (src1.target() == GL_TEXTURE_EXTERNAL_OES) { if (!ext_rgb_renderer_) { ext_rgb_renderer_ = absl::make_unique(); @@ -417,7 +424,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); } renderer = ext_rgb_renderer_.get(); } else // NOLINT(readability/braces) -#endif // __ANDROID__ +#endif // TEXTURE_EXTERNAL_OES { if (!rgb_renderer_) { rgb_renderer_ = absl::make_unique(); @@ -460,7 +467,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); auto output = dst.GetFrame(); cc->Outputs().Tag("IMAGE_GPU").Add(output.release(), cc->InputTimestamp()); -#endif // __ANDROID__ || iOS +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/image/recolor_calculator.cc b/mediapipe/calculators/image/recolor_calculator.cc index b23eda481..fff26b704 100644 --- a/mediapipe/calculators/image/recolor_calculator.cc +++ b/mediapipe/calculators/image/recolor_calculator.cc @@ -21,12 +21,11 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/util/color.pb.h" -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" -#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU namespace { enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; @@ -95,10 +94,10 @@ class RecolorCalculator : public CalculatorBase { mediapipe::RecolorCalculatorOptions::MaskChannel mask_channel_; bool use_gpu_ = false; -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(RecolorCalculator); @@ -107,36 +106,43 @@ REGISTER_CALCULATOR(RecolorCalculator); RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) + bool use_gpu = false; + +#if !defined(MEDIAPIPE_DISABLE_GPU) if (cc->Inputs().HasTag("IMAGE_GPU")) { cc->Inputs().Tag("IMAGE_GPU").Set(); + use_gpu |= true; } -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag("IMAGE")) { cc->Inputs().Tag("IMAGE").Set(); } -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) if (cc->Inputs().HasTag("MASK_GPU")) { cc->Inputs().Tag("MASK_GPU").Set(); + use_gpu |= true; } -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag("MASK")) { cc->Inputs().Tag("MASK").Set(); } -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) if (cc->Outputs().HasTag("IMAGE_GPU")) { cc->Outputs().Tag("IMAGE_GPU").Set(); + use_gpu |= true; } -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag("IMAGE")) { cc->Outputs().Tag("IMAGE").Set(); } -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) - MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // __ANDROID__ or iOS + if (use_gpu) { +#if !defined(MEDIAPIPE_DISABLE_GPU) + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // !MEDIAPIPE_DISABLE_GPU + } return ::mediapipe::OkStatus(); } @@ -146,9 +152,9 @@ REGISTER_CALCULATOR(RecolorCalculator); if (cc->Inputs().HasTag("IMAGE_GPU")) { use_gpu_ = true; -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU } MP_RETURN_IF_ERROR(LoadOptions(cc)); @@ -158,7 +164,7 @@ REGISTER_CALCULATOR(RecolorCalculator); ::mediapipe::Status RecolorCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) MP_RETURN_IF_ERROR( gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { if (!initialized_) { @@ -168,7 +174,7 @@ REGISTER_CALCULATOR(RecolorCalculator); MP_RETURN_IF_ERROR(RenderGpu(cc)); return ::mediapipe::OkStatus(); })); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU } else { MP_RETURN_IF_ERROR(RenderCpu(cc)); } @@ -176,12 +182,12 @@ REGISTER_CALCULATOR(RecolorCalculator); } ::mediapipe::Status RecolorCalculator::Close(CalculatorContext* cc) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; }); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } @@ -194,7 +200,7 @@ REGISTER_CALCULATOR(RecolorCalculator); if (cc->Inputs().Tag("MASK_GPU").IsEmpty()) { return ::mediapipe::OkStatus(); } -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) // Get inputs and setup output. const Packet& input_packet = cc->Inputs().Tag("IMAGE_GPU").Value(); const Packet& mask_packet = cc->Inputs().Tag("MASK_GPU").Value(); @@ -233,13 +239,13 @@ REGISTER_CALCULATOR(RecolorCalculator); img_tex.Release(); mask_tex.Release(); dst_tex.Release(); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } void RecolorCalculator::GlRender() { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -287,7 +293,7 @@ void RecolorCalculator::GlRender() { glBindVertexArray(0); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU } ::mediapipe::Status RecolorCalculator::LoadOptions(CalculatorContext* cc) { @@ -305,7 +311,7 @@ void RecolorCalculator::GlRender() { } ::mediapipe::Status RecolorCalculator::InitGpu(CalculatorContext* cc) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -374,7 +380,7 @@ void RecolorCalculator::GlRender() { glUniform1i(glGetUniformLocation(program_, "mask"), 2); glUniform3f(glGetUniformLocation(program_, "recolor"), color_[0], color_[1], color_[2]); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/image/set_alpha_calculator.cc b/mediapipe/calculators/image/set_alpha_calculator.cc index f3f6cedaa..31de1e21a 100644 --- a/mediapipe/calculators/image/set_alpha_calculator.cc +++ b/mediapipe/calculators/image/set_alpha_calculator.cc @@ -25,12 +25,11 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" -#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU namespace mediapipe { @@ -107,16 +106,18 @@ class SetAlphaCalculator : public CalculatorBase { bool use_gpu_ = false; bool gpu_initialized_ = false; -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(SetAlphaCalculator); ::mediapipe::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) { CHECK_GE(cc->Inputs().NumEntries(), 1); + bool use_gpu = false; + if (cc->Inputs().HasTag(kInputFrameTag) && cc->Inputs().HasTag(kInputFrameTagGpu)) { return ::mediapipe::InternalError("Cannot have multiple input images."); @@ -127,38 +128,43 @@ REGISTER_CALCULATOR(SetAlphaCalculator); } // Input image to add/edit alpha channel. -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) if (cc->Inputs().HasTag(kInputFrameTagGpu)) { cc->Inputs().Tag(kInputFrameTagGpu).Set(); + use_gpu |= true; } -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputFrameTag)) { cc->Inputs().Tag(kInputFrameTag).Set(); } // Input alpha image mask (optional) -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) if (cc->Inputs().HasTag(kInputAlphaTagGpu)) { cc->Inputs().Tag(kInputAlphaTagGpu).Set(); + use_gpu |= true; } -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputAlphaTag)) { cc->Inputs().Tag(kInputAlphaTag).Set(); } // RGBA output image. -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { cc->Outputs().Tag(kOutputFrameTagGpu).Set(); + use_gpu |= true; } -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kOutputFrameTag)) { cc->Outputs().Tag(kOutputFrameTag).Set(); } -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) - MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // __ANDROID__ or iOS + if (use_gpu) { +#if !defined(MEDIAPIPE_DISABLE_GPU) + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // !MEDIAPIPE_DISABLE_GPU + } return ::mediapipe::OkStatus(); } @@ -170,11 +176,11 @@ REGISTER_CALCULATOR(SetAlphaCalculator); if (cc->Inputs().HasTag(kInputFrameTagGpu) && cc->Outputs().HasTag(kOutputFrameTagGpu)) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) use_gpu_ = true; #else - RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; -#endif // __ANDROID__ or iOS + RET_CHECK_FAIL() << "GPU processing not enabled."; +#endif // !MEDIAPIPE_DISABLE_GPU } // Get global value from options (-1 if not set). @@ -187,17 +193,17 @@ REGISTER_CALCULATOR(SetAlphaCalculator); RET_CHECK_FAIL() << "Must use either image mask or options alpha value."; if (use_gpu_) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); #endif - } + } // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } ::mediapipe::Status SetAlphaCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) MP_RETURN_IF_ERROR( gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { if (!gpu_initialized_) { @@ -207,7 +213,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator); MP_RETURN_IF_ERROR(RenderGpu(cc)); return ::mediapipe::OkStatus(); })); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU } else { MP_RETURN_IF_ERROR(RenderCpu(cc)); } @@ -216,12 +222,12 @@ REGISTER_CALCULATOR(SetAlphaCalculator); } ::mediapipe::Status SetAlphaCalculator::Close(CalculatorContext* cc) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; }); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } @@ -295,7 +301,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator); if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) { return ::mediapipe::OkStatus(); } -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) // Setup source texture. const auto& input_frame = cc->Inputs().Tag(kInputFrameTagGpu).Get(); @@ -348,13 +354,13 @@ REGISTER_CALCULATOR(SetAlphaCalculator); // Cleanup input_texture.Release(); output_texture.Release(); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } void SetAlphaCalculator::GlRender(CalculatorContext* cc) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -403,11 +409,11 @@ void SetAlphaCalculator::GlRender(CalculatorContext* cc) { glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU } ::mediapipe::Status SetAlphaCalculator::GlSetup(CalculatorContext* cc) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -460,7 +466,7 @@ void SetAlphaCalculator::GlRender(CalculatorContext* cc) { glUniform1i(glGetUniformLocation(program_, "alpha_mask"), 2); glUniform1f(glGetUniformLocation(program_, "alpha_value"), alpha_value_); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index 45d6cf965..4231b899e 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -255,6 +255,7 @@ mediapipe_cc_proto_library( cc_deps = [ "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/util:audio_decoder_cc_proto", ], visibility = ["//visibility:public"], deps = [":unpack_media_sequence_calculator_proto"], @@ -653,6 +654,7 @@ cc_library( "//mediapipe/framework/formats:location", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/util:audio_decoder_cc_proto", "//mediapipe/util/sequence:media_sequence", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:protos_all_cc", @@ -769,7 +771,6 @@ cc_test( "//mediapipe/framework/formats:location", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:opencv_imgcodecs", - "//mediapipe/framework/port:status", "//mediapipe/util/sequence:media_sequence", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -971,6 +972,7 @@ cc_test( "//mediapipe/framework/formats:location", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:rectangle", + "//mediapipe/util:audio_decoder_cc_proto", "//mediapipe/util/sequence:media_sequence", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index 7780d7850..594e182ec 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -285,6 +285,10 @@ class PackMediaSequenceCalculator : public CalculatorBase { } ::mediapipe::Status Process(CalculatorContext* cc) override { + int image_height = -1; + int image_width = -1; + // Because the tag order may vary, we need to loop through tags to get + // image information before processing other tag types. for (const auto& tag : cc->Inputs().GetTags()) { if (!cc->Inputs().Tag(tag).IsEmpty()) { features_present_[tag] = true; @@ -306,14 +310,21 @@ class PackMediaSequenceCalculator : public CalculatorBase { return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "No encoded image"; } + image_height = image.height(); + image_width = image.width(); mpms::AddImageTimestamp(key, cc->InputTimestamp().Value(), sequence_.get()); mpms::AddImageEncoded(key, image.encoded_image(), sequence_.get()); } + } + for (const auto& tag : cc->Inputs().GetTags()) { + if (!cc->Inputs().Tag(tag).IsEmpty()) { + features_present_[tag] = true; + } if (absl::StartsWith(tag, kKeypointsTag) && !cc->Inputs().Tag(tag).IsEmpty()) { std::string key = ""; - if (tag != kImageTag) { + if (tag != kKeypointsTag) { int tag_length = sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1; if (tag[tag_length] == '_') { key = tag.substr(tag_length + 1); @@ -363,11 +374,20 @@ class PackMediaSequenceCalculator : public CalculatorBase { LocationData::BOUNDING_BOX || detection.location_data().format() == LocationData::RELATIVE_BOUNDING_BOX) { - int height = mpms::GetImageHeight(*sequence_); - int width = mpms::GetImageWidth(*sequence_); + if (mpms::HasImageHeight(*sequence_) && + mpms::HasImageWidth(*sequence_)) { + image_height = mpms::GetImageHeight(*sequence_); + image_width = mpms::GetImageWidth(*sequence_); + } + if (image_height == -1 || image_width == -1) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Images must be provided with bounding boxes or the " + "image " + << "height and width must already be in the example."; + } Location relative_bbox = Location::CreateRelativeBBoxLocation( Location(detection.location_data()) - .ConvertToRelativeBBox(width, height)); + .ConvertToRelativeBBox(image_width, image_height)); predicted_locations.push_back(relative_bbox); if (detection.label_size() > 0) { predicted_class_strings.push_back(detection.label(0)); diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc index 19b302e13..df43a921f 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc @@ -357,6 +357,148 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) { } } +TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithoutImageDims) { + SetUpCalculator({"BBOX_PREDICTED:detections"}, {}, false, true); + auto input_sequence = ::absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + int height = 480; + int width = 640; + int num_vectors = 2; + for (int i = 0; i < num_vectors; ++i) { + auto detections = ::absl::make_unique<::std::vector>(); + Detection detection; + detection.add_label("absolute bbox"); + detection.add_label_id(0); + detection.add_score(0.5); + Location::CreateBBoxLocation(0, height / 2, width / 2, height / 2) + .ConvertToProto(detection.mutable_location_data()); + detections->push_back(detection); + + detection = Detection(); + detection.add_label("relative bbox"); + detection.add_label_id(1); + detection.add_score(0.75); + Location::CreateRelativeBBoxLocation(0, 0.5, 0.5, 0.5) + .ConvertToProto(detection.mutable_location_data()); + detections->push_back(detection); + + // The mask detection should be ignored in the output. + detection = Detection(); + detection.add_label("mask"); + detection.add_score(1.0); + cv::Mat image(2, 3, CV_8UC1, cv::Scalar(0)); + Location::CreateCvMaskLocation(image).ConvertToProto( + detection.mutable_location_data()); + detections->push_back(detection); + + runner_->MutableInputs() + ->Tag("BBOX_PREDICTED") + .packets.push_back(Adopt(detections.release()).At(Timestamp(i))); + } + + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + auto status = runner_->Run(); + EXPECT_EQ(::mediapipe::StatusCode::kInvalidArgument, status.code()); +} + +TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) { + SetUpCalculator({"BBOX_PREDICTED:detections", "IMAGE:images"}, {}, false, + true); + auto input_sequence = ::absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + int height = 480; + int width = 640; + int num_vectors = 2; + for (int i = 0; i < num_vectors; ++i) { + auto detections = ::absl::make_unique<::std::vector>(); + Detection detection; + detection.add_label("absolute bbox"); + detection.add_label_id(0); + detection.add_score(0.5); + Location::CreateBBoxLocation(0, height / 2, width / 2, height / 2) + .ConvertToProto(detection.mutable_location_data()); + detections->push_back(detection); + + detection = Detection(); + detection.add_label("relative bbox"); + detection.add_label_id(1); + detection.add_score(0.75); + Location::CreateRelativeBBoxLocation(0, 0.5, 0.5, 0.5) + .ConvertToProto(detection.mutable_location_data()); + detections->push_back(detection); + + // The mask detection should be ignored in the output. + detection = Detection(); + detection.add_label("mask"); + detection.add_score(1.0); + cv::Mat image(2, 3, CV_8UC1, cv::Scalar(0)); + Location::CreateCvMaskLocation(image).ConvertToProto( + detection.mutable_location_data()); + detections->push_back(detection); + + runner_->MutableInputs() + ->Tag("BBOX_PREDICTED") + .packets.push_back(Adopt(detections.release()).At(Timestamp(i))); + } + cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255)); + std::vector bytes; + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + std::string test_image_string(bytes.begin(), bytes.end()); + OpenCvImageEncoderCalculatorResults encoded_image; + encoded_image.set_encoded_image(test_image_string); + encoded_image.set_width(width); + encoded_image.set_height(height); + + int num_images = 2; + for (int i = 0; i < num_images; ++i) { + auto image_ptr = + ::absl::make_unique(encoded_image); + runner_->MutableInputs()->Tag("IMAGE").packets.push_back( + Adopt(image_ptr.release()).At(Timestamp(i))); + } + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence)); + ASSERT_EQ(height, mpms::GetImageHeight(output_sequence)); + ASSERT_EQ(width, mpms::GetImageWidth(output_sequence)); + ASSERT_EQ(num_vectors, mpms::GetPredictedBBoxSize(output_sequence)); + ASSERT_EQ(num_vectors, mpms::GetPredictedBBoxTimestampSize(output_sequence)); + ASSERT_EQ(0, mpms::GetClassSegmentationEncodedSize(output_sequence)); + ASSERT_EQ(0, mpms::GetClassSegmentationTimestampSize(output_sequence)); + for (int i = 0; i < num_vectors; ++i) { + ASSERT_EQ(i, mpms::GetPredictedBBoxTimestampAt(output_sequence, i)); + auto bboxes = mpms::GetPredictedBBoxAt(output_sequence, i); + ASSERT_EQ(2, bboxes.size()); + for (int j = 0; j < bboxes.size(); ++j) { + auto rect = bboxes[j].GetRelativeBBox(); + ASSERT_NEAR(0, rect.xmin(), 0.001); + ASSERT_NEAR(0.5, rect.ymin(), 0.001); + ASSERT_NEAR(0.5, rect.xmax(), 0.001); + ASSERT_NEAR(1.0, rect.ymax(), 0.001); + } + auto class_strings = + mpms::GetPredictedBBoxLabelStringAt(output_sequence, i); + ASSERT_EQ("absolute bbox", class_strings[0]); + ASSERT_EQ("relative bbox", class_strings[1]); + auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i); + ASSERT_EQ(0, class_indices[0]); + ASSERT_EQ(1, class_indices[1]); + } +} + TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) { SetUpCalculator({"KEYPOINTS_TEST:keypoints"}, {}, false, true); auto input_sequence = ::absl::make_unique(); diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc index 51493d7b6..a92b48d30 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc @@ -19,6 +19,7 @@ #include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" +#include "mediapipe/util/audio_decoder.pb.h" #include "mediapipe/util/sequence/media_sequence.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" @@ -37,6 +38,7 @@ const char kDatasetRootDirTag[] = "DATASET_ROOT"; const char kDataPath[] = "DATA_PATH"; const char kPacketResamplerOptions[] = "RESAMPLER_OPTIONS"; const char kImagesFrameRateTag[] = "IMAGE_FRAME_RATE"; +const char kAudioDecoderOptions[] = "AUDIO_DECODER_OPTIONS"; namespace tf = ::tensorflow; namespace mpms = ::mediapipe::mediasequence; @@ -126,6 +128,11 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { if (cc->OutputSidePackets().HasTag(kDataPath)) { cc->OutputSidePackets().Tag(kDataPath).Set(); } + if (cc->OutputSidePackets().HasTag(kAudioDecoderOptions)) { + cc->OutputSidePackets() + .Tag(kAudioDecoderOptions) + .Set(); + } if (cc->OutputSidePackets().HasTag(kImagesFrameRateTag)) { cc->OutputSidePackets().Tag(kImagesFrameRateTag).Set(); } @@ -136,10 +143,11 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { } if ((options.has_padding_before_label() || options.has_padding_after_label()) && - !(cc->OutputSidePackets().HasTag(kPacketResamplerOptions))) { + !(cc->OutputSidePackets().HasTag(kAudioDecoderOptions) || + cc->OutputSidePackets().HasTag(kPacketResamplerOptions))) { return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) - << "If specifying padding, must output " - << kPacketResamplerOptions; + << "If specifying padding, must output " << kPacketResamplerOptions + << "or" << kAudioDecoderOptions; } // Optional streams. @@ -260,7 +268,8 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { // Set the start and end of the clip in the appropriate options protos. double start_time = 0; double end_time = 0; - if (cc->OutputSidePackets().HasTag(kPacketResamplerOptions)) { + if (cc->OutputSidePackets().HasTag(kAudioDecoderOptions) || + cc->OutputSidePackets().HasTag(kPacketResamplerOptions)) { if (mpms::HasClipStartTimestamp(sequence)) { start_time = Timestamp(mpms::GetClipStartTimestamp(sequence)).Seconds() - @@ -271,6 +280,27 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { options.padding_after_label(); } } + if (cc->OutputSidePackets().HasTag(kAudioDecoderOptions)) { + auto audio_decoder_options = absl::make_unique( + options.base_audio_decoder_options()); + if (mpms::HasClipStartTimestamp(sequence)) { + if (options.force_decoding_from_start_of_media()) { + audio_decoder_options->set_start_time(0); + } else { + audio_decoder_options->set_start_time( + start_time - options.extra_padding_from_media_decoder()); + } + } + if (mpms::HasClipEndTimestamp(sequence)) { + audio_decoder_options->set_end_time( + end_time + options.extra_padding_from_media_decoder()); + } + LOG(INFO) << "Created AudioDecoderOptions:\n" + << audio_decoder_options->DebugString(); + cc->OutputSidePackets() + .Tag(kAudioDecoderOptions) + .Set(Adopt(audio_decoder_options.release())); + } if (cc->OutputSidePackets().HasTag(kPacketResamplerOptions)) { auto resampler_options = absl::make_unique(); *(resampler_options->MutableExtension( diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.proto b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.proto index e6e839645..51cc870c7 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.proto +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.proto @@ -18,6 +18,7 @@ package mediapipe; import "mediapipe/calculators/core/packet_resampler_calculator.proto"; import "mediapipe/framework/calculator.proto"; +import "mediapipe/util/audio_decoder.proto"; message UnpackMediaSequenceCalculatorOptions { extend mediapipe.CalculatorOptions { @@ -49,4 +50,10 @@ message UnpackMediaSequenceCalculatorOptions { // parameters for the MediaDecoderCalculator. End time parameters are still // respected. optional bool force_decoding_from_start_of_media = 7; + + // Stores the audio decoder settings for the graph. (e.g. which audio + // stream to pull from the video.) The sequence's metadata overrides + // the clip start and end times and outputs these for the + // AudioDecoderCalculator to consume. + optional AudioDecoderOptions base_audio_decoder_options = 9; } diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc index 36958ef8f..185e2e186 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc @@ -23,6 +23,7 @@ #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/rectangle.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/util/audio_decoder.pb.h" #include "mediapipe/util/sequence/media_sequence.h" #include "tensorflow/core/example/example.pb.h" @@ -459,6 +460,62 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromExample) { data_path_); } +TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptions) { + CalculatorOptions options; + options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) + ->set_padding_before_label(1); + options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) + ->set_padding_after_label(2); + SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {}, + &options); + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(sequence_.release()); + MP_ASSERT_OK(runner_->Run()); + + MP_EXPECT_OK(runner_->OutputSidePackets() + .Tag("AUDIO_DECODER_OPTIONS") + .ValidateAsType()); + EXPECT_NEAR(runner_->OutputSidePackets() + .Tag("AUDIO_DECODER_OPTIONS") + .Get() + .start_time(), + 2.0, 1e-5); + EXPECT_NEAR(runner_->OutputSidePackets() + .Tag("AUDIO_DECODER_OPTIONS") + .Get() + .end_time(), + 7.0, 1e-5); +} + +TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptionsOverride) { + CalculatorOptions options; + options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) + ->set_padding_before_label(1); + options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) + ->set_padding_after_label(2); + options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) + ->set_force_decoding_from_start_of_media(true); + SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {}, + &options); + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(sequence_.release()); + MP_ASSERT_OK(runner_->Run()); + + MP_EXPECT_OK(runner_->OutputSidePackets() + .Tag("AUDIO_DECODER_OPTIONS") + .ValidateAsType()); + EXPECT_NEAR(runner_->OutputSidePackets() + .Tag("AUDIO_DECODER_OPTIONS") + .Get() + .start_time(), + 0.0, 1e-5); + EXPECT_NEAR(runner_->OutputSidePackets() + .Tag("AUDIO_DECODER_OPTIONS") + .Get() + .end_time(), + 7.0, 1e-5); +} + TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) { // TODO: Suport proto3 proto.Any in CalculatorOptions. // TODO: Avoid proto2 extensions in "RESAMPLER_OPTIONS". diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 93f08edc5..a0b1fc0b6 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -195,6 +195,12 @@ cc_test( ], ) +cc_library( + name = "util", + hdrs = ["util.h"], + alwayslink = 1, +) + cc_library( name = "tflite_inference_calculator", srcs = ["tflite_inference_calculator.cc"], @@ -214,6 +220,7 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ + ":util", ":tflite_inference_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/util:resource_util", @@ -222,20 +229,25 @@ cc_library( "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/port:ret_check", ] + select({ - "//mediapipe:android": [ + "//mediapipe/gpu:disable_gpu": [], + "//mediapipe:ios": [ + "//mediapipe/gpu:MPPMetalHelper", + "//mediapipe/gpu:MPPMetalUtil", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/objc:mediapipe_framework_ios", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", + "@org_tensorflow//tensorflow/lite/delegates/gpu/metal:buffer_convert", + "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", + ], + "//conditions:default": [ "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gpu_buffer", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", ], - "//mediapipe:ios": [ - "//mediapipe/gpu:MPPMetalHelper", - "//mediapipe/objc:mediapipe_framework_ios", - "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", - ], - "//conditions:default": [], }), alwayslink = 1, ) @@ -259,33 +271,33 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ + ":util", ":tflite_converter_calculator_cc_proto", "//mediapipe/util:resource_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", - "//mediapipe/framework/tool:status_util", - "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ] + select({ - "//mediapipe:android": [ - "//mediapipe/gpu:gl_calculator_helper", - "//mediapipe/gpu:gpu_buffer", - "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", - ], + "//mediapipe/gpu:disable_gpu": [], "//mediapipe:ios": [ + "//mediapipe/gpu:MPPMetalUtil", "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", ], - "//conditions:default": [], + "//conditions:default": [ + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gl_calculator_helper", + "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", + ], }), alwayslink = 1, ) @@ -295,6 +307,7 @@ cc_library( srcs = ["tflite_tensors_to_segmentation_calculator.cc"], visibility = ["//visibility:public"], deps = [ + ":util", ":tflite_tensors_to_segmentation_calculator_cc_proto", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -308,7 +321,9 @@ cc_library( "//mediapipe/util:resource_util", "@org_tensorflow//tensorflow/lite:framework", ] + select({ - "//mediapipe:android": [ + "//mediapipe/gpu:disable_gpu": [], + "//mediapipe:ios": [], + "//conditions:default": [ "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gpu_buffer", @@ -319,7 +334,6 @@ cc_library( "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_texture", ], - "//conditions:default": [], }), alwayslink = 1, ) @@ -346,8 +360,23 @@ cc_test( cc_library( name = "tflite_tensors_to_detections_calculator", srcs = ["tflite_tensors_to_detections_calculator.cc"], + copts = select({ + "//mediapipe:ios": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + linkopts = select({ + "//mediapipe:ios": [ + "-framework CoreVideo", + "-framework MetalKit", + ], + "//conditions:default": [], + }), visibility = ["//visibility:public"], deps = [ + ":util", ":tflite_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", "@com_google_absl//absl/strings:str_format", @@ -359,14 +388,21 @@ cc_library( "//mediapipe/framework/port:ret_check", "@org_tensorflow//tensorflow/lite:framework", ] + select({ - "//mediapipe:android": [ + "//mediapipe/gpu:disable_gpu": [], + "//mediapipe:ios": [ + "//mediapipe/gpu:MPPMetalUtil", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:MPPMetalHelper", + "//mediapipe/objc:mediapipe_framework_ios", + "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", + ], + "//conditions:default": [ "//mediapipe/gpu:gl_calculator_helper", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", ], - "//conditions:default": [], }), alwayslink = 1, ) diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator.cc b/mediapipe/calculators/tflite/tflite_converter_calculator.cc index 37952008b..598ae4965 100644 --- a/mediapipe/calculators/tflite/tflite_converter_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_converter_calculator.cc @@ -16,23 +16,23 @@ #include #include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h" +#include "mediapipe/calculators/tflite/util.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/matrix.h" -#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/util/resource_util.h" #include "tensorflow/lite/error_reporter.h" #include "tensorflow/lite/interpreter.h" -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gpu_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" -#endif // __ANDROID__ +#endif // !MEDIAPIPE_DISABLE_GPU #if defined(__APPLE__) && !TARGET_OS_OSX // iOS #import @@ -40,11 +40,12 @@ #import #import "mediapipe/gpu/MPPMetalHelper.h" +#include "mediapipe/gpu/MPPMetalUtil.h" #include "mediapipe/gpu/gpu_buffer.h" #include "tensorflow/lite/delegates/gpu/metal_delegate.h" #endif // iOS -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) typedef ::tflite::gpu::gl::GlBuffer GpuTensor; #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS typedef id GpuTensor; @@ -66,26 +67,27 @@ typedef Eigen::Matrix namespace mediapipe { -#if defined(__ANDROID__) -using ::tflite::gpu::gl::GlBuffer; +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) +using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::GlProgram; using ::tflite::gpu::gl::GlShader; struct GPUData { int elements = 1; - GlBuffer buffer; + GpuTensor buffer; GlShader shader; GlProgram program; }; #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS struct GPUData { int elements = 1; - id buffer; + GpuTensor buffer; id pipeline_state; }; #endif // Calculator for normalizing and converting an ImageFrame or Matrix -// into a TfLiteTensor (float 32) or a GpuBuffer to a tflite::gpu::GlBuffer. +// into a TfLiteTensor (float 32) or a GpuBuffer to a tflite::gpu::GlBuffer +// or MTLBuffer. // // This calculator is designed to be used with the TfLiteInferenceCalcualtor, // as a pre-processing step for calculator inputs. @@ -102,7 +104,7 @@ struct GPUData { // Output: // One of the following tags: // TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32, or kTfLiteUint8. -// TENSORS_GPU - vector of GlBuffer. +// TENSORS_GPU - vector of GlBuffer or MTLBuffer. // // Example use: // node { @@ -144,7 +146,7 @@ class TfLiteConverterCalculator : public CalculatorBase { std::unique_ptr interpreter_ = nullptr; -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) mediapipe::GlCalculatorHelper gpu_helper_; std::unique_ptr gpu_data_out_; #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS @@ -175,25 +177,33 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); RET_CHECK(cc->Outputs().HasTag("TENSORS") ^ cc->Outputs().HasTag("TENSORS_GPU")); + bool use_gpu = false; + if (cc->Inputs().HasTag("IMAGE")) cc->Inputs().Tag("IMAGE").Set(); if (cc->Inputs().HasTag("MATRIX")) cc->Inputs().Tag("MATRIX").Set(); -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) - if (cc->Inputs().HasTag("IMAGE_GPU")) +#if !defined(MEDIAPIPE_DISABLE_GPU) + if (cc->Inputs().HasTag("IMAGE_GPU")) { cc->Inputs().Tag("IMAGE_GPU").Set(); -#endif + use_gpu |= true; + } +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag("TENSORS")) cc->Outputs().Tag("TENSORS").Set>(); -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) - if (cc->Outputs().HasTag("TENSORS_GPU")) +#if !defined(MEDIAPIPE_DISABLE_GPU) + if (cc->Outputs().HasTag("TENSORS_GPU")) { cc->Outputs().Tag("TENSORS_GPU").Set>(); -#endif + use_gpu |= true; + } +#endif // !MEDIAPIPE_DISABLE_GPU -#if defined(__ANDROID__) - MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); + if (use_gpu) { +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS - MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); + MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); #endif + } // Assign this calculator's default InputStreamHandler. cc->SetInputStreamHandler("FixedSizeInputStreamHandler"); @@ -208,10 +218,10 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); if (cc->Inputs().HasTag("IMAGE_GPU") || cc->Outputs().HasTag("IMAGE_OUT_GPU")) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) use_gpu_ = true; #else - RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; + RET_CHECK_FAIL() << "GPU processing not enabled."; #endif } @@ -221,7 +231,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); cc->Outputs().HasTag("TENSORS_GPU")); // Cannot use quantization. use_quantized_tensors_ = false; -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; @@ -238,6 +248,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); ::mediapipe::Status TfLiteConverterCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { + // GpuBuffer to tflite::gpu::GlBuffer conversion. if (!initialized_) { MP_RETURN_IF_ERROR(InitGpu(cc)); initialized_ = true; @@ -253,7 +264,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); } ::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) { -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); }); #endif #if defined(__APPLE__) && !TARGET_OS_OSX // iOS @@ -372,7 +383,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); ::mediapipe::Status TfLiteConverterCalculator::ProcessGPU( CalculatorContext* cc) { -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) // GpuBuffer to tflite::gpu::GlBuffer conversion. const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get(); MP_RETURN_IF_ERROR( @@ -381,17 +392,11 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); auto src = gpu_helper_.CreateSourceTexture(input); glActiveTexture(GL_TEXTURE0 + 0); glBindTexture(GL_TEXTURE_2D, src.name()); - auto status = gpu_data_out_->buffer.BindToIndex(1); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } + RET_CHECK_CALL(gpu_data_out_->buffer.BindToIndex(1)); const tflite::gpu::uint3 workgroups = { NumGroups(input.width(), kWorkgroupSize), NumGroups(input.height(), kWorkgroupSize), 1}; - status = gpu_data_out_->program.Dispatch(workgroups); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } + RET_CHECK_CALL(gpu_data_out_->program.Dispatch(workgroups)); glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); glBindTexture(GL_TEXTURE_2D, 0); src.Release(); @@ -400,17 +405,17 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); // Copy into outputs. auto output_tensors = absl::make_unique>(); - output_tensors->resize(1); - { - GlBuffer& tensor = output_tensors->at(0); - using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; - auto status = CreateReadWriteShaderStorageBuffer( - gpu_data_out_->elements, &tensor); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - tflite::gpu::gl::CopyBuffer(gpu_data_out_->buffer, tensor); - } + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &output_tensors]() -> ::mediapipe::Status { + output_tensors->resize(1); + { + GpuTensor& tensor = output_tensors->at(0); + RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + gpu_data_out_->elements, &tensor)); + RET_CHECK_CALL(CopyBuffer(gpu_data_out_->buffer, tensor)); + } + return ::mediapipe::OkStatus(); + })); cc->Outputs() .Tag("TENSORS_GPU") .Add(output_tensors.release(), cc->InputTimestamp()); @@ -438,66 +443,60 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); } // Copy into outputs. + // TODO Avoid this copy. auto output_tensors = absl::make_unique>(); + output_tensors->resize(1); { id device = gpu_helper_.mtlDevice; - id command_buffer = [gpu_helper_ commandBuffer]; - command_buffer.label = @"TfLiteConverterCalculatorCopy"; - id tensor = + output_tensors->at(0) = [device newBufferWithLength:gpu_data_out_->elements * sizeof(float) options:MTLResourceStorageModeShared]; - id blit_command = - [command_buffer blitCommandEncoder]; - [blit_command copyFromBuffer:gpu_data_out_->buffer - sourceOffset:0 - toBuffer:tensor - destinationOffset:0 - size:gpu_data_out_->elements * sizeof(float)]; - [blit_command endEncoding]; - [command_buffer commit]; - [command_buffer waitUntilCompleted]; - - output_tensors->push_back(tensor); + [MPPMetalUtil blitMetalBufferTo:output_tensors->at(0) + from:gpu_data_out_->buffer + blocking:true + commandBuffer:[gpu_helper_ commandBuffer]]; } cc->Outputs() .Tag("TENSORS_GPU") .Add(output_tensors.release(), cc->InputTimestamp()); #else - RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; + RET_CHECK_FAIL() << "GPU processing is not enabled."; #endif return ::mediapipe::OkStatus(); } ::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) - // Configure inputs. +#if !defined(MEDIAPIPE_DISABLE_GPU) + // Get input image sizes. const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get(); mediapipe::ImageFormat::Format format = mediapipe::ImageFormatForGpuBufferFormat(input.format()); gpu_data_out_ = absl::make_unique(); gpu_data_out_->elements = input.height() * input.width() * max_num_channels_; const bool include_alpha = (max_num_channels_ == 4); - if (!(format == mediapipe::ImageFormat::SRGB || + const bool single_channel = (max_num_channels_ == 1); + if (!(format == mediapipe::ImageFormat::GRAY8 || + format == mediapipe::ImageFormat::SRGB || format == mediapipe::ImageFormat::SRGBA)) RET_CHECK_FAIL() << "Unsupported GPU input format."; if (include_alpha && (format != mediapipe::ImageFormat::SRGBA)) RET_CHECK_FAIL() << "Num input channels is less than desired output."; -#endif +#endif // !MEDIAPIPE_DISABLE_GPU -#if defined(__ANDROID__) - // Device memory. - auto status = ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( - gpu_data_out_->elements, &gpu_data_out_->buffer); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status { + // Device memory. + RET_CHECK_CALL( + ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( + gpu_data_out_->elements, &gpu_data_out_->buffer)); - // Shader to convert GL Texture to Shader Storage Buffer Object (SSBO), - // with normalization to either: [0,1] or [-1,1]. - const std::string shader_source = absl::Substitute( - R"( #version 310 es + // Shader to convert GL Texture to Shader Storage Buffer Object (SSBO), + // with normalization to either: [0,1] or [-1,1]. + const std::string shader_source = absl::Substitute( + R"( #version 310 es layout(local_size_x = $0, local_size_y = $0) in; layout(binding = 0) uniform sampler2D input_texture; layout(std430, binding = 1) buffer Output {float elements[];} output_data; @@ -505,33 +504,31 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); void main() { ivec2 gid = ivec2(gl_GlobalInvocationID.xy); if (gid.x >= width_height.x || gid.y >= width_height.y) return; - $5 // pixel fetch + vec4 pixel = texelFetch(input_texture, gid, 0); $3 // normalize [-1,1] int linear_index = $7 * ($4 * width_height.x + gid.x); - output_data.elements[linear_index + 0] = pixel.x; - output_data.elements[linear_index + 1] = pixel.y; - output_data.elements[linear_index + 2] = pixel.z; + output_data.elements[linear_index + 0] = pixel.x; // r channel + $5 // g & b channels $6 // alpha channel })", - /*$0=*/kWorkgroupSize, /*$1=*/input.width(), /*$2=*/input.height(), - /*$3=*/zero_center_ ? "pixel = (pixel - 0.5) * 2.0;" : "", - /*$4=*/flip_vertically_ ? "(width_height.y - 1 - gid.y)" : "gid.y", - /*$5=*/ - include_alpha ? "vec4 pixel = texelFetch(input_texture, gid, 0);" - : "vec3 pixel = texelFetch(input_texture, gid, 0).xyz;", - /*$6=*/ - include_alpha ? "output_data.elements[linear_index + 3] = pixel.w;" : "", - /*$7=*/include_alpha ? 4 : 3); - status = GlShader::CompileShader(GL_COMPUTE_SHADER, shader_source, - &gpu_data_out_->shader); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - status = GlProgram::CreateWithShader(gpu_data_out_->shader, - &gpu_data_out_->program); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } + /*$0=*/kWorkgroupSize, /*$1=*/input.width(), /*$2=*/input.height(), + /*$3=*/zero_center_ ? "pixel = (pixel - 0.5) * 2.0;" : "", + /*$4=*/flip_vertically_ ? "(width_height.y - 1 - gid.y)" : "gid.y", + /*$5=*/ + single_channel + ? "" + : R"(output_data.elements[linear_index + 1] = pixel.y; + output_data.elements[linear_index + 2] = pixel.z;)", + /*$6=*/ + include_alpha ? "output_data.elements[linear_index + 3] = pixel.w;" + : "", + /*$7=*/max_num_channels_); + RET_CHECK_CALL(GlShader::CompileShader(GL_COMPUTE_SHADER, shader_source, + &gpu_data_out_->shader)); + RET_CHECK_CALL(GlProgram::CreateWithShader(gpu_data_out_->shader, + &gpu_data_out_->program)); + return ::mediapipe::OkStatus(); + })); #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS RET_CHECK(include_alpha) << "iOS GPU inference currently accepts only RGBA input."; @@ -546,8 +543,6 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); // with normalization to either: [0,1] or [-1,1]. const std::string shader_source = absl::Substitute( R"( - #include - #include using namespace metal; @@ -612,9 +607,9 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); // Get desired way to handle input channels. max_num_channels_ = options.max_num_channels(); - // Currently only alpha channel toggling is suppored. - CHECK_GE(max_num_channels_, 3); + CHECK_GE(max_num_channels_, 1); CHECK_LE(max_num_channels_, 4); + CHECK_NE(max_num_channels_, 2); #if defined(__APPLE__) && !TARGET_OS_OSX // iOS if (cc->Inputs().HasTag("IMAGE_GPU")) // Currently on iOS, tflite gpu input tensor must be 4 channels, diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator.proto b/mediapipe/calculators/tflite/tflite_converter_calculator.proto index 3be32b347..2c0d8f4e1 100644 --- a/mediapipe/calculators/tflite/tflite_converter_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_converter_calculator.proto @@ -36,8 +36,7 @@ message TfLiteConverterCalculatorOptions { optional bool flip_vertically = 2 [default = false]; // Controls how many channels of the input image get passed through to the - // tensor. Currently this only controls whether or not to ignore alpha - // channel, so it must be 3 or 4. + // tensor. Valid values are 1,3,4 only. Ignored for iOS GPU. optional int32 max_num_channels = 3 [default = 3]; // The calculator expects Matrix inputs to be in column-major order. Set diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index f9b5646a4..9bc02b48c 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -12,10 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include #include #include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h" +#include "mediapipe/calculators/tflite/util.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/util/resource_util.h" @@ -24,14 +27,15 @@ #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gpu_buffer.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" -#endif // __ANDROID__ +#endif // !MEDIAPIPE_DISABLE_GPU #if defined(__APPLE__) && !TARGET_OS_OSX // iOS #import @@ -39,33 +43,42 @@ #import #import "mediapipe/gpu/MPPMetalHelper.h" +#include "mediapipe/gpu/MPPMetalUtil.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h" #include "tensorflow/lite/delegates/gpu/metal_delegate.h" #endif // iOS -#if defined(__ANDROID__) +namespace { + +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) typedef ::tflite::gpu::gl::GlBuffer GpuTensor; #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS typedef id GpuTensor; #endif +// Round up n to next multiple of m. +size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT +} // namespace + // TfLiteInferenceCalculator File Layout: // * Header // * Core // * Aux namespace mediapipe { -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) +using ::tflite::gpu::gl::CopyBuffer; +using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::GlBuffer; -using ::tflite::gpu::gl::GlProgram; -using ::tflite::gpu::gl::GlShader; +#endif + +#if !defined(MEDIAPIPE_DISABLE_GPU) struct GPUData { int elements = 1; - GlBuffer buffer; -}; -#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS -struct GPUData { - int elements = 1; - id buffer; + GpuTensor buffer; + ::tflite::gpu::BHWC shape; }; #endif @@ -134,7 +147,7 @@ class TfLiteInferenceCalculator : public CalculatorBase { std::unique_ptr model_; TfLiteDelegate* delegate_ = nullptr; -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) mediapipe::GlCalculatorHelper gpu_helper_; std::unique_ptr gpu_data_in_; std::vector> gpu_data_out_; @@ -142,6 +155,7 @@ class TfLiteInferenceCalculator : public CalculatorBase { MPPMetalHelper* gpu_helper_ = nullptr; std::unique_ptr gpu_data_in_; std::vector> gpu_data_out_; + TFLBufferConvert* converter_from_BPHWC4_ = nil; #endif std::string model_path_ = ""; @@ -161,19 +175,25 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); RET_CHECK(cc->Outputs().HasTag("TENSORS") ^ cc->Outputs().HasTag("TENSORS_GPU")); + bool use_gpu = false; + if (cc->Inputs().HasTag("TENSORS")) cc->Inputs().Tag("TENSORS").Set>(); -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) - if (cc->Inputs().HasTag("TENSORS_GPU")) +#if !defined(MEDIAPIPE_DISABLE_GPU) + if (cc->Inputs().HasTag("TENSORS_GPU")) { cc->Inputs().Tag("TENSORS_GPU").Set>(); -#endif + use_gpu |= true; + } +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag("TENSORS")) cc->Outputs().Tag("TENSORS").Set>(); -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) - if (cc->Outputs().HasTag("TENSORS_GPU")) +#if !defined(MEDIAPIPE_DISABLE_GPU) + if (cc->Outputs().HasTag("TENSORS_GPU")) { cc->Outputs().Tag("TENSORS_GPU").Set>(); -#endif + use_gpu |= true; + } +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { cc->InputSidePackets() @@ -181,11 +201,17 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); .Set(); } -#if defined(__ANDROID__) - MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); + const auto& options = + cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>(); + use_gpu |= options.use_gpu(); + + if (use_gpu) { +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS - MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); + MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); #endif + } // Assign this calculator's default InputStreamHandler. cc->SetInputStreamHandler("FixedSizeInputStreamHandler"); @@ -199,35 +225,41 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); MP_RETURN_IF_ERROR(LoadOptions(cc)); if (cc->Inputs().HasTag("TENSORS_GPU")) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) gpu_input_ = true; gpu_inference_ = true; // Inference must be on GPU also. #else - RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; -#endif + RET_CHECK(!cc->Inputs().HasTag("TENSORS_GPU")) + << "GPU processing not enabled."; +#endif // !MEDIAPIPE_DISABLE_GPU } if (cc->Outputs().HasTag("TENSORS_GPU")) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) gpu_output_ = true; RET_CHECK(cc->Inputs().HasTag("TENSORS_GPU")) << "GPU output must also have GPU Input."; #else - RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; -#endif + RET_CHECK(!cc->Inputs().HasTag("TENSORS_GPU")) + << "GPU processing not enabled."; +#endif // !MEDIAPIPE_DISABLE_GPU } MP_RETURN_IF_ERROR(LoadModel(cc)); if (gpu_inference_) { -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; RET_CHECK(gpu_helper_); #endif - +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &cc]() -> ::mediapipe::Status { return LoadDelegate(cc); })); +#else MP_RETURN_IF_ERROR(LoadDelegate(cc)); +#endif } return ::mediapipe::OkStatus(); @@ -237,35 +269,27 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); // 1. Receive pre-processed tensor inputs. if (gpu_input_) { // Read GPU input into SSBO. -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) const auto& input_tensors = cc->Inputs().Tag("TENSORS_GPU").Get>(); RET_CHECK_EQ(input_tensors.size(), 1); MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( [this, &input_tensors]() -> ::mediapipe::Status { // Explicit copy input. - tflite::gpu::gl::CopyBuffer(input_tensors[0], gpu_data_in_->buffer); + RET_CHECK_CALL(CopyBuffer(input_tensors[0], gpu_data_in_->buffer)); return ::mediapipe::OkStatus(); })); #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS const auto& input_tensors = cc->Inputs().Tag("TENSORS_GPU").Get>(); RET_CHECK_EQ(input_tensors.size(), 1); - id command_buffer = [gpu_helper_ commandBuffer]; - command_buffer.label = @"TfLiteInferenceCalculatorInput"; - id blit_command = - [command_buffer blitCommandEncoder]; // Explicit copy input. - [blit_command copyFromBuffer:input_tensors[0] - sourceOffset:0 - toBuffer:gpu_data_in_->buffer - destinationOffset:0 - size:gpu_data_in_->elements * sizeof(float)]; - [blit_command endEncoding]; - [command_buffer commit]; - [command_buffer waitUntilCompleted]; + [MPPMetalUtil blitMetalBufferTo:gpu_data_in_->buffer + from:input_tensors[0] + blocking:true + commandBuffer:[gpu_helper_ commandBuffer]]; #else - RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; + RET_CHECK_FAIL() << "GPU processing not enabled."; #endif } else { // Read CPU input into tensors. @@ -278,18 +302,20 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); if (use_quantized_tensors_) { const uint8* input_tensor_buffer = input_tensor->data.uint8; uint8* local_tensor_buffer = interpreter_->typed_input_tensor(i); - memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor->bytes); + std::memcpy(local_tensor_buffer, input_tensor_buffer, + input_tensor->bytes); } else { const float* input_tensor_buffer = input_tensor->data.f; float* local_tensor_buffer = interpreter_->typed_input_tensor(i); - memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor->bytes); + std::memcpy(local_tensor_buffer, input_tensor_buffer, + input_tensor->bytes); } } } // 2. Run inference. if (gpu_inference_) { -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) MP_RETURN_IF_ERROR( gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status { RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); @@ -304,52 +330,51 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); // 3. Output processed tensors. if (gpu_output_) { -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) // Output result tensors (GPU). auto output_tensors = absl::make_unique>(); - output_tensors->resize(gpu_data_out_.size()); - for (int i = 0; i < gpu_data_out_.size(); ++i) { - GlBuffer& tensor = output_tensors->at(i); - using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; - auto status = CreateReadWriteShaderStorageBuffer( - gpu_data_out_[i]->elements, &tensor); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - tflite::gpu::gl::CopyBuffer(gpu_data_out_[i]->buffer, tensor); - } + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &output_tensors]() -> ::mediapipe::Status { + output_tensors->resize(gpu_data_out_.size()); + for (int i = 0; i < gpu_data_out_.size(); ++i) { + GpuTensor& tensor = output_tensors->at(i); + RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + gpu_data_out_[i]->elements, &tensor)); + RET_CHECK_CALL(CopyBuffer(gpu_data_out_[i]->buffer, tensor)); + } + return ::mediapipe::OkStatus(); + })); cc->Outputs() .Tag("TENSORS_GPU") .Add(output_tensors.release(), cc->InputTimestamp()); #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS // Output result tensors (GPU). auto output_tensors = absl::make_unique>(); + output_tensors->resize(gpu_data_out_.size()); id device = gpu_helper_.mtlDevice; id command_buffer = [gpu_helper_ commandBuffer]; - command_buffer.label = @"TfLiteInferenceCalculatorOutput"; + command_buffer.label = @"TfLiteInferenceBPHWC4Convert"; + id convert_command = + [command_buffer computeCommandEncoder]; for (int i = 0; i < gpu_data_out_.size(); ++i) { - id tensor = + output_tensors->at(i) = [device newBufferWithLength:gpu_data_out_[i]->elements * sizeof(float) options:MTLResourceStorageModeShared]; - id blit_command = - [command_buffer blitCommandEncoder]; - // Explicit copy input. - [blit_command copyFromBuffer:gpu_data_out_[i]->buffer - sourceOffset:0 - toBuffer:tensor - destinationOffset:0 - size:gpu_data_out_[i]->elements * sizeof(float)]; - [blit_command endEncoding]; - [command_buffer commit]; - [command_buffer waitUntilCompleted]; - output_tensors->push_back(tensor); + // Reshape tensor. + [converter_from_BPHWC4_ convertWithEncoder:convert_command + shape:gpu_data_out_[i]->shape + sourceBuffer:gpu_data_out_[i]->buffer + convertedBuffer:output_tensors->at(i)]; } + [convert_command endEncoding]; + [command_buffer commit]; + [command_buffer waitUntilCompleted]; cc->Outputs() .Tag("TENSORS_GPU") .Add(output_tensors.release(), cc->InputTimestamp()); #else - RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; -#endif + RET_CHECK_FAIL() << "GPU processing not enabled."; +#endif // !MEDIAPIPE_DISABLE_GPU } else { // Output result tensors (CPU). const auto& tensor_indexes = interpreter_->outputs(); @@ -367,7 +392,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); ::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { if (delegate_) { -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status { TfLiteGpuDelegateDelete(delegate_); gpu_data_in_.reset(); @@ -446,7 +471,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); ::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate( CalculatorContext* cc) { -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) // Configure and create the delegate. TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); options.compile_options.precision_loss_allowed = 1; @@ -466,15 +491,12 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); for (int d = 0; d < tensor->dims->size; ++d) { gpu_data_in_->elements *= tensor->dims->data[d]; } - // Input to model can be either RGB/RGBA only. - RET_CHECK_GE(tensor->dims->data[3], 3); - RET_CHECK_LE(tensor->dims->data[3], 4); + CHECK_GE(tensor->dims->data[3], 1); + CHECK_LE(tensor->dims->data[3], 4); + CHECK_NE(tensor->dims->data[3], 2); // Create and bind input buffer. - auto status = ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( - gpu_data_in_->elements, &gpu_data_in_->buffer); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } + RET_CHECK_CALL(::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( + gpu_data_in_->elements, &gpu_data_in_->buffer)); RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor( delegate_, gpu_data_in_->buffer.id(), interpreter_->inputs()[0]), // First tensor only @@ -496,12 +518,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); // Create and bind output buffers. interpreter_->SetAllowBufferHandleOutput(true); for (int i = 0; i < gpu_data_out_.size(); ++i) { - using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; - auto status = CreateReadWriteShaderStorageBuffer( - gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } + RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer)); RET_CHECK_EQ( TfLiteGpuDelegateBindBufferToTensor( delegate_, gpu_data_out_[i]->buffer.id(), output_indices[i]), @@ -511,14 +529,15 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); // Must call this last. RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk); -#endif // __ANDROID__ +#endif // OpenGL #if defined(__APPLE__) && !TARGET_OS_OSX // iOS // Configure and create the delegate. GpuDelegateOptions options; options.allow_precision_loss = false; // Must match converter, F=float/T=half - options.wait_type = GpuDelegateOptions::WaitType::kActive; + options.wait_type = GpuDelegateOptions::WaitType::kPassive; if (!delegate_) delegate_ = TFLGpuDelegateCreate(&options); + id device = gpu_helper_.mtlDevice; if (gpu_input_) { // Get input image sizes. @@ -539,11 +558,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); LOG(WARNING) << "Please ensure input GPU tensor is 4 channels."; } // Create and bind input buffer. - id device = gpu_helper_.mtlDevice; gpu_data_in_->buffer = [device newBufferWithLength:gpu_data_in_->elements * sizeof(float) options:MTLResourceStorageModeShared]; - // Must call this before TFLGpuDelegateBindMetalBufferToTensor. RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk); RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( delegate_, @@ -561,12 +578,33 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); gpu_data_out_[i]->elements = 1; // TODO handle *2 properly on some dialated models for (int d = 0; d < tensor->dims->size; ++d) { - gpu_data_out_[i]->elements *= tensor->dims->data[d]; + // Pad each dim for BHWC4 conversion inside delegate. + gpu_data_out_[i]->elements *= RoundUp(tensor->dims->data[d], 4); + } + // Save dimensions for reshaping back later. + gpu_data_out_[i]->shape.b = tensor->dims->data[0]; + switch (tensor->dims->size) { + case 2: + gpu_data_out_[i]->shape.h = 1; + gpu_data_out_[i]->shape.w = 1; + gpu_data_out_[i]->shape.c = tensor->dims->data[1]; + break; + case 3: + gpu_data_out_[i]->shape.h = 1; + gpu_data_out_[i]->shape.w = tensor->dims->data[1]; + gpu_data_out_[i]->shape.c = tensor->dims->data[2]; + break; + case 4: + gpu_data_out_[i]->shape.h = tensor->dims->data[1]; + gpu_data_out_[i]->shape.w = tensor->dims->data[2]; + gpu_data_out_[i]->shape.c = tensor->dims->data[3]; + break; + default: + return mediapipe::InternalError("Unsupported tensor shape."); } } // Create and bind output buffers. interpreter_->SetAllowBufferHandleOutput(true); - id device = gpu_helper_.mtlDevice; for (int i = 0; i < gpu_data_out_.size(); ++i) { gpu_data_out_[i]->buffer = [device newBufferWithLength:gpu_data_out_[i]->elements * sizeof(float) @@ -575,6 +613,14 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); delegate_, output_indices[i], gpu_data_out_[i]->buffer), true); } + // Create converter for GPU output. + converter_from_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device + isFloat16:false + convertToPBHWC4:false]; + if (converter_from_BPHWC4_ == nil) { + return mediapipe::InternalError( + "Error initializating output buffer converter"); + } } #endif // iOS diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc index 057a1831b..8e790b00a 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc @@ -18,6 +18,7 @@ #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.pb.h" +#include "mediapipe/calculators/tflite/util.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/detection.pb.h" @@ -26,28 +27,61 @@ #include "mediapipe/framework/port/ret_check.h" #include "tensorflow/lite/interpreter.h" -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #include "mediapipe/gpu/gl_calculator_helper.h" #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" -#endif // ANDROID +#endif // !MEDIAPIPE_DISABLE_GPU -#if defined(__ANDROID__) -using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; -using ::tflite::gpu::gl::GlBuffer; -using ::tflite::gpu::gl::GlProgram; -using ::tflite::gpu::gl::GlShader; -#endif // ANDROID +#if defined(__APPLE__) && !TARGET_OS_OSX // iOS +#import +#import +#import -namespace mediapipe { +#import "mediapipe/gpu/MPPMetalHelper.h" +#include "mediapipe/gpu/MPPMetalUtil.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "tensorflow/lite/delegates/gpu/metal_delegate.h" +#endif // iOS namespace { constexpr int kNumInputTensorsWithAnchors = 3; constexpr int kNumCoordsPerBox = 4; +} // namespace + +namespace mediapipe { + +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) +using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; +using ::tflite::gpu::gl::GlShader; +#endif + +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) +typedef ::tflite::gpu::gl::GlBuffer GpuTensor; +typedef ::tflite::gpu::gl::GlProgram GpuProgram; +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS +typedef id GpuTensor; +typedef id GpuProgram; +#endif + +namespace { + +#if !defined(MEDIAPIPE_DISABLE_GPU) +struct GPUData { + GpuProgram decode_program; + GpuProgram score_program; + GpuTensor decoded_boxes_buffer; + GpuTensor raw_boxes_buffer; + GpuTensor raw_anchors_buffer; + GpuTensor scored_boxes_buffer; + GpuTensor raw_scores_buffer; +}; +#endif + void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes, std::vector* anchors) { anchors->clear(); @@ -88,7 +122,7 @@ void ConvertAnchorsToRawValues(const std::vector& anchors, // optional to pass in a third tensor for anchors (e.g. for SSD // models) depend on the outputs of the detection model. The size // of anchor tensor must be (num_boxes * 4). -// TENSORS_GPU - vector of GlBuffer. +// TENSORS_GPU - vector of GlBuffer of MTLBuffer. // Output: // DETECTIONS - Result MediaPipe detections. // @@ -126,7 +160,7 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase { std::vector* output_detections); ::mediapipe::Status LoadOptions(CalculatorContext* cc); - ::mediapipe::Status GlSetup(CalculatorContext* cc); + ::mediapipe::Status GpuInit(CalculatorContext* cc); ::mediapipe::Status DecodeBoxes(const float* raw_boxes, const std::vector& anchors, std::vector* boxes); @@ -146,15 +180,12 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase { std::vector anchors_; bool side_packet_anchors_{}; -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) mediapipe::GlCalculatorHelper gpu_helper_; - std::unique_ptr decode_program_; - std::unique_ptr score_program_; - std::unique_ptr decoded_boxes_buffer_; - std::unique_ptr raw_boxes_buffer_; - std::unique_ptr raw_anchors_buffer_; - std::unique_ptr scored_boxes_buffer_; - std::unique_ptr raw_scores_buffer_; + std::unique_ptr gpu_data_; +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + MPPMetalHelper* gpu_helper_ = nullptr; + std::unique_ptr gpu_data_; #endif bool gpu_input_ = false; @@ -167,15 +198,18 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); + bool use_gpu = false; + if (cc->Inputs().HasTag("TENSORS")) { cc->Inputs().Tag("TENSORS").Set>(); } -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) if (cc->Inputs().HasTag("TENSORS_GPU")) { - cc->Inputs().Tag("TENSORS_GPU").Set>(); + cc->Inputs().Tag("TENSORS_GPU").Set>(); + use_gpu |= true; } -#endif +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag("DETECTIONS")) { cc->Outputs().Tag("DETECTIONS").Set>(); @@ -187,9 +221,13 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); } } -#if defined(__ANDROID__) - MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); + if (use_gpu) { +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); #endif + } return ::mediapipe::OkStatus(); } @@ -200,8 +238,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); if (cc->Inputs().HasTag("TENSORS_GPU")) { gpu_input_ = true; -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; + RET_CHECK(gpu_helper_); #endif } @@ -209,7 +250,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); side_packet_anchors_ = cc->InputSidePackets().HasTag("ANCHORS"); if (gpu_input_) { - MP_RETURN_IF_ERROR(GlSetup(cc)); + MP_RETURN_IF_ERROR(GpuInit(cc)); } return ::mediapipe::OkStatus(); @@ -228,7 +269,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get())); } else { MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get())); - } // if gpu_input_ + } // Output if (cc->Outputs().HasTag("DETECTIONS")) { @@ -245,7 +286,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); const auto& input_tensors = cc->Inputs().Tag("TENSORS").Get>(); - if (input_tensors.size() == 2) { + if (input_tensors.size() == 2 || + input_tensors.size() == kNumInputTensorsWithAnchors) { // Postprocessing on CPU for model without postprocessing op. E.g. output // raw score tensor and box tensor. Anchor decoding will be handled below. const TfLiteTensor* raw_box_tensor = &input_tensors[0]; @@ -358,13 +400,84 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); } ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( CalculatorContext* cc, std::vector* output_detections) { -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) const auto& input_tensors = - cc->Inputs().Tag("TENSORS_GPU").Get>(); + cc->Inputs().Tag("TENSORS_GPU").Get>(); + RET_CHECK_GE(input_tensors.size(), 2); + + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, &input_tensors, &cc, + &output_detections]() + -> ::mediapipe::Status { + // Copy inputs. + RET_CHECK_CALL(CopyBuffer(input_tensors[0], gpu_data_->raw_boxes_buffer)); + RET_CHECK_CALL(CopyBuffer(input_tensors[1], gpu_data_->raw_scores_buffer)); + if (!anchors_init_) { + if (side_packet_anchors_) { + CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); + const auto& anchors = + cc->InputSidePackets().Tag("ANCHORS").Get>(); + std::vector raw_anchors(num_boxes_ * kNumCoordsPerBox); + ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data()); + RET_CHECK_CALL(gpu_data_->raw_anchors_buffer.Write( + absl::MakeSpan(raw_anchors))); + } else { + CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); + RET_CHECK_CALL( + CopyBuffer(input_tensors[2], gpu_data_->raw_anchors_buffer)); + } + anchors_init_ = true; + } + + // Run shaders. + // Decode boxes. + RET_CHECK_CALL(gpu_data_->decoded_boxes_buffer.BindToIndex(0)); + RET_CHECK_CALL(gpu_data_->raw_boxes_buffer.BindToIndex(1)); + RET_CHECK_CALL(gpu_data_->raw_anchors_buffer.BindToIndex(2)); + const tflite::gpu::uint3 decode_workgroups = {num_boxes_, 1, 1}; + RET_CHECK_CALL(gpu_data_->decode_program.Dispatch(decode_workgroups)); + + // Score boxes. + RET_CHECK_CALL(gpu_data_->scored_boxes_buffer.BindToIndex(0)); + RET_CHECK_CALL(gpu_data_->raw_scores_buffer.BindToIndex(1)); + const tflite::gpu::uint3 score_workgroups = {num_boxes_, 1, 1}; + RET_CHECK_CALL(gpu_data_->score_program.Dispatch(score_workgroups)); + + // Copy decoded boxes from GPU to CPU. + std::vector boxes(num_boxes_ * num_coords_); + RET_CHECK_CALL(gpu_data_->decoded_boxes_buffer.Read(absl::MakeSpan(boxes))); + std::vector score_class_id_pairs(num_boxes_ * 2); + RET_CHECK_CALL(gpu_data_->scored_boxes_buffer.Read( + absl::MakeSpan(score_class_id_pairs))); + + // TODO: b/138851969. Is it possible to output a float vector + // for score and an int vector for class so that we can avoid copying twice? + std::vector detection_scores(num_boxes_); + std::vector detection_classes(num_boxes_); + for (int i = 0; i < num_boxes_; ++i) { + detection_scores[i] = score_class_id_pairs[i * 2]; + detection_classes[i] = static_cast(score_class_id_pairs[i * 2 + 1]); + } + MP_RETURN_IF_ERROR( + ConvertToDetections(boxes.data(), detection_scores.data(), + detection_classes.data(), output_detections)); + + return ::mediapipe::OkStatus(); + })); +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + + const auto& input_tensors = + cc->Inputs().Tag("TENSORS_GPU").Get>(); + RET_CHECK_GE(input_tensors.size(), 2); // Copy inputs. - tflite::gpu::gl::CopyBuffer(input_tensors[0], *raw_boxes_buffer_.get()); - tflite::gpu::gl::CopyBuffer(input_tensors[1], *raw_scores_buffer_.get()); + [MPPMetalUtil blitMetalBufferTo:gpu_data_->raw_boxes_buffer + from:input_tensors[0] + blocking:true + commandBuffer:[gpu_helper_ commandBuffer]]; + [MPPMetalUtil blitMetalBufferTo:gpu_data_->raw_scores_buffer + from:input_tensors[1] + blocking:true + commandBuffer:[gpu_helper_ commandBuffer]]; if (!anchors_init_) { if (side_packet_anchors_) { CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); @@ -372,47 +485,65 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); cc->InputSidePackets().Tag("ANCHORS").Get>(); std::vector raw_anchors(num_boxes_ * kNumCoordsPerBox); ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data()); - raw_anchors_buffer_->Write(absl::MakeSpan(raw_anchors)); + memcpy([gpu_data_->raw_anchors_buffer contents], raw_anchors.data(), + raw_anchors.size() * sizeof(float)); } else { - CHECK_EQ(input_tensors.size(), 3); - tflite::gpu::gl::CopyBuffer(input_tensors[2], *raw_anchors_buffer_.get()); + RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); + [MPPMetalUtil blitMetalBufferTo:gpu_data_->raw_anchors_buffer + from:input_tensors[2] + blocking:true + commandBuffer:[gpu_helper_ commandBuffer]]; } anchors_init_ = true; } // Run shaders. - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &input_tensors]() -> ::mediapipe::Status { - // Decode boxes. - decoded_boxes_buffer_->BindToIndex(0); - raw_boxes_buffer_->BindToIndex(1); - raw_anchors_buffer_->BindToIndex(2); - const tflite::gpu::uint3 decode_workgroups = {num_boxes_, 1, 1}; - decode_program_->Dispatch(decode_workgroups); - - // Score boxes. - scored_boxes_buffer_->BindToIndex(0); - raw_scores_buffer_->BindToIndex(1); - const tflite::gpu::uint3 score_workgroups = {num_boxes_, 1, 1}; - score_program_->Dispatch(score_workgroups); - - return ::mediapipe::OkStatus(); - })); + { + id command_buffer = [gpu_helper_ commandBuffer]; + command_buffer.label = @"TfLiteDecodeBoxes"; + id decode_command = + [command_buffer computeCommandEncoder]; + [decode_command setComputePipelineState:gpu_data_->decode_program]; + [decode_command setBuffer:gpu_data_->decoded_boxes_buffer + offset:0 + atIndex:0]; + [decode_command setBuffer:gpu_data_->raw_boxes_buffer offset:0 atIndex:1]; + [decode_command setBuffer:gpu_data_->raw_anchors_buffer offset:0 atIndex:2]; + MTLSize decode_threads_per_group = MTLSizeMake(1, 1, 1); + MTLSize decode_threadgroups = MTLSizeMake(num_boxes_, 1, 1); + [decode_command dispatchThreadgroups:decode_threadgroups + threadsPerThreadgroup:decode_threads_per_group]; + [decode_command endEncoding]; + [command_buffer commit]; + [command_buffer waitUntilCompleted]; + } + { + id command_buffer = [gpu_helper_ commandBuffer]; + command_buffer.label = @"TfLiteScoreBoxes"; + id score_command = + [command_buffer computeCommandEncoder]; + [score_command setComputePipelineState:gpu_data_->score_program]; + [score_command setBuffer:gpu_data_->scored_boxes_buffer offset:0 atIndex:0]; + [score_command setBuffer:gpu_data_->raw_scores_buffer offset:0 atIndex:1]; + MTLSize score_threads_per_group = MTLSizeMake(1, num_classes_, 1); + MTLSize score_threadgroups = MTLSizeMake(num_boxes_, 1, 1); + [score_command dispatchThreadgroups:score_threadgroups + threadsPerThreadgroup:score_threads_per_group]; + [score_command endEncoding]; + [command_buffer commit]; + [command_buffer waitUntilCompleted]; + } // Copy decoded boxes from GPU to CPU. std::vector boxes(num_boxes_ * num_coords_); - auto status = decoded_boxes_buffer_->Read(absl::MakeSpan(boxes)); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } + memcpy(boxes.data(), [gpu_data_->decoded_boxes_buffer contents], + num_boxes_ * num_coords_ * sizeof(float)); std::vector score_class_id_pairs(num_boxes_ * 2); - status = scored_boxes_buffer_->Read(absl::MakeSpan(score_class_id_pairs)); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } + memcpy(score_class_id_pairs.data(), [gpu_data_->scored_boxes_buffer contents], + num_boxes_ * 2 * sizeof(float)); - // TODO: b/138851969. Is it possible to output a float vector - // for score and an int vector for class so that we can avoid copying twice? + // Output detections. + // TODO Adjust shader to avoid copying shader output twice. std::vector detection_scores(num_boxes_); std::vector detection_classes(num_boxes_); for (int i = 0; i < num_boxes_; ++i) { @@ -422,25 +553,20 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); MP_RETURN_IF_ERROR(ConvertToDetections(boxes.data(), detection_scores.data(), detection_classes.data(), output_detections)); + #else LOG(ERROR) << "GPU input on non-Android not supported yet."; -#endif // defined(__ANDROID__) +#endif return ::mediapipe::OkStatus(); } ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close( CalculatorContext* cc) { -#if defined(__ANDROID__) - gpu_helper_.RunInGlContext([this] { - decode_program_.reset(); - score_program_.reset(); - decoded_boxes_buffer_.reset(); - raw_boxes_buffer_.reset(); - raw_anchors_buffer_.reset(); - scored_boxes_buffer_.reset(); - raw_scores_buffer_.reset(); - }); -#endif // __ANDROID__ +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) + gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); }); +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + gpu_data_.reset(); +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } @@ -530,6 +656,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); } } } + return ::mediapipe::OkStatus(); } @@ -586,12 +713,16 @@ Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection( return detection; } -::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GlSetup( +::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit( CalculatorContext* cc) { -#if defined(__ANDROID__) - // A shader to decode detection boxes. - const std::string decode_src = absl::Substitute( - R"( #version 310 es +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() + -> ::mediapipe::Status { + gpu_data_ = absl::make_unique(); + + // A shader to decode detection boxes. + const std::string decode_src = absl::Substitute( + R"( #version 310 es layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; @@ -665,7 +796,7 @@ void main() { if (num_keypoints > int(0)){ for (int k = 0; k < num_keypoints; ++k) { int kp_offset = - int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt; + int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt; float kp_y, kp_x; if (reverse_output_order == int(0)) { kp_y = raw_boxes.data[kp_offset + int(0)]; @@ -679,55 +810,37 @@ void main() { } } })", - options_.num_coords(), // box xywh - options_.reverse_output_order() ? 1 : 0, - options_.apply_exponential_on_box_size() ? 1 : 0, - options_.box_coord_offset(), options_.num_keypoints(), - options_.keypoint_coord_offset(), options_.num_values_per_keypoint()); + options_.num_coords(), // box xywh + options_.reverse_output_order() ? 1 : 0, + options_.apply_exponential_on_box_size() ? 1 : 0, + options_.box_coord_offset(), options_.num_keypoints(), + options_.keypoint_coord_offset(), options_.num_values_per_keypoint()); - // Shader program - GlShader decode_shader; - auto status = - GlShader::CompileShader(GL_COMPUTE_SHADER, decode_src, &decode_shader); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - decode_program_ = absl::make_unique(); - status = GlProgram::CreateWithShader(decode_shader, decode_program_.get()); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - // Outputs - size_t decoded_boxes_length = num_boxes_ * num_coords_; - decoded_boxes_buffer_ = absl::make_unique(); - status = CreateReadWriteShaderStorageBuffer( - decoded_boxes_length, decoded_boxes_buffer_.get()); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - // Inputs - size_t raw_boxes_length = num_boxes_ * num_coords_; - raw_boxes_buffer_ = absl::make_unique(); - status = CreateReadWriteShaderStorageBuffer(raw_boxes_length, - raw_boxes_buffer_.get()); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - size_t raw_anchors_length = num_boxes_ * kNumCoordsPerBox; - raw_anchors_buffer_ = absl::make_unique(); - status = CreateReadWriteShaderStorageBuffer(raw_anchors_length, - raw_anchors_buffer_.get()); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - // Parameters - glUseProgram(decode_program_->id()); - glUniform4f(0, options_.x_scale(), options_.y_scale(), options_.w_scale(), - options_.h_scale()); + // Shader program + GlShader decode_shader; + RET_CHECK_CALL( + GlShader::CompileShader(GL_COMPUTE_SHADER, decode_src, &decode_shader)); + RET_CHECK_CALL(GpuProgram::CreateWithShader(decode_shader, + &gpu_data_->decode_program)); + // Outputs + size_t decoded_boxes_length = num_boxes_ * num_coords_; + RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + decoded_boxes_length, &gpu_data_->decoded_boxes_buffer)); + // Inputs + size_t raw_boxes_length = num_boxes_ * num_coords_; + RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + raw_boxes_length, &gpu_data_->raw_boxes_buffer)); + size_t raw_anchors_length = num_boxes_ * kNumCoordsPerBox; + RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + raw_anchors_length, &gpu_data_->raw_anchors_buffer)); + // Parameters + glUseProgram(gpu_data_->decode_program.id()); + glUniform4f(0, options_.x_scale(), options_.y_scale(), options_.w_scale(), + options_.h_scale()); - // A shader to score detection boxes. - const std::string score_src = absl::Substitute( - R"( #version 310 es + // A shader to score detection boxes. + const std::string score_src = absl::Substitute( + R"( #version 310 es layout(local_size_x = 1, local_size_y = $0, local_size_z = 1) in; @@ -781,6 +894,228 @@ void main() { scored_boxes.data[g_idx * uint(2) + uint(0)] = max_score; scored_boxes.data[g_idx * uint(2) + uint(1)] = max_class; } +})", + num_classes_, options_.sigmoid_score() ? 1 : 0, + options_.has_score_clipping_thresh() ? 1 : 0, + options_.has_score_clipping_thresh() ? options_.score_clipping_thresh() + : 0, + !ignore_classes_.empty() ? 1 : 0); + + // # filter classes supported is hardware dependent. + int max_wg_size; // typically <= 1024 + glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1, + &max_wg_size); // y-dim + CHECK_LT(num_classes_, max_wg_size) + << "# classes must be < " << max_wg_size; + // TODO support better filtering. + CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed"; + + // Shader program + GlShader score_shader; + RET_CHECK_CALL( + GlShader::CompileShader(GL_COMPUTE_SHADER, score_src, &score_shader)); + RET_CHECK_CALL( + GpuProgram::CreateWithShader(score_shader, &gpu_data_->score_program)); + // Outputs + size_t scored_boxes_length = num_boxes_ * 2; // score, class + RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + scored_boxes_length, &gpu_data_->scored_boxes_buffer)); + // Inputs + size_t raw_scores_length = num_boxes_ * num_classes_; + RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + raw_scores_length, &gpu_data_->raw_scores_buffer)); + + return ::mediapipe::OkStatus(); + })); + +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + // TODO consolidate Metal and OpenGL shaders via vulkan. + + gpu_data_ = absl::make_unique(); + id device = gpu_helper_.mtlDevice; + + // A shader to decode detection boxes. + std::string decode_src = absl::Substitute( + R"( +#include + +using namespace metal; + +kernel void decodeKernel( + device float* boxes [[ buffer(0) ]], + device float* raw_boxes [[ buffer(1) ]], + device float* raw_anchors [[ buffer(2) ]], + uint2 gid [[ thread_position_in_grid ]]) { + + uint num_coords = uint($0); + int reverse_output_order = int($1); + int apply_exponential = int($2); + int box_coord_offset = int($3); + int num_keypoints = int($4); + int keypt_coord_offset = int($5); + int num_values_per_keypt = int($6); +)", + options_.num_coords(), // box xywh + options_.reverse_output_order() ? 1 : 0, + options_.apply_exponential_on_box_size() ? 1 : 0, + options_.box_coord_offset(), options_.num_keypoints(), + options_.keypoint_coord_offset(), options_.num_values_per_keypoint()); + decode_src += absl::Substitute( + R"( + float4 scale = float4(($0),($1),($2),($3)); +)", + options_.x_scale(), options_.y_scale(), options_.w_scale(), + options_.h_scale()); + decode_src += R"( + uint g_idx = gid.x; + uint box_offset = g_idx * num_coords + uint(box_coord_offset); + uint anchor_offset = g_idx * uint(4); // check kNumCoordsPerBox + + float y_center, x_center, h, w; + + if (reverse_output_order == int(0)) { + y_center = raw_boxes[box_offset + uint(0)]; + x_center = raw_boxes[box_offset + uint(1)]; + h = raw_boxes[box_offset + uint(2)]; + w = raw_boxes[box_offset + uint(3)]; + } else { + x_center = raw_boxes[box_offset + uint(0)]; + y_center = raw_boxes[box_offset + uint(1)]; + w = raw_boxes[box_offset + uint(2)]; + h = raw_boxes[box_offset + uint(3)]; + } + + float anchor_yc = raw_anchors[anchor_offset + uint(0)]; + float anchor_xc = raw_anchors[anchor_offset + uint(1)]; + float anchor_h = raw_anchors[anchor_offset + uint(2)]; + float anchor_w = raw_anchors[anchor_offset + uint(3)]; + + x_center = x_center / scale.x * anchor_w + anchor_xc; + y_center = y_center / scale.y * anchor_h + anchor_yc; + + if (apply_exponential == int(1)) { + h = exp(h / scale.w) * anchor_h; + w = exp(w / scale.z) * anchor_w; + } else { + h = (h / scale.w) * anchor_h; + w = (w / scale.z) * anchor_w; + } + + float ymin = y_center - h / 2.0; + float xmin = x_center - w / 2.0; + float ymax = y_center + h / 2.0; + float xmax = x_center + w / 2.0; + + boxes[box_offset + uint(0)] = ymin; + boxes[box_offset + uint(1)] = xmin; + boxes[box_offset + uint(2)] = ymax; + boxes[box_offset + uint(3)] = xmax; + + if (num_keypoints > int(0)){ + for (int k = 0; k < num_keypoints; ++k) { + int kp_offset = + int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt; + float kp_y, kp_x; + if (reverse_output_order == int(0)) { + kp_y = raw_boxes[kp_offset + int(0)]; + kp_x = raw_boxes[kp_offset + int(1)]; + } else { + kp_x = raw_boxes[kp_offset + int(0)]; + kp_y = raw_boxes[kp_offset + int(1)]; + } + boxes[kp_offset + int(0)] = kp_x / scale.x * anchor_w + anchor_xc; + boxes[kp_offset + int(1)] = kp_y / scale.y * anchor_h + anchor_yc; + } + } +})"; + + { + // Shader program + NSString* library_source = + [NSString stringWithUTF8String:decode_src.c_str()]; + NSError* error = nil; + id library = [device newLibraryWithSource:library_source + options:nullptr + error:&error]; + RET_CHECK(library != nil) << "Couldn't create shader library " + << [[error localizedDescription] UTF8String]; + id kernel_func = nil; + kernel_func = [library newFunctionWithName:@"decodeKernel"]; + RET_CHECK(kernel_func != nil) << "Couldn't create kernel function."; + gpu_data_->decode_program = + [device newComputePipelineStateWithFunction:kernel_func error:&error]; + RET_CHECK(gpu_data_->decode_program != nil) + << "Couldn't create pipeline state " + << [[error localizedDescription] UTF8String]; + // Outputs + size_t decoded_boxes_length = num_boxes_ * num_coords_ * sizeof(float); + gpu_data_->decoded_boxes_buffer = + [device newBufferWithLength:decoded_boxes_length + options:MTLResourceStorageModeShared]; + // Inputs + size_t raw_boxes_length = num_boxes_ * num_coords_ * sizeof(float); + gpu_data_->raw_boxes_buffer = + [device newBufferWithLength:raw_boxes_length + options:MTLResourceStorageModeShared]; + size_t raw_anchors_length = num_boxes_ * kNumCoordsPerBox * sizeof(float); + gpu_data_->raw_anchors_buffer = + [device newBufferWithLength:raw_anchors_length + options:MTLResourceStorageModeShared]; + } + + // A shader to score detection boxes. + const std::string score_src = absl::Substitute( + R"( +#include + +using namespace metal; + +float optional_sigmoid(float x) { + int apply_sigmoid = int($1); + int apply_clipping_thresh = int($2); + float clipping_thresh = float($3); + if (apply_sigmoid == int(0)) return x; + if (apply_clipping_thresh == int(1)) { + x = clamp(x, -clipping_thresh, clipping_thresh); + } + x = 1.0 / (1.0 + exp(-x)); + return x; +} + +kernel void scoreKernel( + device float* scored_boxes [[ buffer(0) ]], + device float* raw_scores [[ buffer(1) ]], + uint2 tid [[ thread_position_in_threadgroup ]], + uint2 gid [[ thread_position_in_grid ]]) { + + uint num_classes = uint($0); + int apply_sigmoid = int($1); + int apply_clipping_thresh = int($2); + float clipping_thresh = float($3); + int ignore_class_0 = int($4); + + uint g_idx = gid.x; // box idx + uint s_idx = tid.y; // score/class idx + + // load all scores into shared memory + threadgroup float local_scores[$0]; + float score = raw_scores[g_idx * num_classes + s_idx]; + local_scores[s_idx] = optional_sigmoid(score); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // find max score in shared memory + if (s_idx == uint(0)) { + float max_score = -FLT_MAX; + float max_class = -1.0; + for (int i=ignore_class_0; i max_score) { + max_score = local_scores[i]; + max_class = float(i); + } + } + scored_boxes[g_idx * uint(2) + uint(0)] = max_score; + scored_boxes[g_idx * uint(2) + uint(1)] = max_class; + } })", num_classes_, options_.sigmoid_score() ? 1 : 0, options_.has_score_clipping_thresh() ? 1 : 0, @@ -788,42 +1123,44 @@ void main() { : 0, ignore_classes_.size() ? 1 : 0); - // # filter classes supported is hardware dependent. - int max_wg_size; // typically <= 1024 - glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1, &max_wg_size); // y-dim - CHECK_LT(num_classes_, max_wg_size) << "# classes must be < " << max_wg_size; // TODO support better filtering. CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed"; - // Shader program - GlShader score_shader; - status = GlShader::CompileShader(GL_COMPUTE_SHADER, score_src, &score_shader); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - score_program_ = absl::make_unique(); - status = GlProgram::CreateWithShader(score_shader, score_program_.get()); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - // Outputs - size_t scored_boxes_length = num_boxes_ * 2; // score, class - scored_boxes_buffer_ = absl::make_unique(); - status = CreateReadWriteShaderStorageBuffer( - scored_boxes_length, scored_boxes_buffer_.get()); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - // Inputs - size_t raw_scores_length = num_boxes_ * num_classes_; - raw_scores_buffer_ = absl::make_unique(); - status = CreateReadWriteShaderStorageBuffer(raw_scores_length, - raw_scores_buffer_.get()); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); + { + // Shader program + NSString* library_source = + [NSString stringWithUTF8String:score_src.c_str()]; + NSError* error = nil; + id library = [device newLibraryWithSource:library_source + options:nullptr + error:&error]; + RET_CHECK(library != nil) << "Couldn't create shader library " + << [[error localizedDescription] UTF8String]; + id kernel_func = nil; + kernel_func = [library newFunctionWithName:@"scoreKernel"]; + RET_CHECK(kernel_func != nil) << "Couldn't create kernel function."; + gpu_data_->score_program = + [device newComputePipelineStateWithFunction:kernel_func error:&error]; + RET_CHECK(gpu_data_->score_program != nil) + << "Couldn't create pipeline state " + << [[error localizedDescription] UTF8String]; + // Outputs + size_t scored_boxes_length = num_boxes_ * 2 * sizeof(float); // score,class + gpu_data_->scored_boxes_buffer = + [device newBufferWithLength:scored_boxes_length + options:MTLResourceStorageModeShared]; + // Inputs + size_t raw_scores_length = num_boxes_ * num_classes_ * sizeof(float); + gpu_data_->raw_scores_buffer = + [device newBufferWithLength:raw_scores_length + options:MTLResourceStorageModeShared]; + // # filter classes supported is hardware dependent. + int max_wg_size = gpu_data_->score_program.maxTotalThreadsPerThreadgroup; + CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size; } -#endif // defined(__ANDROID__) +#endif // __ANDROID__ or iOS + return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc index 3b938a1cd..1d646e4a3 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc @@ -96,7 +96,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator); options_.has_input_image_width()) << "Must provide input with/height for getting normalized landmarks."; } - if (cc->Outputs().HasTag("LANDMARKS") && options_.flip_vertically()) { + if (cc->Outputs().HasTag("LANDMARKS") && + (options_.flip_vertically() || options_.flip_horizontally())) { RET_CHECK(options_.has_input_image_height() && options_.has_input_image_width()) << "Must provide input with/height for using flip_vertically option " @@ -133,7 +134,12 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator); for (int ld = 0; ld < num_landmarks_; ++ld) { const int offset = ld * num_dimensions; Landmark landmark; - landmark.set_x(raw_landmarks[offset]); + + if (options_.flip_horizontally()) { + landmark.set_x(options_.input_image_width() - raw_landmarks[offset]); + } else { + landmark.set_x(raw_landmarks[offset]); + } if (num_dimensions > 1) { if (options_.flip_vertically()) { landmark.set_y(options_.input_image_height() - diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto index 5f37e6238..3b6716c9c 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto @@ -40,6 +40,12 @@ message TfLiteTensorsToLandmarksCalculatorOptions { // representation has a bottom-left origin (e.g., in OpenGL). optional bool flip_vertically = 4 [default = false]; + // Whether the detection coordinates from the input tensors should be flipped + // horizontally (along the x-direction). This is useful, for example, when the + // input image is horizontally flipped in ImageTransformationCalculator + // beforehand. + optional bool flip_horizontally = 6 [default = false]; + // A value that z values should be divided by. optional float normalize_z = 5 [default = 1.0]; } diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc index 20aca8e30..16805a066 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc @@ -17,6 +17,7 @@ #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/calculators/tflite/util.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" @@ -27,7 +28,7 @@ #include "mediapipe/util/resource_util.h" #include "tensorflow/lite/interpreter.h" -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/shader_util.h" @@ -36,7 +37,7 @@ #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" -#endif // __ANDROID__ +#endif // !MEDIAPIPE_DISABLE_GPU namespace { constexpr int kWorkgroupSize = 8; // Block size for GPU shader. @@ -52,12 +53,14 @@ float Clamp(float val, float min, float max) { namespace mediapipe { -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) +using ::tflite::gpu::gl::CopyBuffer; +using ::tflite::gpu::gl::CreateReadWriteRgbaImageTexture; using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::GlBuffer; using ::tflite::gpu::gl::GlProgram; using ::tflite::gpu::gl::GlShader; -#endif // __ANDROID__ +#endif // !MEDIAPIPE_DISABLE_GPU // Converts TFLite tensors from a tflite segmentation model to an image mask. // @@ -126,13 +129,13 @@ class TfLiteTensorsToSegmentationCalculator : public CalculatorBase { int tensor_channels_ = 0; bool use_gpu_ = false; -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) mediapipe::GlCalculatorHelper gpu_helper_; std::unique_ptr mask_program_with_prev_; std::unique_ptr mask_program_no_prev_; std::unique_ptr tensor_buffer_; GLuint upsample_program_; -#endif // __ANDROID__ +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); @@ -142,6 +145,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); + bool use_gpu = false; + // Inputs CPU. if (cc->Inputs().HasTag("TENSORS")) { cc->Inputs().Tag("TENSORS").Set>(); @@ -154,32 +159,37 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); } // Inputs GPU. -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) if (cc->Inputs().HasTag("TENSORS_GPU")) { cc->Inputs().Tag("TENSORS_GPU").Set>(); + use_gpu |= true; } if (cc->Inputs().HasTag("PREV_MASK_GPU")) { cc->Inputs().Tag("PREV_MASK_GPU").Set(); + use_gpu |= true; } if (cc->Inputs().HasTag("REFERENCE_IMAGE_GPU")) { cc->Inputs().Tag("REFERENCE_IMAGE_GPU").Set(); + use_gpu |= true; } -#endif // __ANDROID__ +#endif // !MEDIAPIPE_DISABLE_GPU // Outputs. if (cc->Outputs().HasTag("MASK")) { cc->Outputs().Tag("MASK").Set(); } -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) if (cc->Outputs().HasTag("MASK_GPU")) { cc->Outputs().Tag("MASK_GPU").Set(); + use_gpu |= true; } -#endif // __ANDROID__ - -#if defined(__ANDROID__) - MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // __ANDROID__ +#endif // !MEDIAPIPE_DISABLE_GPU + if (use_gpu) { +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // !MEDIAPIPE_DISABLE_GPU + } return ::mediapipe::OkStatus(); } @@ -189,24 +199,23 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); if (cc->Inputs().HasTag("TENSORS_GPU")) { use_gpu_ = true; -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#endif // __ANDROID__ +#endif // !MEDIAPIPE_DISABLE_GPU } MP_RETURN_IF_ERROR(LoadOptions(cc)); if (use_gpu_) { -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) MP_RETURN_IF_ERROR( gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { MP_RETURN_IF_ERROR(InitGpu(cc)); return ::mediapipe::OkStatus(); })); #else - RET_CHECK_FAIL() - << "GPU processing on non-Android devices is not supported yet."; -#endif // __ANDROID__ + RET_CHECK_FAIL() << "GPU processing not enabled."; +#endif // !MEDIAPIPE_DISABLE_GPU } return ::mediapipe::OkStatus(); @@ -215,13 +224,13 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); ::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Process( CalculatorContext* cc) { if (use_gpu_) { -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) MP_RETURN_IF_ERROR( gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { MP_RETURN_IF_ERROR(ProcessGpu(cc)); return ::mediapipe::OkStatus(); })); -#endif // __ANDROID__ +#endif // !MEDIAPIPE_DISABLE_GPU } else { MP_RETURN_IF_ERROR(ProcessCpu(cc)); } @@ -231,7 +240,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); ::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Close( CalculatorContext* cc) { -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) gpu_helper_.RunInGlContext([this] { if (upsample_program_) glDeleteProgram(upsample_program_); upsample_program_ = 0; @@ -239,7 +248,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); mask_program_no_prev_.reset(); tensor_buffer_.reset(); }); -#endif // __ANDROID__ +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } @@ -358,7 +367,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); if (cc->Inputs().Tag("TENSORS_GPU").IsEmpty()) { return ::mediapipe::OkStatus(); } -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) // Get input streams. const auto& input_tensors = cc->Inputs().Tag("TENSORS_GPU").Get>(); @@ -379,9 +388,9 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); // Create initial working mask texture. ::tflite::gpu::gl::GlTexture small_mask_texture; - ::tflite::gpu::gl::CreateReadWriteRgbaImageTexture( + RET_CHECK_CALL(CreateReadWriteRgbaImageTexture( tflite::gpu::DataType::UINT8, // GL_RGBA8 - {tensor_width_, tensor_height_}, &small_mask_texture); + {tensor_width_, tensor_height_}, &small_mask_texture)); // Get input previous mask. auto input_mask_texture = has_prev_mask @@ -389,7 +398,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); : mediapipe::GlTexture(); // Copy input tensor. - tflite::gpu::gl::CopyBuffer(input_tensors[0], *tensor_buffer_); + RET_CHECK_CALL(CopyBuffer(input_tensors[0], *tensor_buffer_)); // Run shader, process mask tensor. // Run softmax over tensor output and blend with previous mask. @@ -397,18 +406,18 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); const int output_index = 0; glBindImageTexture(output_index, small_mask_texture.id(), 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_RGBA8); - tensor_buffer_->BindToIndex(2); + RET_CHECK_CALL(tensor_buffer_->BindToIndex(2)); const tflite::gpu::uint3 workgroups = { NumGroups(tensor_width_, kWorkgroupSize), NumGroups(tensor_height_, kWorkgroupSize), 1}; if (!has_prev_mask) { - mask_program_no_prev_->Dispatch(workgroups); + RET_CHECK_CALL(mask_program_no_prev_->Dispatch(workgroups)); } else { glActiveTexture(GL_TEXTURE1); glBindTexture(GL_TEXTURE_2D, input_mask_texture.name()); - mask_program_with_prev_->Dispatch(workgroups); + RET_CHECK_CALL(mask_program_with_prev_->Dispatch(workgroups)); glActiveTexture(GL_TEXTURE1); glBindTexture(GL_TEXTURE_2D, 0); } @@ -438,13 +447,13 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); // Cleanup input_mask_texture.Release(); output_texture.Release(); -#endif // __ANDROID__ +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } void TfLiteTensorsToSegmentationCalculator::GlRender() { -#if defined(__ANDROID__) +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -492,7 +501,7 @@ void TfLiteTensorsToSegmentationCalculator::GlRender() { glBindVertexArray(0); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // __ANDROID__ +#endif // !MEDIAPIPE_DISABLE_GPU } ::mediapipe::Status TfLiteTensorsToSegmentationCalculator::LoadOptions( @@ -516,14 +525,15 @@ void TfLiteTensorsToSegmentationCalculator::GlRender() { ::mediapipe::Status TfLiteTensorsToSegmentationCalculator::InitGpu( CalculatorContext* cc) { -#if defined(__ANDROID__) - - // A shader to process a segmentation tensor into an output mask, - // and use an optional previous mask as input. - // Currently uses 4 channels for output, - // and sets both R and A channels as mask value. - const std::string shader_src_template = - R"( #version 310 es +#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() + -> ::mediapipe::Status { + // A shader to process a segmentation tensor into an output mask, + // and use an optional previous mask as input. + // Currently uses 4 channels for output, + // and sets both R and A channels as mask value. + const std::string shader_src_template = + R"( #version 310 es layout(local_size_x = $0, local_size_y = $0, local_size_z = 1) in; @@ -589,76 +599,60 @@ void main() { imageStore(output_texture, output_coordinate, out_value); })"; - const std::string shader_src_no_previous = absl::Substitute( - shader_src_template, kWorkgroupSize, options_.output_layer_index(), - options_.combine_with_previous_ratio(), "", - options_.flip_vertically() ? "out_height - gid.y - 1" : "gid.y"); - const std::string shader_src_with_previous = absl::Substitute( - shader_src_template, kWorkgroupSize, options_.output_layer_index(), - options_.combine_with_previous_ratio(), "#define READ_PREVIOUS", - options_.flip_vertically() ? "out_height - gid.y - 1" : "gid.y"); + const std::string shader_src_no_previous = absl::Substitute( + shader_src_template, kWorkgroupSize, options_.output_layer_index(), + options_.combine_with_previous_ratio(), "", + options_.flip_vertically() ? "out_height - gid.y - 1" : "gid.y"); + const std::string shader_src_with_previous = absl::Substitute( + shader_src_template, kWorkgroupSize, options_.output_layer_index(), + options_.combine_with_previous_ratio(), "#define READ_PREVIOUS", + options_.flip_vertically() ? "out_height - gid.y - 1" : "gid.y"); - auto status = ::tflite::gpu::OkStatus(); + // Shader programs. + GlShader shader_without_previous; + RET_CHECK_CALL(GlShader::CompileShader( + GL_COMPUTE_SHADER, shader_src_no_previous, &shader_without_previous)); + mask_program_no_prev_ = absl::make_unique(); + RET_CHECK_CALL(GlProgram::CreateWithShader(shader_without_previous, + mask_program_no_prev_.get())); + GlShader shader_with_previous; + RET_CHECK_CALL(GlShader::CompileShader( + GL_COMPUTE_SHADER, shader_src_with_previous, &shader_with_previous)); + mask_program_with_prev_ = absl::make_unique(); + RET_CHECK_CALL(GlProgram::CreateWithShader(shader_with_previous, + mask_program_with_prev_.get())); - // Shader programs. - GlShader shader_without_previous; - status = GlShader::CompileShader(GL_COMPUTE_SHADER, shader_src_no_previous, - &shader_without_previous); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - mask_program_no_prev_ = absl::make_unique(); - status = GlProgram::CreateWithShader(shader_without_previous, - mask_program_no_prev_.get()); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - GlShader shader_with_previous; - status = GlShader::CompileShader(GL_COMPUTE_SHADER, shader_src_with_previous, - &shader_with_previous); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - mask_program_with_prev_ = absl::make_unique(); - status = GlProgram::CreateWithShader(shader_with_previous, - mask_program_with_prev_.get()); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } + // Buffer storage for input tensor. + size_t tensor_length = tensor_width_ * tensor_height_ * tensor_channels_; + tensor_buffer_ = absl::make_unique(); + RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + tensor_length, tensor_buffer_.get())); - // Buffer storage for input tensor. - size_t tensor_length = tensor_width_ * tensor_height_ * tensor_channels_; - tensor_buffer_ = absl::make_unique(); - status = CreateReadWriteShaderStorageBuffer(tensor_length, - tensor_buffer_.get()); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } + // Parameters. + glUseProgram(mask_program_with_prev_->id()); + glUniform2i(glGetUniformLocation(mask_program_with_prev_->id(), "out_size"), + tensor_width_, tensor_height_); + glUniform1i( + glGetUniformLocation(mask_program_with_prev_->id(), "input_texture"), + 1); + glUseProgram(mask_program_no_prev_->id()); + glUniform2i(glGetUniformLocation(mask_program_no_prev_->id(), "out_size"), + tensor_width_, tensor_height_); + glUniform1i( + glGetUniformLocation(mask_program_no_prev_->id(), "input_texture"), 1); - // Parameters. - glUseProgram(mask_program_with_prev_->id()); - glUniform2i(glGetUniformLocation(mask_program_with_prev_->id(), "out_size"), - tensor_width_, tensor_height_); - glUniform1i( - glGetUniformLocation(mask_program_with_prev_->id(), "input_texture"), 1); - glUseProgram(mask_program_no_prev_->id()); - glUniform2i(glGetUniformLocation(mask_program_no_prev_->id(), "out_size"), - tensor_width_, tensor_height_); - glUniform1i( - glGetUniformLocation(mask_program_no_prev_->id(), "input_texture"), 1); + // Vertex shader attributes. + const GLint attr_location[NUM_ATTRIBUTES] = { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, + }; + const GLchar* attr_name[NUM_ATTRIBUTES] = { + "position", + "texture_coordinate", + }; - // Vertex shader attributes. - const GLint attr_location[NUM_ATTRIBUTES] = { - ATTRIB_VERTEX, - ATTRIB_TEXTURE_POSITION, - }; - const GLchar* attr_name[NUM_ATTRIBUTES] = { - "position", - "texture_coordinate", - }; - - // Simple pass-through shader, used for hardware upsampling. - std::string upsample_shader_base = R"( + // Simple pass-through shader, used for hardware upsampling. + std::string upsample_shader_base = R"( #if __VERSION__ < 130 #define in varying #endif // __VERSION__ < 130 @@ -683,16 +677,19 @@ void main() { } )"; - // Program - mediapipe::GlhCreateProgram(mediapipe::kBasicVertexShader, - upsample_shader_base.c_str(), NUM_ATTRIBUTES, - &attr_name[0], attr_location, &upsample_program_); - RET_CHECK(upsample_program_) << "Problem initializing the program."; + // Program + mediapipe::GlhCreateProgram( + mediapipe::kBasicVertexShader, upsample_shader_base.c_str(), + NUM_ATTRIBUTES, &attr_name[0], attr_location, &upsample_program_); + RET_CHECK(upsample_program_) << "Problem initializing the program."; - // Parameters - glUseProgram(upsample_program_); - glUniform1i(glGetUniformLocation(upsample_program_, "input_data"), 1); -#endif // __ANDROID__ + // Parameters + glUseProgram(upsample_program_); + glUniform1i(glGetUniformLocation(upsample_program_, "input_data"), 1); + + return ::mediapipe::OkStatus(); + })); +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/tflite/util.h b/mediapipe/calculators/tflite/util.h new file mode 100644 index 000000000..53e0927af --- /dev/null +++ b/mediapipe/calculators/tflite/util.h @@ -0,0 +1,25 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_ +#define MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_ + +#define RET_CHECK_CALL(call) \ + do { \ + const auto status = (call); \ + if (ABSL_PREDICT_FALSE(!status.ok())) \ + return ::mediapipe::InternalError(status.error_message()); \ + } while (0); + +#endif // MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_ diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 3dde96aee..7bd06fe97 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -235,19 +235,13 @@ cc_library( "//mediapipe/framework/port:vector", "//mediapipe/util:annotation_renderer", ] + select({ - "//mediapipe:android": [ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:shader_util", ], - "//mediapipe:ios": [ - "//mediapipe/gpu:gl_calculator_helper", - "//mediapipe/gpu:gl_simple_shaders", - "//mediapipe/gpu:gpu_buffer", - "//mediapipe/gpu:shader_util", - ], - "//conditions:default": [], }), alwayslink = 1, ) @@ -694,3 +688,65 @@ cc_test( "//mediapipe/framework/tool:validate_type", ], ) + +proto_library( + name = "top_k_scores_calculator_proto", + srcs = ["top_k_scores_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +mediapipe_cc_proto_library( + name = "top_k_scores_calculator_cc_proto", + srcs = ["top_k_scores_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + ], + visibility = ["//visibility:public"], + deps = [":top_k_scores_calculator_proto"], +) + +cc_library( + name = "top_k_scores_calculator", + srcs = ["top_k_scores_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":top_k_scores_calculator_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "//mediapipe/framework:calculator_framework", + "//mediapipe/util:resource_util", + ] + select({ + "//mediapipe:android": [ + "//mediapipe/util/android/file/base", + ], + "//mediapipe:apple": [ + "//mediapipe/util/android/file/base", + ], + "//mediapipe:macos": [ + "//mediapipe/framework/port:file_helpers", + ], + "//conditions:default": [ + "//mediapipe/framework/port:file_helpers", + ], + }), + alwayslink = 1, +) + +cc_test( + name = "top_k_scores_calculator_test", + srcs = ["top_k_scores_calculator_test.cc"], + deps = [ + ":top_k_scores_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:packet", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + ], +) diff --git a/mediapipe/calculators/util/annotation_overlay_calculator.cc b/mediapipe/calculators/util/annotation_overlay_calculator.cc index ee21302ef..5f5c53582 100644 --- a/mediapipe/calculators/util/annotation_overlay_calculator.cc +++ b/mediapipe/calculators/util/annotation_overlay_calculator.cc @@ -27,12 +27,12 @@ #include "mediapipe/util/annotation_renderer.h" #include "mediapipe/util/color.pb.h" -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU namespace mediapipe { @@ -146,13 +146,13 @@ class AnnotationOverlayCalculator : public CalculatorBase { bool use_gpu_ = false; bool gpu_initialized_ = false; -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; GLuint image_mat_tex_ = 0; // Overlay drawing image for GPU. int width_ = 0; int height_ = 0; -#endif // __ANDROID__ or iOS +#endif // MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(AnnotationOverlayCalculator); @@ -160,6 +160,8 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); CalculatorContract* cc) { CHECK_GE(cc->Inputs().NumEntries(), 1); + bool use_gpu = false; + if (cc->Inputs().HasTag(kInputFrameTag) && cc->Inputs().HasTag(kInputFrameTagGpu)) { return ::mediapipe::InternalError("Cannot have multiple input images."); @@ -173,12 +175,13 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); int num_render_streams = cc->Inputs().NumEntries(); // Input image to render onto copy of. -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) if (cc->Inputs().HasTag(kInputFrameTagGpu)) { cc->Inputs().Tag(kInputFrameTagGpu).Set(); num_render_streams = cc->Inputs().NumEntries() - 1; + use_gpu |= true; } -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputFrameTag)) { cc->Inputs().Tag(kInputFrameTag).Set(); num_render_streams = cc->Inputs().NumEntries() - 1; @@ -190,18 +193,21 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); } // Rendered image. -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { cc->Outputs().Tag(kOutputFrameTagGpu).Set(); + use_gpu |= true; } -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kOutputFrameTag)) { cc->Outputs().Tag(kOutputFrameTag).Set(); } -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) - MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // __ANDROID__ or iOS + if (use_gpu) { +#if !defined(MEDIAPIPE_DISABLE_GPU) + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // !MEDIAPIPE_DISABLE_GPU + } return ::mediapipe::OkStatus(); } @@ -212,11 +218,11 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); options_ = cc->Options(); if (cc->Inputs().HasTag(kInputFrameTagGpu) && cc->Outputs().HasTag(kOutputFrameTagGpu)) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) use_gpu_ = true; #else - RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; -#endif // __ANDROID__ or iOS + RET_CHECK_FAIL() << "GPU processing not enabled."; +#endif // !MEDIAPIPE_DISABLE_GPU } if (cc->Inputs().HasTag(kInputFrameTagGpu) || @@ -246,9 +252,9 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); } if (use_gpu_) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU } return ::mediapipe::OkStatus(); @@ -260,7 +266,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); std::unique_ptr image_mat; ImageFormat::Format target_format; if (use_gpu_) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) if (!gpu_initialized_) { MP_RETURN_IF_ERROR( gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { @@ -269,7 +275,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); })); gpu_initialized_ = true; } -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(CreateRenderTargetGpu(cc, image_mat)); } else { MP_RETURN_IF_ERROR(CreateRenderTargetCpu(cc, image_mat, &target_format)); @@ -288,7 +294,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); } if (use_gpu_) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) // Overlay rendered image in OpenGL, onto a copy of input. uchar* image_mat_ptr = image_mat->data; MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( @@ -296,7 +302,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); MP_RETURN_IF_ERROR(RenderToGpu(cc, image_mat_ptr)); return ::mediapipe::OkStatus(); })); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU } else { // Copy the rendered image to output. uchar* image_mat_ptr = image_mat->data; @@ -307,14 +313,14 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); } ::mediapipe::Status AnnotationOverlayCalculator::Close(CalculatorContext* cc) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; if (image_mat_tex_) glDeleteTextures(1, &image_mat_tex_); image_mat_tex_ = 0; }); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } @@ -325,7 +331,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); auto output_frame = absl::make_unique( target_format, renderer_->GetImageWidth(), renderer_->GetImageHeight()); -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) output_frame->CopyPixelData(target_format, renderer_->GetImageWidth(), renderer_->GetImageHeight(), data_image, ImageFrame::kGlDefaultAlignmentBoundary); @@ -333,7 +339,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); output_frame->CopyPixelData(target_format, renderer_->GetImageWidth(), renderer_->GetImageHeight(), data_image, ImageFrame::kDefaultAlignmentBoundary); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU cc->Outputs() .Tag(kOutputFrameTag) @@ -344,7 +350,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); ::mediapipe::Status AnnotationOverlayCalculator::RenderToGpu( CalculatorContext* cc, uchar* overlay_image) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) // Source and destination textures. const auto& input_frame = cc->Inputs().Tag(kInputFrameTagGpu).Get(); @@ -390,7 +396,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); // Cleanup input_texture.Release(); output_texture.Release(); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } @@ -451,15 +457,16 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); ::mediapipe::Status AnnotationOverlayCalculator::CreateRenderTargetGpu( CalculatorContext* cc, std::unique_ptr& image_mat) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) if (image_frame_available_) { const auto& input_frame = cc->Inputs().Tag(kInputFrameTagGpu).Get(); const mediapipe::ImageFormat::Format format = mediapipe::ImageFormatForGpuBufferFormat(input_frame.format()); - if (format != mediapipe::ImageFormat::SRGBA) - RET_CHECK_FAIL() << "Unsupported GPU input format."; + if (format != mediapipe::ImageFormat::SRGBA && + format != mediapipe::ImageFormat::SRGB) + RET_CHECK_FAIL() << "Unsupported GPU input format: " << format; image_mat = absl::make_unique( height_, width_, CV_8UC3, @@ -471,14 +478,14 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); cv::Scalar(options_.canvas_color().r(), options_.canvas_color().g(), options_.canvas_color().b())); } -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } ::mediapipe::Status AnnotationOverlayCalculator::GlRender( CalculatorContext* cc) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -526,14 +533,14 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); glBindVertexArray(0); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } ::mediapipe::Status AnnotationOverlayCalculator::GlSetup( CalculatorContext* cc) { -#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#if !defined(MEDIAPIPE_DISABLE_GPU) const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -609,7 +616,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); glBindTexture(GL_TEXTURE_2D, 0); } -#endif // __ANDROID__ or iOS +#endif // !MEDIAPIPE_DISABLE_GPU return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/util/rect_to_render_data_calculator.cc b/mediapipe/calculators/util/rect_to_render_data_calculator.cc index 12cce1fa2..24db78a88 100644 --- a/mediapipe/calculators/util/rect_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_data_calculator.cc @@ -23,10 +23,25 @@ namespace mediapipe { namespace { -constexpr char kNormalizedRectTag[] = "NORM_RECT"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kRectTag[] = "RECT"; +constexpr char kNormRectsTag[] = "NORM_RECTS"; +constexpr char kRectsTag[] = "RECTS"; constexpr char kRenderDataTag[] = "RENDER_DATA"; +RenderAnnotation::Rectangle* NewRect( + const RectToRenderDataCalculatorOptions& options, RenderData* render_data) { + auto* annotation = render_data->add_render_annotations(); + annotation->mutable_color()->set_r(options.color().r()); + annotation->mutable_color()->set_g(options.color().g()); + annotation->mutable_color()->set_b(options.color().b()); + annotation->set_thickness(options.thickness()); + + return options.filled() + ? annotation->mutable_filled_rectangle()->mutable_rectangle() + : annotation->mutable_rectangle(); +} + void SetRect(bool normalized, double xmin, double ymin, double width, double height, double rotation, RenderAnnotation::Rectangle* rect) { @@ -51,6 +66,8 @@ void SetRect(bool normalized, double xmin, double ymin, double width, // One of the following: // NORM_RECT: A NormalizedRect // RECT: A Rect +// NORM_RECTS: An std::vector +// RECTS: An std::vector // // Output: // RENDER_DATA: A RenderData @@ -83,16 +100,27 @@ REGISTER_CALCULATOR(RectToRenderDataCalculator); ::mediapipe::Status RectToRenderDataCalculator::GetContract( CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag(kNormalizedRectTag) ^ - cc->Inputs().HasTag(kRectTag)); + RET_CHECK_EQ((cc->Inputs().HasTag(kNormRectTag) ? 1 : 0) + + (cc->Inputs().HasTag(kRectTag) ? 1 : 0) + + (cc->Inputs().HasTag(kNormRectsTag) ? 1 : 0) + + (cc->Inputs().HasTag(kRectsTag) ? 1 : 0), + 1) + << "Exactly one of NORM_RECT, RECT, NORM_RECTS or RECTS input stream " + "should be provided."; RET_CHECK(cc->Outputs().HasTag(kRenderDataTag)); - if (cc->Inputs().HasTag(kNormalizedRectTag)) { - cc->Inputs().Tag(kNormalizedRectTag).Set(); + if (cc->Inputs().HasTag(kNormRectTag)) { + cc->Inputs().Tag(kNormRectTag).Set(); } if (cc->Inputs().HasTag(kRectTag)) { cc->Inputs().Tag(kRectTag).Set(); } + if (cc->Inputs().HasTag(kNormRectsTag)) { + cc->Inputs().Tag(kNormRectsTag).Set>(); + } + if (cc->Inputs().HasTag(kRectsTag)) { + cc->Inputs().Tag(kRectsTag).Set>(); + } cc->Outputs().Tag(kRenderDataTag).Set(); return ::mediapipe::OkStatus(); @@ -108,31 +136,43 @@ REGISTER_CALCULATOR(RectToRenderDataCalculator); ::mediapipe::Status RectToRenderDataCalculator::Process(CalculatorContext* cc) { auto render_data = absl::make_unique(); - auto* annotation = render_data->add_render_annotations(); - annotation->mutable_color()->set_r(options_.color().r()); - annotation->mutable_color()->set_g(options_.color().g()); - annotation->mutable_color()->set_b(options_.color().b()); - annotation->set_thickness(options_.thickness()); - auto* rectangle = - options_.filled() - ? annotation->mutable_filled_rectangle()->mutable_rectangle() - : annotation->mutable_rectangle(); - - if (cc->Inputs().HasTag(kNormalizedRectTag) && - !cc->Inputs().Tag(kNormalizedRectTag).IsEmpty()) { - const auto& rect = - cc->Inputs().Tag(kNormalizedRectTag).Get(); + if (cc->Inputs().HasTag(kNormRectTag) && + !cc->Inputs().Tag(kNormRectTag).IsEmpty()) { + const auto& rect = cc->Inputs().Tag(kNormRectTag).Get(); + auto* rectangle = NewRect(options_, render_data.get()); SetRect(/*normalized=*/true, rect.x_center() - rect.width() / 2.f, rect.y_center() - rect.height() / 2.f, rect.width(), rect.height(), rect.rotation(), rectangle); } if (cc->Inputs().HasTag(kRectTag) && !cc->Inputs().Tag(kRectTag).IsEmpty()) { const auto& rect = cc->Inputs().Tag(kRectTag).Get(); + auto* rectangle = NewRect(options_, render_data.get()); SetRect(/*normalized=*/false, rect.x_center() - rect.width() / 2.f, rect.y_center() - rect.height() / 2.f, rect.width(), rect.height(), rect.rotation(), rectangle); } + if (cc->Inputs().HasTag(kNormRectsTag) && + !cc->Inputs().Tag(kNormRectsTag).IsEmpty()) { + const auto& rects = + cc->Inputs().Tag(kNormRectsTag).Get>(); + for (auto& rect : rects) { + auto* rectangle = NewRect(options_, render_data.get()); + SetRect(/*normalized=*/true, rect.x_center() - rect.width() / 2.f, + rect.y_center() - rect.height() / 2.f, rect.width(), + rect.height(), rect.rotation(), rectangle); + } + } + if (cc->Inputs().HasTag(kRectsTag) && + !cc->Inputs().Tag(kRectsTag).IsEmpty()) { + const auto& rects = cc->Inputs().Tag(kRectsTag).Get>(); + for (auto& rect : rects) { + auto* rectangle = NewRect(options_, render_data.get()); + SetRect(/*normalized=*/false, rect.x_center() - rect.width() / 2.f, + rect.y_center() - rect.height() / 2.f, rect.width(), + rect.height(), rect.rotation(), rectangle); + } + } cc->Outputs() .Tag(kRenderDataTag) diff --git a/mediapipe/calculators/util/top_k_scores_calculator.cc b/mediapipe/calculators/util/top_k_scores_calculator.cc new file mode 100644 index 000000000..18f2eec62 --- /dev/null +++ b/mediapipe/calculators/util/top_k_scores_calculator.cc @@ -0,0 +1,194 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mediapipe/calculators/util/top_k_scores_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/statusor.h" +#include "mediapipe/util/resource_util.h" + +#if defined(MEDIAPIPE_LITE) || defined(__ANDROID__) || \ + (defined(__APPLE__) && !TARGET_OS_OSX) +#include "mediapipe/util/android/file/base/file.h" +#include "mediapipe/util/android/file/base/helpers.h" +#else +#include "mediapipe/framework/port/file_helpers.h" +#endif + +namespace mediapipe { +// A calculator that takes a vector of scores and returns the indexes, scores, +// labels of the top k elements. +// +// Usage example: +// node { +// calculator: "TopKScoresCalculator" +// input_stream: "SCORES:score_vector" +// output_stream: "TOP_K_INDEXES:top_k_indexes" +// output_stream: "TOP_K_SCORES:top_k_scores" +// output_stream: "TOP_K_LABELS:top_k_labels" +// options: { +// [mediapipe.TopKScoresCalculatorOptions.ext] { +// top_k: 5 +// threshold: 0.1 +// label_map_path: "/path/to/label/map" +// } +// } +// } +class TopKScoresCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + ::mediapipe::Status LoadLabelmap(std::string label_map_path); + + int top_k_ = -1; + float threshold_ = 0.0; + std::unordered_map label_map_; +}; +REGISTER_CALCULATOR(TopKScoresCalculator); + +::mediapipe::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag("SCORES")); + cc->Inputs().Tag("SCORES").Set>(); + if (cc->Outputs().HasTag("TOP_K_INDEXES")) { + cc->Outputs().Tag("TOP_K_INDEXES").Set>(); + } + if (cc->Outputs().HasTag("TOP_K_SCORES")) { + cc->Outputs().Tag("TOP_K_SCORES").Set>(); + } + if (cc->Outputs().HasTag("TOP_K_LABELS")) { + cc->Outputs().Tag("TOP_K_LABELS").Set>(); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TopKScoresCalculator::Open(CalculatorContext* cc) { + const auto& options = cc->Options<::mediapipe::TopKScoresCalculatorOptions>(); + RET_CHECK(options.has_top_k() || options.has_threshold()) + << "Must specify at least one of the top_k and threshold fields in " + "TopKScoresCalculatorOptions."; + if (options.has_top_k()) { + RET_CHECK(options.top_k() > 0) << "top_k must be greater than zero."; + top_k_ = options.top_k(); + } + if (options.has_threshold()) { + threshold_ = options.threshold(); + } + if (options.has_label_map_path()) { + MP_RETURN_IF_ERROR(LoadLabelmap(options.label_map_path())); + } + if (cc->Outputs().HasTag("TOP_K_LABELS")) { + RET_CHECK(!label_map_.empty()); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TopKScoresCalculator::Process(CalculatorContext* cc) { + const std::vector& input_vector = + cc->Inputs().Tag("SCORES").Get>(); + std::vector top_k_indexes; + + std::vector top_k_scores; + + std::vector top_k_labels; + + if (top_k_ > 0) { + top_k_indexes.reserve(top_k_); + top_k_scores.reserve(top_k_); + top_k_labels.reserve(top_k_); + } + std::priority_queue, std::vector>, + std::greater>> + pq; + for (int i = 0; i < input_vector.size(); ++i) { + if (input_vector[i] < threshold_) { + continue; + } + if (top_k_ > 0) { + if (pq.size() < top_k_) { + pq.push(std::pair(input_vector[i], i)); + } else if (pq.top().first < input_vector[i]) { + pq.pop(); + pq.push(std::pair(input_vector[i], i)); + } + } else { + pq.push(std::pair(input_vector[i], i)); + } + } + + while (!pq.empty()) { + top_k_indexes.push_back(pq.top().second); + top_k_scores.push_back(pq.top().first); + pq.pop(); + } + reverse(top_k_indexes.begin(), top_k_indexes.end()); + reverse(top_k_scores.begin(), top_k_scores.end()); + + if (cc->Outputs().HasTag("TOP_K_LABELS")) { + for (int index : top_k_indexes) { + top_k_labels.push_back(label_map_[index]); + } + } + if (cc->Outputs().HasTag("TOP_K_INDEXES")) { + cc->Outputs() + .Tag("TOP_K_INDEXES") + .AddPacket(MakePacket>(top_k_indexes) + .At(cc->InputTimestamp())); + } + if (cc->Outputs().HasTag("TOP_K_SCORES")) { + cc->Outputs() + .Tag("TOP_K_SCORES") + .AddPacket(MakePacket>(top_k_scores) + .At(cc->InputTimestamp())); + } + if (cc->Outputs().HasTag("TOP_K_LABELS")) { + cc->Outputs() + .Tag("TOP_K_LABELS") + .AddPacket(MakePacket>(top_k_labels) + .At(cc->InputTimestamp())); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TopKScoresCalculator::LoadLabelmap( + std::string label_map_path) { + std::string string_path; + ASSIGN_OR_RETURN(string_path, PathToResourceAsFile(label_map_path)); + std::string label_map_string; + MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string)); + + std::istringstream stream(label_map_string); + std::string line; + int i = 0; + while (std::getline(stream, line)) { + label_map_[i++] = line; + } + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/top_k_scores_calculator.proto b/mediapipe/calculators/util/top_k_scores_calculator.proto new file mode 100644 index 000000000..08fb7a756 --- /dev/null +++ b/mediapipe/calculators/util/top_k_scores_calculator.proto @@ -0,0 +1,33 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TopKScoresCalculatorOptions { + extend CalculatorOptions { + optional TopKScoresCalculatorOptions ext = 271211788; + } + // How many highest scoring packets to output. + optional int32 top_k = 1; + + // If set, only keep the scores that are greater than the threshold. + optional float threshold = 2; + + // Path to a label map file for getting the actual name of classes. + optional string label_map_path = 3; +} diff --git a/mediapipe/calculators/util/top_k_scores_calculator_test.cc b/mediapipe/calculators/util/top_k_scores_calculator_test.cc new file mode 100644 index 000000000..7daeb5c0c --- /dev/null +++ b/mediapipe/calculators/util/top_k_scores_calculator_test.cc @@ -0,0 +1,150 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +TEST(TopKScoresCalculatorTest, TestNodeConfig) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "TopKScoresCalculator" + input_stream: "SCORES:score_vector" + output_stream: "TOP_K_INDEXES:top_k_indexes" + output_stream: "TOP_K_SCORES:top_k_scores" + options: { + [mediapipe.TopKScoresCalculatorOptions.ext] {} + } + )")); + + auto status = runner.Run(); + ASSERT_TRUE(!status.ok()); + EXPECT_THAT( + status.ToString(), + testing::HasSubstr( + "Must specify at least one of the top_k and threshold fields")); +} + +TEST(TopKScoresCalculatorTest, TestTopKOnly) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "TopKScoresCalculator" + input_stream: "SCORES:score_vector" + output_stream: "TOP_K_INDEXES:top_k_indexes" + output_stream: "TOP_K_SCORES:top_k_scores" + options: { + [mediapipe.TopKScoresCalculatorOptions.ext] { top_k: 2 } + } + )")); + + std::vector score_vector{0.9, 0.2, 0.3, 1.0, 0.1}; + + runner.MutableInputs()->Tag("SCORES").packets.push_back( + MakePacket>(score_vector).At(Timestamp(0))); + + MP_ASSERT_OK(runner.Run()); + const std::vector& indexes_outputs = + runner.Outputs().Tag("TOP_K_INDEXES").packets; + ASSERT_EQ(1, indexes_outputs.size()); + const auto& indexes = indexes_outputs[0].Get>(); + EXPECT_EQ(2, indexes.size()); + EXPECT_EQ(3, indexes[0]); + EXPECT_EQ(0, indexes[1]); + const std::vector& scores_outputs = + runner.Outputs().Tag("TOP_K_SCORES").packets; + ASSERT_EQ(1, scores_outputs.size()); + const auto& scores = scores_outputs[0].Get>(); + EXPECT_EQ(2, scores.size()); + EXPECT_NEAR(1, scores[0], 1e-5); + EXPECT_NEAR(0.9, scores[1], 1e-5); +} + +TEST(TopKScoresCalculatorTest, TestThresholdOnly) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "TopKScoresCalculator" + input_stream: "SCORES:score_vector" + output_stream: "TOP_K_INDEXES:top_k_indexes" + output_stream: "TOP_K_SCORES:top_k_scores" + options: { + [mediapipe.TopKScoresCalculatorOptions.ext] { threshold: 0.2 } + } + )")); + + std::vector score_vector{0.9, 0.2, 0.3, 1.0, 0.1}; + + runner.MutableInputs()->Tag("SCORES").packets.push_back( + MakePacket>(score_vector).At(Timestamp(0))); + + MP_ASSERT_OK(runner.Run()); + const std::vector& indexes_outputs = + runner.Outputs().Tag("TOP_K_INDEXES").packets; + ASSERT_EQ(1, indexes_outputs.size()); + const auto& indexes = indexes_outputs[0].Get>(); + EXPECT_EQ(4, indexes.size()); + EXPECT_EQ(3, indexes[0]); + EXPECT_EQ(0, indexes[1]); + EXPECT_EQ(2, indexes[2]); + EXPECT_EQ(1, indexes[3]); + const std::vector& scores_outputs = + runner.Outputs().Tag("TOP_K_SCORES").packets; + ASSERT_EQ(1, scores_outputs.size()); + const auto& scores = scores_outputs[0].Get>(); + EXPECT_EQ(4, scores.size()); + EXPECT_NEAR(1.0, scores[0], 1e-5); + EXPECT_NEAR(0.9, scores[1], 1e-5); + EXPECT_NEAR(0.3, scores[2], 1e-5); + EXPECT_NEAR(0.2, scores[3], 1e-5); +} + +TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "TopKScoresCalculator" + input_stream: "SCORES:score_vector" + output_stream: "TOP_K_INDEXES:top_k_indexes" + output_stream: "TOP_K_SCORES:top_k_scores" + options: { + [mediapipe.TopKScoresCalculatorOptions.ext] { top_k: 4 threshold: 0.3 } + } + )")); + + std::vector score_vector{0.9, 0.2, 0.3, 1.0, 0.1}; + + runner.MutableInputs()->Tag("SCORES").packets.push_back( + MakePacket>(score_vector).At(Timestamp(0))); + + MP_ASSERT_OK(runner.Run()); + const std::vector& indexes_outputs = + runner.Outputs().Tag("TOP_K_INDEXES").packets; + ASSERT_EQ(1, indexes_outputs.size()); + const auto& indexes = indexes_outputs[0].Get>(); + EXPECT_EQ(3, indexes.size()); + EXPECT_EQ(3, indexes[0]); + EXPECT_EQ(0, indexes[1]); + EXPECT_EQ(2, indexes[2]); + const std::vector& scores_outputs = + runner.Outputs().Tag("TOP_K_SCORES").packets; + ASSERT_EQ(1, scores_outputs.size()); + const auto& scores = scores_outputs[0].Get>(); + EXPECT_EQ(3, scores.size()); + EXPECT_NEAR(1.0, scores[0], 1e-5); + EXPECT_NEAR(0.9, scores[1], 1e-5); + EXPECT_NEAR(0.3, scores[2], 1e-5); +} + +} // namespace mediapipe diff --git a/mediapipe/docs/examples.md b/mediapipe/docs/examples.md index 6e3156283..404024b5f 100644 --- a/mediapipe/docs/examples.md +++ b/mediapipe/docs/examples.md @@ -51,6 +51,15 @@ and model details are described in the * [Android](./face_detection_mobile_gpu.md) * [iOS](./face_detection_mobile_gpu.md) +### Face Detection with CPU + +[Face Detection with CPU](./face_detection_mobile_cpu.md) illustrates using the +same TFLite model in a CPU-based pipeline. This example highlights how graphs +can be easily adapted to run on CPU v.s. GPU. + +* [Android](./face_detection_mobile_cpu.md) +* [iOS](./face_detection_mobile_cpu.md) + ### Hand Detection with GPU [Hand Detection with GPU](./hand_detection_mobile_gpu.md) illustrates how to use @@ -103,3 +112,29 @@ object detection models (TensorFlow and TFLite) using the MediaPipe C++ APIs. [Sobel edge detection]:https://en.wikipedia.org/wiki/Sobel_operator [CameraX]:https://developer.android.com/training/camerax + +### Face Detection on Desktop with Webcam + +[Face Detection on Desktop with Webcam](./face_detection_desktop.md) shows how +to use MediaPipe with a TFLite model for face detection on desktop using CPU or +GPU with live video from a webcam. + +* [Desktop GPU](./face_detection_desktop.md) +* [Desktop CPU](./face_detection_desktop.md) + +### Hand Tracking on Desktop with Webcam + +[Hand Tracking on Desktop with Webcam](./hand_tracking_desktop.md) shows how to +use MediaPipe with a TFLite model for hand tracking on desktop using CPU or GPU +with live video from a webcam. + +* [Desktop GPU](./hand_tracking_desktop.md) +* [Desktop CPU](./hand_tracking_desktop.md) + +### Hair Segmentation on Desktop with Webcam + +[Hair Segmentation on Desktop with Webcam](./hair_segmentation_desktop.md) shows +how to use MediaPipe with a TFLite model for hair segmentation on desktop using +GPU with live video from a webcam. + +* [Desktop GPU](./hair_segmentation_desktop.md) diff --git a/mediapipe/docs/face_detection_desktop.md b/mediapipe/docs/face_detection_desktop.md new file mode 100644 index 000000000..b95705262 --- /dev/null +++ b/mediapipe/docs/face_detection_desktop.md @@ -0,0 +1,265 @@ +## Face Detection on Desktop + +This is an example of using MediaPipe to run face detection models (TensorFlow +Lite) and render bounding boxes on the detected faces. To know more about the +face detection models, please refer to the model [`README file`]. Moreover, if +you are interested in running the same TensorfFlow Lite model on Android/iOS, +please see the +[Face Detection on GPU on Android/iOS](face_detection_mobile_gpu.md) and +[Face Detection on CPU on Android/iOS](face_detection_mobile_cpu.md) examples. + +We show the face detection demos with TensorFlow Lite model using the Webcam: + +- [TensorFlow Lite Face Detection Demo with Webcam (CPU)](#tensorflow-lite-face-detection-demo-with-webcam-cpu) + +- [TensorFlow Lite Face Detection Demo with Webcam (GPU)](#tensorflow-lite-face-detection-demo-with-webcam-gpu) + +Note: Desktop GPU works only on Linux. Mesa drivers need to be installed. Please +see +[step 4 of "Installing on Debian and Ubuntu" in the installation guide](./install.md). + +Note: If MediaPipe depends on OpenCV 2, please see the [known issues with OpenCV 2](#known-issues-with-opencv-2) section. + +### TensorFlow Lite Face Detection Demo with Webcam (CPU) + +To build and run the TensorFlow Lite example on desktop (CPU) with Webcam, run: + +```bash +# Video from webcam running on desktop CPU +$ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \ + mediapipe/examples/desktop/face_detection:face_detection_cpu + +# It should print: +# Target //mediapipe/examples/desktop/face_detection:face_detection_cpu up-to-date: +# bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_cpu +# INFO: Elapsed time: 36.417s, Critical Path: 23.22s +# INFO: 711 processes: 710 linux-sandbox, 1 local. +# INFO: Build completed successfully, 734 total actions + +$ export GLOG_logtostderr=1 +# This will open up your webcam as long as it is connected and on +# Any errors is likely due to your webcam being not accessible +$ bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_cpu \ + --calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_desktop_live.pbtxt +``` + +### TensorFlow Lite Face Detection Demo with Webcam (GPU) + +To build and run the TensorFlow Lite example on desktop (GPU) with Webcam, run: + +```bash +# Video from webcam running on desktop GPU +# This works only for linux currently +$ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \ + mediapipe/examples/desktop/face_detection:face_detection_gpu + +# It should print: +# Target //mediapipe/examples/desktop/face_detection:face_detection_gpu up-to-date: +# bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_gpu +# INFO: Elapsed time: 36.417s, Critical Path: 23.22s +# INFO: 711 processes: 710 linux-sandbox, 1 local. +# INFO: Build completed successfully, 734 total actions + +$ export GLOG_logtostderr=1 +# This will open up your webcam as long as it is connected and on +# Any errors is likely due to your webcam being not accessible, +# or GPU drivers not setup properly. +$ bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_gpu \ + --calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt +``` + +#### Graph + +![graph visualization](images/face_detection_desktop.png) + +To visualize the graph as shown above, copy the text specification of the graph +below and paste it into +[MediaPipe Visualizer](https://viz.mediapipe.dev). + +```bash +# MediaPipe graph that performs face detection with TensorFlow Lite on CPU & GPU. +# Used in the examples in +# mediapipie/examples/desktop/face_detection:face_detection_cpu. + +# Images on CPU coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Throttles the images flowing downstream for flow control. It passes through +# the very first incoming image unaltered, and waits for +# TfLiteTensorsToDetectionsCalculator downstream in the graph to finish +# generating the corresponding detections before it passes through another +# image. All images that come in while waiting are dropped, limiting the number +# of in-flight images between this calculator and +# TfLiteTensorsToDetectionsCalculator to 1. This prevents the nodes in between +# from queuing up incoming images and data excessively, which leads to increased +# latency and memory usage, unwanted in real-time mobile applications. It also +# eliminates unnecessarily computation, e.g., a transformed image produced by +# ImageTransformationCalculator may get dropped downstream if the subsequent +# TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy +# processing previous inputs. +node { + calculator: "FlowLimiterCalculator" + input_stream: "input_video" + input_stream: "FINISHED:detections" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_input_video" +} + +# Transforms the input image on CPU to a 128x128 image. To scale the input +# image, the scale_mode option is set to FIT to preserve the aspect ratio, +# resulting in potential letterboxing in the transformed image. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE:throttled_input_video" + output_stream: "IMAGE:transformed_input_video_cpu" + output_stream: "LETTERBOX_PADDING:letterbox_padding" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 128 + output_height: 128 + scale_mode: FIT + } + } +} + +# Converts the transformed input image on CPU into an image tensor stored as a +# TfLiteTensor. +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE:transformed_input_video_cpu" + output_stream: "TENSORS:image_tensor" +} + +# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS:image_tensor" + output_stream: "TENSORS:detection_tensors" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "mediapipe/models/face_detection_front.tflite" + } + } +} + +# Generates a single side packet containing a vector of SSD anchors based on +# the specification in the options. +node { + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + node_options: { + [type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] { + num_layers: 4 + min_scale: 0.1484375 + max_scale: 0.75 + input_size_height: 128 + input_size_width: 128 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 8 + strides: 16 + strides: 16 + strides: 16 + aspect_ratios: 1.0 + fixed_anchor_size: true + } + } +} + +# Decodes the detection tensors generated by the TensorFlow Lite model, based on +# the SSD anchors and the specification in the options, into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "TfLiteTensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] { + num_classes: 1 + num_boxes: 896 + num_coords: 16 + box_coord_offset: 0 + keypoint_coord_offset: 4 + num_keypoints: 6 + num_values_per_keypoint: 2 + sigmoid_score: true + score_clipping_thresh: 100.0 + reverse_output_order: true + x_scale: 128.0 + y_scale: 128.0 + h_scale: 128.0 + w_scale: 128.0 + min_score_thresh: 0.75 + } + } +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "detections" + output_stream: "filtered_detections" + node_options: { + [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { + min_suppression_threshold: 0.3 + overlap_type: INTERSECTION_OVER_UNION + algorithm: WEIGHTED + return_empty_detections: true + } + } +} + +# Maps detection label IDs to the corresponding label text ("Face"). The label +# map is provided in the label_map_path option. +node { + calculator: "DetectionLabelIdToTextCalculator" + input_stream: "filtered_detections" + output_stream: "labeled_detections" + node_options: { + [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { + label_map_path: "mediapipe/models/face_detection_front_labelmap.txt" + } + } +} + +# Adjusts detection locations (already normalized to [0.f, 1.f]) on the +# letterboxed image (after image transformation with the FIT scale mode) to the +# corresponding locations on the same image with the letterbox removed (the +# input image to the graph before image transformation). +node { + calculator: "DetectionLetterboxRemovalCalculator" + input_stream: "DETECTIONS:labeled_detections" + input_stream: "LETTERBOX_PADDING:letterbox_padding" + output_stream: "DETECTIONS:output_detections" +} + +# Converts the detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTIONS:output_detections" + output_stream: "RENDER_DATA:render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 255 g: 0 b: 0 } + } + } +} + +# Draws annotations and overlays them on top of the input images. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME:throttled_input_video" + input_stream: "render_data" + output_stream: "OUTPUT_FRAME:output_video" +} +``` + +[`README file`]:https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md diff --git a/mediapipe/docs/face_detection_mobile_cpu.md b/mediapipe/docs/face_detection_mobile_cpu.md new file mode 100644 index 000000000..bcef66a48 --- /dev/null +++ b/mediapipe/docs/face_detection_mobile_cpu.md @@ -0,0 +1,244 @@ +# Face Detection (CPU) + +This doc focuses on the +[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/face_detection/face_detection_mobile_cpu.pbtxt) +that performs face detection with TensorFlow Lite on CPU. + +## Android + +[Source](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu) + +To build and install the app: + +```bash +bazel build -c opt --config=android_arm64 mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu +adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/facedetectioncpu.apk +``` + +## iOS + +[Source](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/facedetectioncpu). + +See the general [instructions](./mediapipe_ios_setup.md) for building iOS +examples and generating an Xcode project. This will be the FaceDetectionCpuApp +target. + +To build on the command line: + +```bash +bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/facedetectioncpu:FaceDetectionCpuApp +``` + +## Graph + +![face_detection_mobile_cpu_graph](images/mobile/face_detection_mobile_cpu.png) + +To visualize the graph as shown above, copy the text specification of the graph +below and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). + +[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/face_detection/face_detection_mobile_cpu.pbtxt) + +```bash +# MediaPipe graph that performs face detection with TensorFlow Lite on CPU. +# Used in the examples in +# mediapipie/examples/android/src/java/com/mediapipe/apps/facedetectioncpu and +# mediapipie/examples/ios/facedetectioncpu. + +# Images on GPU coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Throttles the images flowing downstream for flow control. It passes through +# the very first incoming image unaltered, and waits for +# TfLiteTensorsToDetectionsCalculator downstream in the graph to finish +# generating the corresponding detections before it passes through another +# image. All images that come in while waiting are dropped, limiting the number +# of in-flight images between this calculator and +# TfLiteTensorsToDetectionsCalculator to 1. This prevents the nodes in between +# from queuing up incoming images and data excessively, which leads to increased +# latency and memory usage, unwanted in real-time mobile applications. It also +# eliminates unnecessarily computation, e.g., a transformed image produced by +# ImageTransformationCalculator may get dropped downstream if the subsequent +# TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy +# processing previous inputs. +node { + calculator: "FlowLimiterCalculator" + input_stream: "input_video" + input_stream: "FINISHED:detections" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_input_video" +} + +# Transfers the input image from GPU to CPU memory for the purpose of +# demonstrating a CPU-based pipeline. Note that the input image on GPU has the +# origin defined at the bottom-left corner (OpenGL convention). As a result, +# the transferred image on CPU also shares the same representation. +node: { + calculator: "GpuBufferToImageFrameCalculator" + input_stream: "throttled_input_video" + output_stream: "input_video_cpu" +} + +# Transforms the input image on CPU to a 128x128 image. To scale the input +# image, the scale_mode option is set to FIT to preserve the aspect ratio, +# resulting in potential letterboxing in the transformed image. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE:input_video_cpu" + output_stream: "IMAGE:transformed_input_video_cpu" + output_stream: "LETTERBOX_PADDING:letterbox_padding" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 128 + output_height: 128 + scale_mode: FIT + } + } +} + +# Converts the transformed input image on CPU into an image tensor stored as a +# TfLiteTensor. +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE:transformed_input_video_cpu" + output_stream: "TENSORS:image_tensor" +} + +# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS:image_tensor" + output_stream: "TENSORS:detection_tensors" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "mediapipe/models/face_detection_front.tflite" + } + } +} + +# Generates a single side packet containing a vector of SSD anchors based on +# the specification in the options. +node { + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + node_options: { + [type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] { + num_layers: 4 + min_scale: 0.1484375 + max_scale: 0.75 + input_size_height: 128 + input_size_width: 128 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 8 + strides: 16 + strides: 16 + strides: 16 + aspect_ratios: 1.0 + fixed_anchor_size: true + } + } +} + +# Decodes the detection tensors generated by the TensorFlow Lite model, based on +# the SSD anchors and the specification in the options, into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "TfLiteTensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] { + num_classes: 1 + num_boxes: 896 + num_coords: 16 + box_coord_offset: 0 + keypoint_coord_offset: 4 + num_keypoints: 6 + num_values_per_keypoint: 2 + sigmoid_score: true + score_clipping_thresh: 100.0 + reverse_output_order: true + x_scale: 128.0 + y_scale: 128.0 + h_scale: 128.0 + w_scale: 128.0 + min_score_thresh: 0.75 + } + } +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "detections" + output_stream: "filtered_detections" + node_options: { + [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { + min_suppression_threshold: 0.3 + overlap_type: INTERSECTION_OVER_UNION + algorithm: WEIGHTED + return_empty_detections: true + } + } +} + +# Maps detection label IDs to the corresponding label text ("Face"). The label +# map is provided in the label_map_path option. +node { + calculator: "DetectionLabelIdToTextCalculator" + input_stream: "filtered_detections" + output_stream: "labeled_detections" + node_options: { + [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { + label_map_path: "mediapipe/models/face_detection_front_labelmap.txt" + } + } +} + +# Adjusts detection locations (already normalized to [0.f, 1.f]) on the +# letterboxed image (after image transformation with the FIT scale mode) to the +# corresponding locations on the same image with the letterbox removed (the +# input image to the graph before image transformation). +node { + calculator: "DetectionLetterboxRemovalCalculator" + input_stream: "DETECTIONS:labeled_detections" + input_stream: "LETTERBOX_PADDING:letterbox_padding" + output_stream: "DETECTIONS:output_detections" +} + +# Converts the detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTIONS:output_detections" + output_stream: "RENDER_DATA:render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 255 g: 0 b: 0 } + } + } +} + +# Draws annotations and overlays them on top of the input images. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME:input_video_cpu" + input_stream: "render_data" + output_stream: "OUTPUT_FRAME:output_video_cpu" +} + +# Transfers the annotated image from CPU back to GPU memory, to be sent out of +# the graph. +node: { + calculator: "ImageFrameToGpuBufferCalculator" + input_stream: "output_video_cpu" + output_stream: "output_video" +} +``` diff --git a/mediapipe/docs/hair_segmentation_desktop.md b/mediapipe/docs/hair_segmentation_desktop.md new file mode 100644 index 000000000..058902363 --- /dev/null +++ b/mediapipe/docs/hair_segmentation_desktop.md @@ -0,0 +1,209 @@ +## Hair Segmentation on Desktop + +This is an example of using MediaPipe to run hair segmentation models +(TensorFlow Lite) and render a color to the detected hair. To know more about +the hair segmentation models, please refer to the model [`README file`]. +Moreover, if you are interested in running the same TensorfFlow Lite model on +Android/iOS, please see the +[Hair Segmentation on GPU on Android/iOS](hair_segmentation_mobile_gpu.md) and + +We show the hair segmentation demos with TensorFlow Lite model using the Webcam: + +- [TensorFlow Lite Hair Segmentation Demo with Webcam (GPU)](#tensorflow-lite-hair-segmentation-demo-with-webcam-gpu) + +Note: Desktop GPU works only on Linux. Mesa drivers need to be installed. Please +see +[step 4 of "Installing on Debian and Ubuntu" in the installation guide](./install.md). + +Note: If MediaPipe depends on OpenCV 2, please see the [known issues with OpenCV 2](#known-issues-with-opencv-2) section. + +### TensorFlow Lite Hair Segmentation Demo with Webcam (GPU) + +To build and run the TensorFlow Lite example on desktop (GPU) with Webcam, run: + +```bash +# Video from webcam running on desktop GPU +# This works only for linux currently +$ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \ + mediapipe/examples/desktop/hair_segmentation:hair_segmentation_gpu + +# It should print: +#INFO: Found 1 target... +#Target //mediapipe/examples/desktop/hair_segmentation:hair_segmentation_gpu up-to-date: +# bazel-bin/mediapipe/examples/desktop/hair_segmentation/hair_segmentation_gpu +#INFO: Elapsed time: 18.209s, Forge stats: 13026/13057 actions cached, 20.8s CPU used, 0.0s queue time, 89.3 MB ObjFS output (novel bytes: 87.4 MB), 0.0 MB local output, Critical Path: 11.88s, Remote (86.01% of the time): [queue: 0.00%, network: 16.83%, setup: 4.59%, process: 38.92%] +#INFO: Streaming build results to: http://sponge2/37d5a184-293b-4e98-a43e-b22084db3142 +#INFO: Build completed successfully, 12210 total actions + +$ export GLOG_logtostderr=1 +# This will open up your webcam as long as it is connected and on +# Any errors is likely due to your webcam being not accessible, +# or GPU drivers not setup properly. +$ bazel-bin/mediapipe/examples/desktop/hair_segmentation/hair_segmentation_gpu \ + --calculator_graph_config_file=mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt +``` + +#### Graph + +![hair_segmentation_mobile_gpu_graph](images/mobile/hair_segmentation_mobile_gpu.png) + +To visualize the graph as shown above, copy the text specification of the graph +below and paste it into +[MediaPipe Visualizer](https://viz.mediapipe.dev). + +```bash +# MediaPipe graph that performs hair segmentation with TensorFlow Lite on GPU. +# Used in the example in +# mediapipie/examples/android/src/java/com/mediapipe/apps/hairsegmentationgpu. + +# Images on GPU coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Throttles the images flowing downstream for flow control. It passes through +# the very first incoming image unaltered, and waits for +# TfLiteTensorsToSegmentationCalculator downstream in the graph to finish +# generating the corresponding hair mask before it passes through another +# image. All images that come in while waiting are dropped, limiting the number +# of in-flight images between this calculator and +# TfLiteTensorsToSegmentationCalculator to 1. This prevents the nodes in between +# from queuing up incoming images and data excessively, which leads to increased +# latency and memory usage, unwanted in real-time mobile applications. It also +# eliminates unnecessarily computation, e.g., a transformed image produced by +# ImageTransformationCalculator may get dropped downstream if the subsequent +# TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy +# processing previous inputs. +node { + calculator: "FlowLimiterCalculator" + input_stream: "input_video" + input_stream: "FINISHED:hair_mask" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_input_video" +} + +# Transforms the input image on GPU to a 512x512 image. To scale the image, by +# default it uses the STRETCH scale mode that maps the entire input image to the +# entire transformed image. As a result, image aspect ratio may be changed and +# objects in the image may be deformed (stretched or squeezed), but the hair +# segmentation model used in this graph is agnostic to that deformation. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE_GPU:throttled_input_video" + output_stream: "IMAGE_GPU:transformed_input_video" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 512 + output_height: 512 + } + } +} + +# Caches a mask fed back from the previous round of hair segmentation, and upon +# the arrival of the next input image sends out the cached mask with the +# timestamp replaced by that of the input image, essentially generating a packet +# that carries the previous mask. Note that upon the arrival of the very first +# input image, an empty packet is sent out to jump start the feedback loop. +node { + calculator: "PreviousLoopbackCalculator" + input_stream: "MAIN:throttled_input_video" + input_stream: "LOOP:hair_mask" + input_stream_info: { + tag_index: "LOOP" + back_edge: true + } + output_stream: "PREV_LOOP:previous_hair_mask" +} + +# Embeds the hair mask generated from the previous round of hair segmentation +# as the alpha channel of the current input image. +node { + calculator: "SetAlphaCalculator" + input_stream: "IMAGE_GPU:transformed_input_video" + input_stream: "ALPHA_GPU:previous_hair_mask" + output_stream: "IMAGE_GPU:mask_embedded_input_video" +} + +# Converts the transformed input image on GPU into an image tensor stored in +# tflite::gpu::GlBuffer. The zero_center option is set to false to normalize the +# pixel values to [0.f, 1.f] as opposed to [-1.f, 1.f]. With the +# max_num_channels option set to 4, all 4 RGBA channels are contained in the +# image tensor. +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE_GPU:mask_embedded_input_video" + output_stream: "TENSORS_GPU:image_tensor" + node_options: { + [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { + zero_center: false + max_num_channels: 4 + } + } +} + +# Generates a single side packet containing a TensorFlow Lite op resolver that +# supports custom ops needed by the model used in this graph. +node { + calculator: "TfLiteCustomOpResolverCalculator" + output_side_packet: "op_resolver" + node_options: { + [type.googleapis.com/mediapipe.TfLiteCustomOpResolverCalculatorOptions] { + use_gpu: true + } + } +} + +# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a +# tensor representing the hair segmentation, which has the same width and height +# as the input image tensor. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS_GPU:image_tensor" + output_stream: "TENSORS_GPU:segmentation_tensor" + input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "mediapipe/models/hair_segmentation.tflite" + use_gpu: true + } + } +} + +# Decodes the segmentation tensor generated by the TensorFlow Lite model into a +# mask of values in [0.f, 1.f], stored in the R channel of a GPU buffer. It also +# takes the mask generated previously as another input to improve the temporal +# consistency. +node { + calculator: "TfLiteTensorsToSegmentationCalculator" + input_stream: "TENSORS_GPU:segmentation_tensor" + input_stream: "PREV_MASK_GPU:previous_hair_mask" + output_stream: "MASK_GPU:hair_mask" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToSegmentationCalculatorOptions] { + tensor_width: 512 + tensor_height: 512 + tensor_channels: 2 + combine_with_previous_ratio: 0.9 + output_layer_index: 1 + } + } +} + +# Colors the hair segmentation with the color specified in the option. +node { + calculator: "RecolorCalculator" + input_stream: "IMAGE_GPU:throttled_input_video" + input_stream: "MASK_GPU:hair_mask" + output_stream: "IMAGE_GPU:output_video" + node_options: { + [type.googleapis.com/mediapipe.RecolorCalculatorOptions] { + color { r: 0 g: 0 b: 255 } + mask_channel: RED + } + } +} +``` + +[`README file`]:https://github.com/google/mediapipe/tree/master/mediapipe/README.md diff --git a/mediapipe/docs/hand_detection_mobile_gpu.md b/mediapipe/docs/hand_detection_mobile_gpu.md index 141c0a26f..5c4a41bbd 100644 --- a/mediapipe/docs/hand_detection_mobile_gpu.md +++ b/mediapipe/docs/hand_detection_mobile_gpu.md @@ -1,7 +1,7 @@ # Hand Detection (GPU) This doc focuses on the -[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt) +[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_mobile.pbtxt) that performs hand detection with TensorFlow Lite on GPU. It is related to the [hand tracking example](./hand_tracking_mobile_gpu.md). @@ -147,7 +147,7 @@ node { ![hand_detection_gpu_subgraph](images/mobile/hand_detection_gpu_subgraph.png) -[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt) +[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/hand_detection_gpu.pbtxt) ```bash # MediaPipe hand detection subgraph. diff --git a/mediapipe/docs/hand_tracking_desktop.md b/mediapipe/docs/hand_tracking_desktop.md new file mode 100644 index 000000000..6776a4710 --- /dev/null +++ b/mediapipe/docs/hand_tracking_desktop.md @@ -0,0 +1,184 @@ +## Hand Tracking on Desktop + +This is an example of using MediaPipe to run hand tracking models (TensorFlow +Lite) and render bounding boxes on the detected hand (one hand only). To know +more about the hand tracking models, please refer to the model [`README file`]. +Moreover, if you are interested in running the same TensorfFlow Lite model on +Android/iOS, please see the +[Hand Tracking on GPU on Android/iOS](hand_tracking_mobile_gpu.md) and + +We show the hand tracking demos with TensorFlow Lite model using the Webcam: + +- [TensorFlow Lite Hand Tracking Demo with Webcam (CPU)](#tensorflow-lite-hand-tracking-demo-with-webcam-cpu) + +- [TensorFlow Lite Hand Tracking Demo with Webcam (GPU)](#tensorflow-lite-hand-tracking-demo-with-webcam-gpu) + +Note: Desktop GPU works only on Linux. Mesa drivers need to be installed. Please +see +[step 4 of "Installing on Debian and Ubuntu" in the installation guide](./install.md). + +Note: If MediaPipe depends on OpenCV 2, please see the [known issues with OpenCV 2](#known-issues-with-opencv-2) section. + +### TensorFlow Lite Hand Tracking Demo with Webcam (CPU) + +To build and run the TensorFlow Lite example on desktop (CPU) with Webcam, run: + +```bash +# Video from webcam running on desktop CPU +$ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \ + mediapipe/examples/desktop/hand_tracking:hand_tracking_cpu + +# It should print: +#Target //mediapipe/examples/desktop/hand_tracking:hand_tracking_cpu up-to-date: +# bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu +#INFO: Elapsed time: 22.645s, Forge stats: 13356/13463 actions cached, 1.5m CPU used, 0.0s queue time, 819.8 MB ObjFS output (novel bytes: 85.6 MB), 0.0 MB local output, Critical Path: 14.43s, Remote (87.25% of the time): [queue: 0.00%, network: 14.88%, setup: 4.80%, process: 39.80%, fetch: 18.15%] +#INFO: Streaming build results to: http://sponge2/360196b9-33ab-44b1-84a7-1022b5043307 +#INFO: Build completed successfully, 12517 total actions + +$ export GLOG_logtostderr=1 +# This will open up your webcam as long as it is connected and on +# Any errors is likely due to your webcam being not accessible +$ bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu \ + --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt +``` + +### TensorFlow Lite Hand Tracking Demo with Webcam (GPU) + +To build and run the TensorFlow Lite example on desktop (GPU) with Webcam, run: + +```bash +# Video from webcam running on desktop GPU +# This works only for linux currently +$ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \ + mediapipe/examples/desktop/hand_tracking:hand_tracking_gpu + +# It should print: +# Target //mediapipe/examples/desktop/hand_tracking:hand_tracking_gpu up-to-date: +# bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu +#INFO: Elapsed time: 84.055s, Forge stats: 6858/19343 actions cached, 1.6h CPU used, 0.9s queue time, 1.68 GB ObjFS output (novel bytes: 485.1 MB), 0.0 MB local output, Critical Path: 48.14s, Remote (99.40% of the time): [queue: 0.00%, setup: 5.59%, process: 74.44%] +#INFO: Streaming build results to: http://sponge2/00c7f95f-6fbc-432d-8978-f5d361efca3b +#INFO: Build completed successfully, 22455 total actions + +$ export GLOG_logtostderr=1 +# This will open up your webcam as long as it is connected and on +# Any errors is likely due to your webcam being not accessible, +# or GPU drivers not setup properly. +$ bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \ + --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt +``` + +#### Graph + +![graph visualization](images/hand_tracking_desktop.png) + +To visualize the graph as shown above, copy the text specification of the graph +below and paste it into +[MediaPipe Visualizer](https://viz.mediapipe.dev). + +```bash +# MediaPipe graph that performs hand tracking on desktop with TensorFlow Lite +# on CPU & GPU. +# Used in the example in +# mediapipie/examples/desktop/hand_tracking:hand_tracking_cpu. + +# Images coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Caches a hand-presence decision fed back from HandLandmarkSubgraph, and upon +# the arrival of the next input image sends out the cached decision with the +# timestamp replaced by that of the input image, essentially generating a packet +# that carries the previous hand-presence decision. Note that upon the arrival +# of the very first input image, an empty packet is sent out to jump start the +# feedback loop. +node { + calculator: "PreviousLoopbackCalculator" + input_stream: "MAIN:input_video" + input_stream: "LOOP:hand_presence" + input_stream_info: { + tag_index: "LOOP" + back_edge: true + } + output_stream: "PREV_LOOP:prev_hand_presence" +} + +# Drops the incoming image if HandLandmarkSubgraph was able to identify hand +# presence in the previous image. Otherwise, passes the incoming image through +# to trigger a new round of hand detection in HandDetectionSubgraph. +node { + calculator: "GateCalculator" + input_stream: "input_video" + input_stream: "DISALLOW:prev_hand_presence" + output_stream: "hand_detection_input_video" + + node_options: { + [type.googleapis.com/mediapipe.GateCalculatorOptions] { + empty_packets_as_allow: true + } + } +} + +# Subgraph that detections hands (see hand_detection_cpu.pbtxt). +node { + calculator: "HandDetectionSubgraph" + input_stream: "hand_detection_input_video" + output_stream: "DETECTIONS:palm_detections" + output_stream: "NORM_RECT:hand_rect_from_palm_detections" +} + +# Subgraph that localizes hand landmarks (see hand_landmark_cpu.pbtxt). +node { + calculator: "HandLandmarkSubgraph" + input_stream: "IMAGE:input_video" + input_stream: "NORM_RECT:hand_rect" + output_stream: "LANDMARKS:hand_landmarks" + output_stream: "NORM_RECT:hand_rect_from_landmarks" + output_stream: "PRESENCE:hand_presence" +} + +# Caches a hand rectangle fed back from HandLandmarkSubgraph, and upon the +# arrival of the next input image sends out the cached rectangle with the +# timestamp replaced by that of the input image, essentially generating a packet +# that carries the previous hand rectangle. Note that upon the arrival of the +# very first input image, an empty packet is sent out to jump start the +# feedback loop. +node { + calculator: "PreviousLoopbackCalculator" + input_stream: "MAIN:input_video" + input_stream: "LOOP:hand_rect_from_landmarks" + input_stream_info: { + tag_index: "LOOP" + back_edge: true + } + output_stream: "PREV_LOOP:prev_hand_rect_from_landmarks" +} + +# Merges a stream of hand rectangles generated by HandDetectionSubgraph and that +# generated by HandLandmarkSubgraph into a single output stream by selecting +# between one of the two streams. The former is selected if the incoming packet +# is not empty, i.e., hand detection is performed on the current image by +# HandDetectionSubgraph (because HandLandmarkSubgraph could not identify hand +# presence in the previous image). Otherwise, the latter is selected, which is +# never empty because HandLandmarkSubgraphs processes all images (that went +# through FlowLimiterCaculator). +node { + calculator: "MergeCalculator" + input_stream: "hand_rect_from_palm_detections" + input_stream: "prev_hand_rect_from_landmarks" + output_stream: "hand_rect" +} + +# Subgraph that renders annotations and overlays them on top of the input +# images (see renderer_cpu.pbtxt). +node { + calculator: "RendererSubgraph" + input_stream: "IMAGE:input_video" + input_stream: "LANDMARKS:hand_landmarks" + input_stream: "NORM_RECT:hand_rect" + input_stream: "DETECTIONS:palm_detections" + output_stream: "IMAGE:output_video" +} + +``` + +[`README file`]:https://github.com/google/mediapipe/tree/master/mediapipe/README.md diff --git a/mediapipe/docs/hand_tracking_mobile_gpu.md b/mediapipe/docs/hand_tracking_mobile_gpu.md index 08af5dcc5..be9cdd264 100644 --- a/mediapipe/docs/hand_tracking_mobile_gpu.md +++ b/mediapipe/docs/hand_tracking_mobile_gpu.md @@ -227,7 +227,7 @@ node { ![hand_detection_gpu_subgraph](images/mobile/hand_detection_gpu_subgraph.png) -[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt) +[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/hand_detection_gpu.pbtxt) ```bash # MediaPipe hand detection subgraph. @@ -433,7 +433,7 @@ node { ![hand_landmark_gpu_subgraph.pbtxt](images/mobile/hand_landmark_gpu_subgraph.png) -[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_landmark_gpu.pbtxt) +[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_gpu.pbtxt) ```bash # MediaPipe hand landmark localization subgraph. @@ -617,7 +617,7 @@ node { ![hand_renderer_gpu_subgraph.pbtxt](images/mobile/hand_renderer_gpu_subgraph.png) -[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/renderer_gpu.pbtxt) +[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/renderer_gpu.pbtxt) ```bash # MediaPipe hand tracking rendering subgraph. diff --git a/mediapipe/docs/images/face_detection_desktop.png b/mediapipe/docs/images/face_detection_desktop.png new file mode 100644 index 000000000..7f5f8ab1b Binary files /dev/null and b/mediapipe/docs/images/face_detection_desktop.png differ diff --git a/mediapipe/docs/images/hand_tracking_desktop.png b/mediapipe/docs/images/hand_tracking_desktop.png new file mode 100644 index 000000000..cfc38f057 Binary files /dev/null and b/mediapipe/docs/images/hand_tracking_desktop.png differ diff --git a/mediapipe/docs/images/mobile/face_detection_mobile_cpu.png b/mediapipe/docs/images/mobile/face_detection_mobile_cpu.png new file mode 100644 index 000000000..e57caa23a Binary files /dev/null and b/mediapipe/docs/images/mobile/face_detection_mobile_cpu.png differ diff --git a/mediapipe/docs/images/mobile/hand_tracking_mobile.png b/mediapipe/docs/images/mobile/hand_tracking_mobile.png index 66b9a7a9e..3b2063190 100644 Binary files a/mediapipe/docs/images/mobile/hand_tracking_mobile.png and b/mediapipe/docs/images/mobile/hand_tracking_mobile.png differ diff --git a/mediapipe/docs/install.md b/mediapipe/docs/install.md index 4b82b7cb0..99473811e 100644 --- a/mediapipe/docs/install.md +++ b/mediapipe/docs/install.md @@ -107,14 +107,34 @@ To build and run iOS apps: ) ``` -4. Run the [Hello World desktop example](./hello_world_desktop.md). +4. For running desktop examples on Linux only (not on OS X) with GPU + acceleration. + + ```bash + # Requires a GPU with EGL driver support. + # Can use mesa GPU libraries for desktop, (or Nvidia/AMD equivalent). + sudo apt-get install mesa-common-dev libegl1-mesa-dev libgles2-mesa-dev + + # To compile with GPU support, replace + --define MEDIAPIPE_DISABLE_GPU=1 + # with + --copt -DMESA_EGL_NO_X11_HEADERS + # when building GPU examples. + ``` + +5. Run the [Hello World desktop example](./hello_world_desktop.md). ```bash $ export GLOG_logtostderr=1 - # Need bazel flag 'MEDIAPIPE_DISABLE_GPU=1' as desktop GPU is currently not supported + + # if you are running on Linux desktop with CPU only $ bazel run --define MEDIAPIPE_DISABLE_GPU=1 \ mediapipe/examples/desktop/hello_world:hello_world + # If you are running on Linux desktop with GPU support enabled (via mesa drivers) + $ bazel run --copt -DMESA_EGL_NO_X11_HEADERS \ + mediapipe/examples/desktop/hello_world:hello_world + # Should print: # Hello World! # Hello World! @@ -194,7 +214,7 @@ To build and run iOS apps: ```bash $ export GLOG_logtostderr=1 - # Need bazel flag 'MEDIAPIPE_DISABLE_GPU=1' as desktop GPU is currently not supported + # Need bazel flag 'MEDIAPIPE_DISABLE_GPU=1' if you are running on Linux desktop with CPU only $ bazel run --define MEDIAPIPE_DISABLE_GPU=1 \ mediapipe/examples/desktop/hello_world:hello_world diff --git a/mediapipe/docs/object_detection_desktop.md b/mediapipe/docs/object_detection_desktop.md index 88334993e..63de4f1ef 100644 --- a/mediapipe/docs/object_detection_desktop.md +++ b/mediapipe/docs/object_detection_desktop.md @@ -10,8 +10,9 @@ interested in running the same TensorfFlow Lite model on Android, please see the We show the object detection demo with both TensorFlow model and TensorFlow Lite model: -- [TensorFlow Object Detection Demo](#tensorflow-object-detection-demo) -- [TensorFlow Lite Object Detection Demo](#tensorflow-lite-object-detection-demo) +- [TensorFlow Object Detection Demo](#tensorflow-object-detection-demo) +- [TensorFlow Lite Object Detection Demo](#tensorflow-lite-object-detection-demo) +- [TensorFlow Lite Object Detection Demo with Webcam (CPU)](#tensorflow-lite-object-detection-demo) Note: If MediaPipe depends on OpenCV 2, please see the [known issues with OpenCV 2](#known-issues-with-opencv-2) section. @@ -207,6 +208,29 @@ $ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite --input_side_packets=input_video_path=,output_video_path= ``` +### TensorFlow Lite Object Detection Demo with Webcam (CPU) + +To build and run the TensorFlow Lite example on desktop (CPU) with Webcam, run: + +```bash +# Video from webcam running on desktop CPU +$ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \ + mediapipe/examples/desktop/object_detection:object_detection_cpu + +# It should print: +#Target //mediapipe/examples/desktop/object_detection:object_detection_cpu up-to-date: +# bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_cpu +#INFO: Elapsed time: 16.020s, Forge stats: 13001/13003 actions cached, 2.1s CPU used, 0.0s queue time, 89.0 MB ObjFS output (novel bytes: 88.0 MB), 0.0 MB local output, Critical Path: 10.01s, Remote (41.42% of the time): [queue: 0.00%, setup: 4.21%, process: 12.48%] +#INFO: Streaming build results to: http://sponge2/1824d4cc-ba63-4350-bdc0-aacbd45b902b +#INFO: Build completed successfully, 12154 total actions + +$ export GLOG_logtostderr=1 +# This will open up your webcam as long as it is connected and on +# Any errors is likely due to your webcam being not accessible +$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_cpu \ + --calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_live.pbtxt +``` + #### Graph ![graph visualization](images/object_detection_desktop_tflite.png) diff --git a/mediapipe/docs/visualizer.md b/mediapipe/docs/visualizer.md index 86cbfe1e6..9cab5dd4b 100644 --- a/mediapipe/docs/visualizer.md +++ b/mediapipe/docs/visualizer.md @@ -77,7 +77,7 @@ For instance, there are two graphs involved in the [hand detection example](./hand_detection_mobile_gpu.md): the main graph ([source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_mobile.pbtxt)) and its associated subgraph -([source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt)). +([source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/hand_detection_gpu.pbtxt)). To visualize them: * In the MediaPipe visualizer, click on the upload graph button and select the diff --git a/mediapipe/examples/desktop/BUILD b/mediapipe/examples/desktop/BUILD index 829601df7..3a35d724b 100644 --- a/mediapipe/examples/desktop/BUILD +++ b/mediapipe/examples/desktop/BUILD @@ -14,7 +14,9 @@ licenses(["notice"]) # Apache 2.0 -package(default_visibility = ["//mediapipe/examples:__subpackages__"]) +package(default_visibility = [ + "//visibility:public", +]) cc_library( name = "simple_run_graph_main", @@ -29,3 +31,44 @@ cc_library( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "demo_run_graph_main", + srcs = ["demo_run_graph_main.cc"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:commandlineflags", + "//mediapipe/framework/port:file_helpers", + "//mediapipe/framework/port:opencv_highgui", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:opencv_video", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + ], +) + +# Linux only. +# Must have a GPU with EGL support: +# ex: sudo aptitude install mesa-common-dev libegl1-mesa-dev libgles2-mesa-dev +# (or similar nvidia/amd equivalent) +cc_library( + name = "demo_run_graph_main_gpu", + srcs = ["demo_run_graph_main_gpu.cc"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:commandlineflags", + "//mediapipe/framework/port:file_helpers", + "//mediapipe/framework/port:opencv_highgui", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:opencv_video", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gpu_shared_data_internal", + ], +) diff --git a/mediapipe/examples/desktop/demo_run_graph_main.cc b/mediapipe/examples/desktop/demo_run_graph_main.cc new file mode 100644 index 000000000..14136560c --- /dev/null +++ b/mediapipe/examples/desktop/demo_run_graph_main.cc @@ -0,0 +1,146 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// An example of sending OpenCV webcam frames into a MediaPipe graph. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/commandlineflags.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/opencv_highgui_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/opencv_video_inc.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" + +constexpr char kInputStream[] = "input_video"; +constexpr char kOutputStream[] = "output_video"; +constexpr char kWindowName[] = "MediaPipe"; + +DEFINE_string( + calculator_graph_config_file, "", + "Name of file containing text format CalculatorGraphConfig proto."); +DEFINE_string(input_video_path, "", + "Full path of video to load. " + "If not provided, attempt to use a webcam."); +DEFINE_string(output_video_path, "", + "Full path of where to save result (.mp4 only). " + "If not provided, show result in a window."); + +::mediapipe::Status RunMPPGraph() { + std::string calculator_graph_config_contents; + MP_RETURN_IF_ERROR(mediapipe::file::GetContents( + FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + LOG(INFO) << "Get calculator graph config contents: " + << calculator_graph_config_contents; + mediapipe::CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + calculator_graph_config_contents); + + LOG(INFO) << "Initialize the calculator graph."; + mediapipe::CalculatorGraph graph; + MP_RETURN_IF_ERROR(graph.Initialize(config)); + + LOG(INFO) << "Initialize the camera or load the video."; + cv::VideoCapture capture; + const bool load_video = !FLAGS_input_video_path.empty(); + if (load_video) { + capture.open(FLAGS_input_video_path); + } else { + capture.open(0); + } + RET_CHECK(capture.isOpened()); + + cv::VideoWriter writer; + const bool save_video = !FLAGS_output_video_path.empty(); + if (save_video) { + LOG(INFO) << "Prepare video writer."; + cv::Mat test_frame; + capture.read(test_frame); // Consume first frame. + capture.set(cv::CAP_PROP_POS_AVI_RATIO, 0); // Rewind to beginning. + writer.open(FLAGS_output_video_path, + mediapipe::fourcc('a', 'v', 'c', '1'), // .mp4 + capture.get(cv::CAP_PROP_FPS), test_frame.size()); + RET_CHECK(writer.isOpened()); + } else { + cv::namedWindow(kWindowName, /*flags=WINDOW_AUTOSIZE*/ 1); + } + + LOG(INFO) << "Start running the calculator graph."; + ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller poller, + graph.AddOutputStreamPoller(kOutputStream)); + MP_RETURN_IF_ERROR(graph.StartRun({})); + + LOG(INFO) << "Start grabbing and processing frames."; + size_t frame_timestamp = 0; + bool grab_frames = true; + while (grab_frames) { + // Capture opencv camera or video frame. + cv::Mat camera_frame_raw; + capture >> camera_frame_raw; + if (camera_frame_raw.empty()) break; // End of video. + cv::Mat camera_frame; + cv::cvtColor(camera_frame_raw, camera_frame, cv::COLOR_BGR2RGB); + if (!load_video) { + cv::flip(camera_frame, camera_frame, /*flipcode=HORIZONTAL*/ 1); + } + + // Wrap Mat into an ImageFrame. + auto input_frame = absl::make_unique( + mediapipe::ImageFormat::SRGB, camera_frame.cols, camera_frame.rows, + mediapipe::ImageFrame::kDefaultAlignmentBoundary); + cv::Mat input_frame_mat = mediapipe::formats::MatView(input_frame.get()); + camera_frame.copyTo(input_frame_mat); + + // Send image packet into the graph. + MP_RETURN_IF_ERROR(graph.AddPacketToInputStream( + kInputStream, mediapipe::Adopt(input_frame.release()) + .At(mediapipe::Timestamp(frame_timestamp++)))); + + // Get the graph result packet, or stop if that fails. + mediapipe::Packet packet; + if (!poller.Next(&packet)) break; + auto& output_frame = packet.Get(); + + // Convert back to opencv for display or saving. + cv::Mat output_frame_mat = mediapipe::formats::MatView(&output_frame); + cv::cvtColor(output_frame_mat, output_frame_mat, cv::COLOR_RGB2BGR); + if (save_video) { + writer.write(output_frame_mat); + } else { + cv::imshow(kWindowName, output_frame_mat); + // Press any key to exit. + const int pressed_key = cv::waitKey(5); + if (pressed_key >= 0 && pressed_key != 255) grab_frames = false; + } + } + + LOG(INFO) << "Shutting down."; + if (writer.isOpened()) writer.release(); + MP_RETURN_IF_ERROR(graph.CloseInputStream(kInputStream)); + return graph.WaitUntilDone(); +} + +int main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + gflags::ParseCommandLineFlags(&argc, &argv, true); + ::mediapipe::Status run_status = RunMPPGraph(); + if (!run_status.ok()) { + LOG(ERROR) << "Failed to run the graph: " << run_status.message(); + } else { + LOG(INFO) << "Success!"; + } + return 0; +} diff --git a/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc b/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc new file mode 100644 index 000000000..4bf8cf97a --- /dev/null +++ b/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc @@ -0,0 +1,186 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// An example of sending OpenCV webcam frames into a MediaPipe graph. +// This example requires a linux computer and a GPU with EGL support drivers. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/commandlineflags.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/opencv_highgui_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/opencv_video_inc.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/gpu_shared_data_internal.h" + +constexpr char kInputStream[] = "input_video"; +constexpr char kOutputStream[] = "output_video"; +constexpr char kWindowName[] = "MediaPipe"; + +DEFINE_string( + calculator_graph_config_file, "", + "Name of file containing text format CalculatorGraphConfig proto."); +DEFINE_string(input_video_path, "", + "Full path of video to load. " + "If not provided, attempt to use a webcam."); +DEFINE_string(output_video_path, "", + "Full path of where to save result (.mp4 only). " + "If not provided, show result in a window."); + +::mediapipe::Status RunMPPGraph() { + std::string calculator_graph_config_contents; + MP_RETURN_IF_ERROR(mediapipe::file::GetContents( + FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + LOG(INFO) << "Get calculator graph config contents: " + << calculator_graph_config_contents; + mediapipe::CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + calculator_graph_config_contents); + + LOG(INFO) << "Initialize the calculator graph."; + mediapipe::CalculatorGraph graph; + MP_RETURN_IF_ERROR(graph.Initialize(config)); + + LOG(INFO) << "Initialize the GPU."; + ASSIGN_OR_RETURN(auto gpu_resources, mediapipe::GpuResources::Create()); + MP_RETURN_IF_ERROR(graph.SetGpuResources(std::move(gpu_resources))); + mediapipe::GlCalculatorHelper gpu_helper; + gpu_helper.InitializeForTest(graph.GetGpuResources().get()); + + LOG(INFO) << "Initialize the camera or load the video."; + cv::VideoCapture capture; + const bool load_video = !FLAGS_input_video_path.empty(); + if (load_video) { + capture.open(FLAGS_input_video_path); + } else { + capture.open(0); + } + RET_CHECK(capture.isOpened()); + + cv::VideoWriter writer; + const bool save_video = !FLAGS_output_video_path.empty(); + if (save_video) { + LOG(INFO) << "Prepare video writer."; + cv::Mat test_frame; + capture.read(test_frame); // Consume first frame. + capture.set(cv::CAP_PROP_POS_AVI_RATIO, 0); // Rewind to beginning. + writer.open(FLAGS_output_video_path, + mediapipe::fourcc('a', 'v', 'c', '1'), // .mp4 + capture.get(cv::CAP_PROP_FPS), test_frame.size()); + RET_CHECK(writer.isOpened()); + } else { + cv::namedWindow(kWindowName, /*flags=WINDOW_AUTOSIZE*/ 1); + } + + LOG(INFO) << "Start running the calculator graph."; + ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller poller, + graph.AddOutputStreamPoller(kOutputStream)); + MP_RETURN_IF_ERROR(graph.StartRun({})); + + LOG(INFO) << "Start grabbing and processing frames."; + size_t frame_timestamp = 0; + bool grab_frames = true; + while (grab_frames) { + // Capture opencv camera or video frame. + cv::Mat camera_frame_raw; + capture >> camera_frame_raw; + if (camera_frame_raw.empty()) break; // End of video. + cv::Mat camera_frame; + cv::cvtColor(camera_frame_raw, camera_frame, cv::COLOR_BGR2RGB); + if (!load_video) { + cv::flip(camera_frame, camera_frame, /*flipcode=HORIZONTAL*/ 1); + } + + // Wrap Mat into an ImageFrame. + auto input_frame = absl::make_unique( + mediapipe::ImageFormat::SRGB, camera_frame.cols, camera_frame.rows, + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); + cv::Mat input_frame_mat = mediapipe::formats::MatView(input_frame.get()); + camera_frame.copyTo(input_frame_mat); + + // Prepare and add graph input packet. + MP_RETURN_IF_ERROR( + gpu_helper.RunInGlContext([&input_frame, &frame_timestamp, &graph, + &gpu_helper]() -> ::mediapipe::Status { + // Convert ImageFrame to GpuBuffer. + auto texture = gpu_helper.CreateSourceTexture(*input_frame.get()); + auto gpu_frame = texture.GetFrame(); + glFlush(); + texture.Release(); + // Send GPU image packet into the graph. + MP_RETURN_IF_ERROR(graph.AddPacketToInputStream( + kInputStream, mediapipe::Adopt(gpu_frame.release()) + .At(mediapipe::Timestamp(frame_timestamp++)))); + return ::mediapipe::OkStatus(); + })); + + // Get the graph result packet, or stop if that fails. + mediapipe::Packet packet; + if (!poller.Next(&packet)) break; + std::unique_ptr output_frame; + + // Convert GpuBuffer to ImageFrame. + MP_RETURN_IF_ERROR(gpu_helper.RunInGlContext( + [&packet, &output_frame, &gpu_helper]() -> ::mediapipe::Status { + auto& gpu_frame = packet.Get(); + auto texture = gpu_helper.CreateSourceTexture(gpu_frame); + output_frame = absl::make_unique( + mediapipe::ImageFormatForGpuBufferFormat(gpu_frame.format()), + gpu_frame.width(), gpu_frame.height(), + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); + gpu_helper.BindFramebuffer(texture); + const auto info = + mediapipe::GlTextureInfoForGpuBufferFormat(gpu_frame.format(), 0); + glReadPixels(0, 0, texture.width(), texture.height(), info.gl_format, + info.gl_type, output_frame->MutablePixelData()); + glFlush(); + texture.Release(); + return ::mediapipe::OkStatus(); + })); + + // Convert back to opencv for display or saving. + cv::Mat output_frame_mat = mediapipe::formats::MatView(output_frame.get()); + cv::cvtColor(output_frame_mat, output_frame_mat, cv::COLOR_RGB2BGR); + if (save_video) { + writer.write(output_frame_mat); + } else { + cv::imshow(kWindowName, output_frame_mat); + // Press any key to exit. + const int pressed_key = cv::waitKey(5); + if (pressed_key >= 0 && pressed_key != 255) grab_frames = false; + } + } + + LOG(INFO) << "Shutting down."; + if (writer.isOpened()) writer.release(); + MP_RETURN_IF_ERROR(graph.CloseInputStream(kInputStream)); + return graph.WaitUntilDone(); +} + +int main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + gflags::ParseCommandLineFlags(&argc, &argv, true); + ::mediapipe::Status run_status = RunMPPGraph(); + if (!run_status.ok()) { + LOG(ERROR) << "Failed to run the graph: " << run_status.message(); + } else { + LOG(INFO) << "Success!"; + } + return 0; +} diff --git a/mediapipe/examples/desktop/face_detection/BUILD b/mediapipe/examples/desktop/face_detection/BUILD new file mode 100644 index 000000000..3d1dbcec8 --- /dev/null +++ b/mediapipe/examples/desktop/face_detection/BUILD @@ -0,0 +1,34 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//mediapipe/examples:__subpackages__"]) + +cc_binary( + name = "face_detection_cpu", + deps = [ + "//mediapipe/examples/desktop:demo_run_graph_main", + "//mediapipe/graphs/face_detection:desktop_tflite_calculators", + ], +) + +# Linux only +cc_binary( + name = "face_detection_gpu", + deps = [ + "//mediapipe/examples/desktop:demo_run_graph_main_gpu", + "//mediapipe/graphs/face_detection:mobile_calculators", + ], +) diff --git a/mediapipe/examples/desktop/hair_segmentation/BUILD b/mediapipe/examples/desktop/hair_segmentation/BUILD new file mode 100644 index 000000000..0338feddf --- /dev/null +++ b/mediapipe/examples/desktop/hair_segmentation/BUILD @@ -0,0 +1,26 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//mediapipe/examples:__subpackages__"]) + +# Linux only +cc_binary( + name = "hair_segmentation_gpu", + deps = [ + "//mediapipe/examples/desktop:demo_run_graph_main_gpu", + "//mediapipe/graphs/hair_segmentation:mobile_calculators", + ], +) diff --git a/mediapipe/examples/desktop/hand_tracking/BUILD b/mediapipe/examples/desktop/hand_tracking/BUILD new file mode 100644 index 000000000..1c99b00f6 --- /dev/null +++ b/mediapipe/examples/desktop/hand_tracking/BUILD @@ -0,0 +1,42 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//mediapipe/examples:__subpackages__"]) + +cc_binary( + name = "hand_tracking_tflite", + deps = [ + "//mediapipe/examples/desktop:simple_run_graph_main", + "//mediapipe/graphs/hand_tracking:desktop_tflite_calculators", + ], +) + +cc_binary( + name = "hand_tracking_cpu", + deps = [ + "//mediapipe/examples/desktop:demo_run_graph_main", + "//mediapipe/graphs/hand_tracking:desktop_tflite_calculators", + ], +) + +# Linux only +cc_binary( + name = "hand_tracking_gpu", + deps = [ + "//mediapipe/examples/desktop:demo_run_graph_main_gpu", + "//mediapipe/graphs/hand_tracking:mobile_calculators", + ], +) diff --git a/mediapipe/examples/desktop/object_detection/BUILD b/mediapipe/examples/desktop/object_detection/BUILD index 4ce4fd900..ee6832069 100644 --- a/mediapipe/examples/desktop/object_detection/BUILD +++ b/mediapipe/examples/desktop/object_detection/BUILD @@ -72,3 +72,11 @@ cc_binary( "//mediapipe/graphs/object_detection:desktop_tflite_calculators", ], ) + +cc_binary( + name = "object_detection_cpu", + deps = [ + "//mediapipe/examples/desktop:demo_run_graph_main", + "//mediapipe/graphs/object_detection:desktop_tflite_calculators", + ], +) diff --git a/mediapipe/examples/desktop/youtube8m/BUILD b/mediapipe/examples/desktop/youtube8m/BUILD index 8c97ffc8c..c25c5f50d 100644 --- a/mediapipe/examples/desktop/youtube8m/BUILD +++ b/mediapipe/examples/desktop/youtube8m/BUILD @@ -27,7 +27,7 @@ cc_binary( "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", - "//mediapipe/graphs/youtube8m:yt8m_calculators_deps", + "//mediapipe/graphs/youtube8m:yt8m_feature_extraction_calculators", # TODO: Figure out the minimum set of the kernels needed by this example. "@org_tensorflow//tensorflow/core:all_kernels", "@org_tensorflow//tensorflow/core:direct_session", diff --git a/mediapipe/examples/desktop/youtube8m/README.md b/mediapipe/examples/desktop/youtube8m/README.md index b777cd789..2989a7927 100644 --- a/mediapipe/examples/desktop/youtube8m/README.md +++ b/mediapipe/examples/desktop/youtube8m/README.md @@ -37,7 +37,9 @@ ```bash python -m mediapipe.examples.desktop.youtube8m.generate_input_sequence_example \ - --path_to_input_video=/absolute/path/to/the/local/video/file + --path_to_input_video=/absolute/path/to/the/local/video/file \ + --clip_start_time_sec=0 \ + --clip_end_time_sec=10 ``` 5. Run the MediaPipe binary to extract the features diff --git a/mediapipe/examples/desktop/youtube8m/generate_input_sequence_example.py b/mediapipe/examples/desktop/youtube8m/generate_input_sequence_example.py index b814f9ea4..7438a5134 100644 --- a/mediapipe/examples/desktop/youtube8m/generate_input_sequence_example.py +++ b/mediapipe/examples/desktop/youtube8m/generate_input_sequence_example.py @@ -37,20 +37,29 @@ def bytes23(string): def main(argv): - if len(argv) > 1: + if len(argv) > 3: raise app.UsageError('Too many command-line arguments.') if not flags.FLAGS.path_to_input_video: raise ValueError('You must specify the path to the input video.') + if not flags.FLAGS.clip_end_time_sec: + raise ValueError('You must specify the clip end timestamp in seconds.') + if flags.FLAGS.clip_start_time_sec >= flags.FLAGS.clip_end_time_sec: + raise ValueError( + 'The clip start time must be greater than the clip end time.') metadata = tf.train.SequenceExample() ms.set_clip_data_path(bytes23(flags.FLAGS.path_to_input_video), metadata) - ms.set_clip_start_timestamp(0, metadata) + ms.set_clip_start_timestamp( + flags.FLAGS.clip_start_time_sec * SECONDS_TO_MICROSECONDS, metadata) ms.set_clip_end_timestamp( - int(float(300 * SECONDS_TO_MICROSECONDS)), metadata) + flags.FLAGS.clip_end_time_sec * SECONDS_TO_MICROSECONDS, metadata) with open('/tmp/mediapipe/metadata.tfrecord', 'wb') as writer: writer.write(metadata.SerializeToString()) if __name__ == '__main__': flags.DEFINE_string('path_to_input_video', '', 'Path to the input video.') + flags.DEFINE_integer('clip_start_time_sec', 0, + 'Clip start timestamp in seconds') + flags.DEFINE_integer('clip_end_time_sec', 10, 'Clip end timestamp in seconds') app.run(main) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index bc83ebcd8..90a4f672c 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1134,6 +1134,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "//mediapipe/framework/tool:calculator_graph_template_cc_proto", + "//mediapipe/framework/tool:options_util", "//mediapipe/framework/tool:template_expander", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -1499,13 +1500,47 @@ cc_test( deps = [ ":calculator_context", ":calculator_framework", + ":test_calculators", + ":thread_pool_executor", ":timestamp", + ":type_map", + "//mediapipe/calculators/core:counting_source_calculator", + "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", + "//mediapipe/framework/stream_handler:barrier_input_stream_handler", + "//mediapipe/framework/stream_handler:early_close_input_stream_handler", + "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "//mediapipe/framework/stream_handler:mux_input_stream_handler", + "//mediapipe/framework/tool:sink", + ], +) + +cc_test( + name = "calculator_graph_side_packet_test", + size = "small", + srcs = [ + "calculator_graph_side_packet_test.cc", + ], + visibility = ["//visibility:public"], + deps = [ + ":calculator_framework", + ":test_calculators", + "//mediapipe/calculators/core:counting_source_calculator", + "//mediapipe/calculators/core:mux_calculator", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:sink", + "@com_google_absl//absl/time", ], ) diff --git a/mediapipe/framework/calculator_graph_bounds_test.cc b/mediapipe/framework/calculator_graph_bounds_test.cc index 750042522..1b8c3e9f2 100644 --- a/mediapipe/framework/calculator_graph_bounds_test.cc +++ b/mediapipe/framework/calculator_graph_bounds_test.cc @@ -21,11 +21,258 @@ #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/thread_pool_executor.h" #include "mediapipe/framework/timestamp.h" namespace mediapipe { namespace { +typedef std::function<::mediapipe::Status(CalculatorContext* cc)> + CalculatorContextFunction; + +// A simple Semaphore for synchronizing test threads. +class AtomicSemaphore { + public: + AtomicSemaphore(int64_t supply) : supply_(supply) {} + void Acquire(int64_t amount) { + while (supply_.fetch_sub(amount) - amount < 0) { + Release(amount); + } + } + void Release(int64_t amount) { supply_ += amount; } + + private: + std::atomic supply_; +}; + +// A mediapipe::Executor that signals the start and finish of each task. +class CountingExecutor : public Executor { + public: + CountingExecutor(int num_threads, std::function start_callback, + std::function finish_callback) + : thread_pool_(num_threads), + start_callback_(std::move(start_callback)), + finish_callback_(std::move(finish_callback)) { + thread_pool_.StartWorkers(); + } + void Schedule(std::function task) override { + start_callback_(); + thread_pool_.Schedule([this, task] { + task(); + finish_callback_(); + }); + } + + private: + ThreadPool thread_pool_; + std::function start_callback_; + std::function finish_callback_; +}; + +// A Calculator that adds the integer values in the packets in all the input +// streams and outputs the sum to the output stream. +class IntAdderCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + cc->Inputs().Index(i).Set(); + } + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + int sum = 0; + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + sum += cc->Inputs().Index(i).Get(); + } + cc->Outputs().Index(0).Add(new int(sum), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(IntAdderCalculator); + +template +class TypedSinkCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + return ::mediapipe::OkStatus(); + } +}; +typedef TypedSinkCalculator StringSinkCalculator; +typedef TypedSinkCalculator IntSinkCalculator; +REGISTER_CALCULATOR(StringSinkCalculator); +REGISTER_CALCULATOR(IntSinkCalculator); + +// A Calculator that passes an input packet through if it contains an even +// integer. +class EvenIntFilterCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + int value = cc->Inputs().Index(0).Get(); + if (value % 2 == 0) { + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + } else { + cc->Outputs().Index(0).SetNextTimestampBound( + cc->InputTimestamp().NextAllowedInStream()); + } + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(EvenIntFilterCalculator); + +// A Calculator that passes packets through or not, depending on a second +// input. The first input stream's packets are only propagated if the second +// input stream carries the value true. +class ValveCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Inputs().Index(1).Set(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header()); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + if (cc->Inputs().Index(1).Get()) { + cc->GetCounter("PassThrough")->Increment(); + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + } else { + cc->GetCounter("Block")->Increment(); + // The next timestamp bound is the minimum timestamp that the next packet + // can have, so, if we want to inform the downstream that no packet at + // InputTimestamp() is coming, we need to set it to the next value. + // We could also just call SetOffset(TimestampDiff(0)) in Open, and then + // we would not have to call this manually. + cc->Outputs().Index(0).SetNextTimestampBound( + cc->InputTimestamp().NextAllowedInStream()); + } + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(ValveCalculator); + +// A Calculator that simply passes its input Packets and header through, +// but shifts the timestamp. +class TimeShiftCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + cc->InputSidePackets().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + // Input: arbitrary Packets. + // Output: copy of the input. + cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header()); + shift_ = cc->InputSidePackets().Index(0).Get(); + cc->SetOffset(shift_); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + cc->GetCounter("PassThrough")->Increment(); + cc->Outputs().Index(0).AddPacket( + cc->Inputs().Index(0).Value().At(cc->InputTimestamp() + shift_)); + return ::mediapipe::OkStatus(); + } + + private: + TimestampDiff shift_; +}; +REGISTER_CALCULATOR(TimeShiftCalculator); + +// A source calculator that alternates between outputting an integer (0, 1, 2, +// ..., 100) and setting the next timestamp bound. The timestamps of the output +// packets and next timestamp bounds are 0, 10, 20, 30, ... +// +// T=0 Output 0 +// T=10 Set timestamp bound +// T=20 Output 1 +// T=30 Set timestamp bound +// ... +// T=2000 Output 100 +class OutputAndBoundSourceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + counter_ = 0; + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + Timestamp timestamp(counter_); + if (counter_ % 20 == 0) { + cc->Outputs().Index(0).AddPacket( + MakePacket(counter_ / 20).At(timestamp)); + } else { + cc->Outputs().Index(0).SetNextTimestampBound(timestamp); + } + if (counter_ == 2000) { + return tool::StatusStop(); + } + counter_ += 10; + return ::mediapipe::OkStatus(); + } + + private: + int counter_; +}; +REGISTER_CALCULATOR(OutputAndBoundSourceCalculator); + +// A calculator that outputs an initial packet of value 0 at time 0 in the +// Open() method, and then delays each input packet by 20 time units in the +// Process() method. The input stream and output stream have the integer type. +class Delay20Calculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->SetOffset(TimestampDiff(20)); + cc->Outputs().Index(0).AddPacket(MakePacket(0).At(Timestamp(0))); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + const Packet& packet = cc->Inputs().Index(0).Value(); + Timestamp timestamp = packet.Timestamp() + 20; + cc->Outputs().Index(0).AddPacket(packet.At(timestamp)); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(Delay20Calculator); + class CustomBoundCalculator : public CalculatorBase { public: static ::mediapipe::Status GetContract(CalculatorContract* cc) { @@ -45,8 +292,280 @@ class CustomBoundCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(CustomBoundCalculator); +// Test that SetNextTimestampBound propagates. +TEST(CalculatorGraph, SetNextTimestampBoundPropagation) { + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in' + input_stream: 'gate' + node { + calculator: 'ValveCalculator' + input_stream: 'in' + input_stream: 'gate' + output_stream: 'gated' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'gated' + output_stream: 'passed' + } + node { + calculator: 'TimeShiftCalculator' + input_stream: 'passed' + output_stream: 'shifted' + input_side_packet: 'shift' + } + node { + calculator: 'MergeCalculator' + input_stream: 'in' + input_stream: 'shifted' + output_stream: 'merged' + } + node { + name: 'merged_output' + calculator: 'PassThroughCalculator' + input_stream: 'merged' + output_stream: 'out' + } + )"); + + Timestamp timestamp = Timestamp(0); + auto send_inputs = [&graph, ×tamp](int input, bool pass) { + ++timestamp; + MP_EXPECT_OK(graph.AddPacketToInputStream( + "in", MakePacket(input).At(timestamp))); + MP_EXPECT_OK(graph.AddPacketToInputStream( + "gate", MakePacket(pass).At(timestamp))); + }; + + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.StartRun({{"shift", MakePacket(0)}})); + + auto pass_counter = + graph.GetCounterFactory()->GetCounter("ValveCalculator-PassThrough"); + auto block_counter = + graph.GetCounterFactory()->GetCounter("ValveCalculator-Block"); + auto merged_counter = + graph.GetCounterFactory()->GetCounter("merged_output-PassThrough"); + + send_inputs(1, true); + send_inputs(2, true); + send_inputs(3, false); + send_inputs(4, false); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Verify that MergeCalculator was able to run even when the gated branch + // was blocked. + EXPECT_EQ(2, pass_counter->Get()); + EXPECT_EQ(2, block_counter->Get()); + EXPECT_EQ(4, merged_counter->Get()); + + send_inputs(5, true); + send_inputs(6, false); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + EXPECT_EQ(3, pass_counter->Get()); + EXPECT_EQ(3, block_counter->Get()); + EXPECT_EQ(6, merged_counter->Get()); + + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); + + // Now test with time shift + MP_ASSERT_OK(graph.StartRun({{"shift", MakePacket(-1)}})); + + send_inputs(7, true); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // The merger should have run only once now, at timestamp 6, with inputs + // . If we do not respect the offset and unblock the merger for + // timestamp 7 too, then it will have run twice, with 6: and + // 7: <7, null>. + EXPECT_EQ(4, pass_counter->Get()); + EXPECT_EQ(3, block_counter->Get()); + EXPECT_EQ(7, merged_counter->Get()); + + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); + + EXPECT_EQ(4, pass_counter->Get()); + EXPECT_EQ(3, block_counter->Get()); + EXPECT_EQ(8, merged_counter->Get()); +} + +// Both input streams of the calculator node have the same next timestamp +// bound. One input stream has a packet at that timestamp. The other input +// stream is empty. We should not run the Process() method of the node in this +// case. +TEST(CalculatorGraph, NotAllInputPacketsAtNextTimestampBoundAvailable) { + // + // in0_unfiltered in1_to_be_filtered + // | | + // | V + // | +-----------------------+ + // | |EvenIntFilterCalculator| + // | +-----------------------+ + // | | + // \ / + // \ / in1_filtered + // \ / + // | | + // V V + // +------------------+ + // |IntAdderCalculator| + // +------------------+ + // | + // V + // out + // + CalculatorGraph graph; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'in0_unfiltered' + input_stream: 'in1_to_be_filtered' + node { + calculator: 'EvenIntFilterCalculator' + input_stream: 'in1_to_be_filtered' + output_stream: 'in1_filtered' + } + node { + calculator: 'IntAdderCalculator' + input_stream: 'in0_unfiltered' + input_stream: 'in1_filtered' + output_stream: 'out' + } + )"); + std::vector packet_dump; + tool::AddVectorSink("out", &config, &packet_dump); + + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.StartRun({})); + + Timestamp timestamp = Timestamp(0); + + // We send an integer with timestamp 1 to the in0_unfiltered input stream of + // the IntAdderCalculator. We then send an even integer with timestamp 1 to + // the EvenIntFilterCalculator. This packet will go through and + // the IntAdderCalculator will run. The next timestamp bounds of both the + // input streams of the IntAdderCalculator will become 2. + + ++timestamp; // Timestamp 1. + MP_EXPECT_OK(graph.AddPacketToInputStream("in0_unfiltered", + MakePacket(1).At(timestamp))); + MP_EXPECT_OK(graph.AddPacketToInputStream("in1_to_be_filtered", + MakePacket(2).At(timestamp))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(1, packet_dump.size()); + EXPECT_EQ(3, packet_dump[0].Get()); + + // We send an odd integer with timestamp 2 to the EvenIntFilterCalculator. + // This packet will be filtered out and the next timestamp bound of the + // in1_filtered input stream of the IntAdderCalculator will become 3. + + ++timestamp; // Timestamp 2. + MP_EXPECT_OK(graph.AddPacketToInputStream("in1_to_be_filtered", + MakePacket(3).At(timestamp))); + + // We send an integer with timestamp 3 to the in0_unfiltered input stream of + // the IntAdderCalculator. MediaPipe should propagate the next timestamp bound + // across the IntAdderCalculator but should not run its Process() method. + + ++timestamp; // Timestamp 3. + MP_EXPECT_OK(graph.AddPacketToInputStream("in0_unfiltered", + MakePacket(3).At(timestamp))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(1, packet_dump.size()); + + // We send an even integer with timestamp 3 to the IntAdderCalculator. This + // packet will go through and the IntAdderCalculator will run. + + MP_EXPECT_OK(graph.AddPacketToInputStream("in1_to_be_filtered", + MakePacket(4).At(timestamp))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(2, packet_dump.size()); + EXPECT_EQ(7, packet_dump[1].Get()); + + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_EQ(2, packet_dump.size()); +} + +TEST(CalculatorGraph, PropagateBoundLoop) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'OutputAndBoundSourceCalculator' + output_stream: 'integers' + } + node { + calculator: 'IntAdderCalculator' + input_stream: 'integers' + input_stream: 'old_sum' + input_stream_info: { + tag_index: ':1' # 'old_sum' + back_edge: true + } + output_stream: 'sum' + input_stream_handler { + input_stream_handler: 'EarlyCloseInputStreamHandler' + } + } + node { + calculator: 'Delay20Calculator' + input_stream: 'sum' + output_stream: 'old_sum' + } + )"); + std::vector packet_dump; + tool::AddVectorSink("sum", &config, &packet_dump); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.Run()); + ASSERT_EQ(101, packet_dump.size()); + int sum = 0; + for (int i = 0; i < 101; ++i) { + sum += i; + EXPECT_EQ(sum, packet_dump[i].Get()); + EXPECT_EQ(Timestamp(i * 20), packet_dump[i].Timestamp()); + } +} + +TEST(CalculatorGraph, CheckBatchProcessingBoundPropagation) { + // The timestamp bound sent by OutputAndBoundSourceCalculator shouldn't be + // directly propagated to the output stream when PassThroughCalculator has + // anything in its default calculator context for batch processing. Otherwise, + // the sink calculator's input stream should report packet timestamp + // mismatches. + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'OutputAndBoundSourceCalculator' + output_stream: 'integers' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'integers' + output_stream: 'output' + input_stream_handler { + input_stream_handler: "DefaultInputStreamHandler" + options: { + [mediapipe.DefaultInputStreamHandlerOptions.ext]: { + batch_size: 10 + } + } + } + } + node { calculator: 'IntSinkCalculator' input_stream: 'output' } + )"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.Run()); +} + // Shows that ImmediateInputStreamHandler allows bounds propagation. -TEST(CalculatorGraphBounds, ImmediateHandlerBounds) { +TEST(CalculatorGraphBoundsTest, ImmediateHandlerBounds) { // CustomBoundCalculator produces only timestamp bounds. // The first PassThroughCalculator propagates bounds using SetOffset(0). // The second PassthroughCalculator delivers an output packet whenever the @@ -101,5 +620,261 @@ TEST(CalculatorGraphBounds, ImmediateHandlerBounds) { EXPECT_EQ(output_packets.size(), 4); } +// A Calculator that only sets timestamp bound by SetOffset(). +class OffsetBoundCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->SetOffset(0); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(OffsetBoundCalculator); + +// A Calculator that produces a packet for each call to Process. +class BoundToPacketCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + cc->Outputs().Index(0).AddPacket(Adopt(new int(33))); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(BoundToPacketCalculator); + +// Verifies that SetOffset still propagates when Process is called and +// produces no output packets. +TEST(CalculatorGraphBoundsTest, OffsetBoundPropagation) { + // OffsetBoundCalculator produces only timestamp bounds. + // The PassthroughCalculator delivers an output packet whenever the + // OffsetBoundCalculator delivers a timestamp bound. + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'input' + node { + calculator: 'OffsetBoundCalculator' + input_stream: 'input' + output_stream: 'bounds' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'bounds' + input_stream: 'input' + output_stream: 'bounds_output' + output_stream: 'output' + } + )"); + CalculatorGraph graph; + std::vector output_packets; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.ObserveOutputStream("output", [&](const Packet& p) { + output_packets.push_back(p); + return ::mediapipe::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Add four packets into the graph. + constexpr int kNumInputs = 4; + for (int i = 0; i < kNumInputs; ++i) { + Packet p = MakePacket(33).At(Timestamp(i)); + MP_ASSERT_OK(graph.AddPacketToInputStream("input", p)); + } + + // Four packets arrive at the output only if timestamp bounds are propagated. + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_EQ(output_packets.size(), kNumInputs); + + // Shutdown the graph. + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// Shows that bounds changes alone do not invoke Process. +// Note: Bounds changes alone will invoke Process eventually +// when SetOffset is cleared, see: go/mediapipe-realtime-graph. +TEST(CalculatorGraphBoundsTest, BoundWithoutInputPackets) { + // OffsetBoundCalculator produces only timestamp bounds. + // The BoundToPacketCalculator delivers an output packet whenever the + // OffsetBoundCalculator delivers a timestamp bound. + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'input' + node { + calculator: 'OffsetBoundCalculator' + input_stream: 'input' + output_stream: 'bounds' + } + node { + calculator: 'BoundToPacketCalculator' + input_stream: 'bounds' + output_stream: 'output' + } + )"); + CalculatorGraph graph; + std::vector output_packets; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.ObserveOutputStream("output", [&](const Packet& p) { + output_packets.push_back(p); + return ::mediapipe::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Add four packets into the graph. + constexpr int kNumInputs = 4; + for (int i = 0; i < kNumInputs; ++i) { + Packet p = MakePacket(33).At(Timestamp(i)); + MP_ASSERT_OK(graph.AddPacketToInputStream("input", p)); + } + + // No packets arrive, because updated timestamp bounds do not invoke + // BoundToPacketCalculator::Process. + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_EQ(output_packets.size(), 0); + + // Shutdown the graph. + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// Shows that when fixed-size-input-stream-hanlder drops packets, +// no timetamp bounds are announced. +TEST(CalculatorGraphBoundsTest, FixedSizeHandlerBounds) { + // LambdaCalculator with FixedSizeInputStreamHandler will drop packets + // while it is busy. Timetamps for the dropped packets are only relevant + // when SetOffset is active on the LambdaCalculator. + // The PassthroughCalculator delivers an output packet whenever the + // LambdaCalculator delivers a timestamp bound. + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'input' + input_side_packet: 'open_function' + input_side_packet: 'process_function' + node { + calculator: 'LambdaCalculator' + input_stream: 'input' + output_stream: 'thinned' + input_side_packet: 'OPEN:open_fn' + input_side_packet: 'PROCESS:process_fn' + input_stream_handler { + input_stream_handler: "FixedSizeInputStreamHandler" + } + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'thinned' + input_stream: 'input' + output_stream: 'thinned_output' + output_stream: 'output' + } + )"); + CalculatorGraph graph; + + // The task_semaphore counts the number of running tasks. + constexpr int kTaskSupply = 10; + AtomicSemaphore task_semaphore(/*supply=*/kTaskSupply); + + // This executor invokes a callback at the start and finish of each task. + auto executor = std::make_shared( + 4, /*start_callback=*/[&]() { task_semaphore.Acquire(1); }, + /*finish_callback=*/[&]() { task_semaphore.Release(1); }); + MP_ASSERT_OK(graph.SetExecutor(/*name=*/"", executor)); + + // Monitor output from the graph. + MP_ASSERT_OK(graph.Initialize(config)); + std::vector outputs; + MP_ASSERT_OK(graph.ObserveOutputStream("output", [&](const Packet& p) { + outputs.push_back(p); + return ::mediapipe::OkStatus(); + })); + std::vector thinned_outputs; + MP_ASSERT_OK( + graph.ObserveOutputStream("thinned_output", [&](const Packet& p) { + thinned_outputs.push_back(p); + return ::mediapipe::OkStatus(); + })); + + // The enter_semaphore is used to wait for LambdaCalculator::Process. + // The exit_semaphore blocks and unblocks LambdaCalculator::Process. + AtomicSemaphore enter_semaphore(0); + AtomicSemaphore exit_semaphore(0); + CalculatorContextFunction open_fn = [&](CalculatorContext* cc) { + cc->SetOffset(0); + return ::mediapipe::OkStatus(); + }; + CalculatorContextFunction process_fn = [&](CalculatorContext* cc) { + enter_semaphore.Release(1); + exit_semaphore.Acquire(1); + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + return ::mediapipe::OkStatus(); + }; + MP_ASSERT_OK(graph.StartRun({ + {"open_fn", Adopt(new auto(open_fn))}, + {"process_fn", Adopt(new auto(process_fn))}, + })); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Add four packets into the graph. + constexpr int kNumInputs = 4; + for (int i = 0; i < kNumInputs; ++i) { + Packet p = MakePacket(33).At(Timestamp(i)); + MP_ASSERT_OK(graph.AddPacketToInputStream("input", p)); + } + + // Wait until only the LambdaCalculator is running, + // by wating until the task_semaphore has only one token occupied. + // At this point 2 packets were dropped by the FixedSizeInputStreamHandler. + task_semaphore.Acquire(kTaskSupply - 1); + task_semaphore.Release(kTaskSupply - 1); + + // No timestamp bounds and no packets are emitted yet. + EXPECT_EQ(outputs.size(), 0); + EXPECT_EQ(thinned_outputs.size(), 0); + + // Allow the first LambdaCalculator::Process call to complete. + // Wait for the second LambdaCalculator::Process call to begin. + // Wait until only the LambdaCalculator is running. + enter_semaphore.Acquire(1); + exit_semaphore.Release(1); + enter_semaphore.Acquire(1); + task_semaphore.Acquire(kTaskSupply - 1); + task_semaphore.Release(kTaskSupply - 1); + + // Only one timestamp bound and one packet are emitted. + EXPECT_EQ(outputs.size(), 1); + EXPECT_EQ(thinned_outputs.size(), 1); + + // Allow the second LambdaCalculator::Process call to complete. + exit_semaphore.Release(1); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Packets 1 and 2 were dropped by the FixedSizeInputStreamHandler. + EXPECT_EQ(thinned_outputs.size(), 2); + EXPECT_EQ(thinned_outputs[0].Timestamp(), Timestamp(0)); + EXPECT_EQ(thinned_outputs[1].Timestamp(), Timestamp(kNumInputs - 1)); + EXPECT_EQ(outputs.size(), kNumInputs); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/calculator_graph_side_packet_test.cc b/mediapipe/framework/calculator_graph_side_packet_test.cc new file mode 100644 index 000000000..166826ff1 --- /dev/null +++ b/mediapipe/framework/calculator_graph_side_packet_test.cc @@ -0,0 +1,747 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +// Takes an input stream packet and passes it (with timestamp removed) as an +// output side packet. +class OutputSidePacketInProcessCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->OutputSidePackets().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + cc->OutputSidePackets().Index(0).Set( + cc->Inputs().Index(0).Value().At(Timestamp::Unset())); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(OutputSidePacketInProcessCalculator); + +// Takes an input stream packet and counts the number of the packets it +// receives. Outputs the total number of packets as a side packet in Close. +class CountAndOutputSummarySidePacketInCloseCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->OutputSidePackets().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + ++count_; + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Close(CalculatorContext* cc) final { + cc->OutputSidePackets().Index(0).Set( + MakePacket(count_).At(Timestamp::Unset())); + return ::mediapipe::OkStatus(); + } + + int count_ = 0; +}; +REGISTER_CALCULATOR(CountAndOutputSummarySidePacketInCloseCalculator); + +// Takes an input stream packet and passes it (with timestamp intact) as an +// output side packet. This triggers an error in the graph. +class OutputSidePacketWithTimestampCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->OutputSidePackets().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + cc->OutputSidePackets().Index(0).Set(cc->Inputs().Index(0).Value()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(OutputSidePacketWithTimestampCalculator); + +// Generates an output side packet containing the integer 1. +class IntegerOutputSidePacketCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->OutputSidePackets().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->OutputSidePackets().Index(0).Set(MakePacket(1)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + LOG(FATAL) << "Not reached."; + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(IntegerOutputSidePacketCalculator); + +// Generates an output side packet containing the sum of the two integer input +// side packets. +class SidePacketAdderCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->InputSidePackets().Index(0).Set(); + cc->InputSidePackets().Index(1).Set(); + cc->OutputSidePackets().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->OutputSidePackets().Index(0).Set( + MakePacket(cc->InputSidePackets().Index(1).Get() + + cc->InputSidePackets().Index(0).Get())); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + LOG(FATAL) << "Not reached."; + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(SidePacketAdderCalculator); + +// Produces an output packet with the PostStream timestamp containing the +// input side packet. +class SidePacketToStreamPacketCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->InputSidePackets().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->InputSidePackets().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->Outputs().Index(0).AddPacket( + cc->InputSidePackets().Index(0).At(Timestamp::PostStream())); + cc->Outputs().Index(0).Close(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + return ::mediapipe::tool::StatusStop(); + } +}; +REGISTER_CALCULATOR(SidePacketToStreamPacketCalculator); + +// Packet generator for an arbitrary unit64 packet. +class Uint64PacketGenerator : public PacketGenerator { + public: + static ::mediapipe::Status FillExpectations( + const PacketGeneratorOptions& extendable_options, + PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { + output_side_packets->Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + static ::mediapipe::Status Generate( + const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, PacketSet* output_side_packets) { + output_side_packets->Index(0) = Adopt(new uint64(15LL << 32 | 5)); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_PACKET_GENERATOR(Uint64PacketGenerator); + +TEST(CalculatorGraph, OutputSidePacketInProcess) { + const int64 offset = 100; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "offset" + node { + calculator: "OutputSidePacketInProcessCalculator" + input_stream: "offset" + output_side_packet: "offset" + } + node { + calculator: "SidePacketToStreamPacketCalculator" + output_stream: "output" + input_side_packet: "offset" + } + )"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + std::vector output_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output", [&output_packets](const Packet& packet) { + output_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + + // Run the graph twice. + for (int run = 0; run < 2; ++run) { + output_packets.clear(); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "offset", MakePacket(offset).At(Timestamp(0)))); + MP_ASSERT_OK(graph.CloseInputStream("offset")); + MP_ASSERT_OK(graph.WaitUntilDone()); + ASSERT_EQ(1, output_packets.size()); + EXPECT_EQ(offset, output_packets[0].Get().Value()); + } +} + +// A PacketGenerator that simply passes its input Packets through +// unchanged. The inputs may be specified by tag or index. The outputs +// must match the inputs exactly. Any options may be specified and will +// also be ignored. +class PassThroughGenerator : public PacketGenerator { + public: + static ::mediapipe::Status FillExpectations( + const PacketGeneratorOptions& extendable_options, PacketTypeSet* inputs, + PacketTypeSet* outputs) { + if (!inputs->TagMap()->SameAs(*outputs->TagMap())) { + return ::mediapipe::InvalidArgumentError( + "Input and outputs to PassThroughGenerator must use the same tags " + "and indexes."); + } + for (CollectionItemId id = inputs->BeginId(); id < inputs->EndId(); ++id) { + inputs->Get(id).SetAny(); + outputs->Get(id).SetSameAs(&inputs->Get(id)); + } + return ::mediapipe::OkStatus(); + } + + static ::mediapipe::Status Generate( + const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, PacketSet* output_side_packets) { + for (CollectionItemId id = input_side_packets.BeginId(); + id < input_side_packets.EndId(); ++id) { + output_side_packets->Get(id) = input_side_packets.Get(id); + } + return ::mediapipe::OkStatus(); + } +}; +REGISTER_PACKET_GENERATOR(PassThroughGenerator); + +TEST(CalculatorGraph, SharePacketGeneratorGraph) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: 'CountingSourceCalculator' + output_stream: 'count1' + input_side_packet: 'MAX_COUNT:max_count1' + } + node { + calculator: 'CountingSourceCalculator' + output_stream: 'count2' + input_side_packet: 'MAX_COUNT:max_count2' + } + node { + calculator: 'CountingSourceCalculator' + output_stream: 'count3' + input_side_packet: 'MAX_COUNT:max_count3' + } + node { + calculator: 'CountingSourceCalculator' + output_stream: 'count4' + input_side_packet: 'MAX_COUNT:max_count4' + } + node { + calculator: 'PassThroughCalculator' + input_side_packet: 'MAX_COUNT:max_count5' + output_side_packet: 'MAX_COUNT:max_count6' + } + node { + calculator: 'CountingSourceCalculator' + output_stream: 'count5' + input_side_packet: 'MAX_COUNT:max_count6' + } + packet_generator { + packet_generator: 'PassThroughGenerator' + input_side_packet: 'max_count1' + output_side_packet: 'max_count2' + } + packet_generator { + packet_generator: 'PassThroughGenerator' + input_side_packet: 'max_count4' + output_side_packet: 'max_count5' + } + )"); + + // At this point config is a standard config which specifies both + // calculators and packet_factories/packet_genators. The following + // code is an example of reusing side packets across a number of + // CalculatorGraphs. It is particularly informative to note how each + // side packet is created. + // + // max_count1 is set for all graphs by a PacketFactory in the config. + // The side packet is created by generator_graph.InitializeGraph(). + // + // max_count2 is set for all graphs by a PacketGenerator in the config. + // The side packet is created by generator_graph.InitializeGraph() + // because max_count1 is available at that time. + // + // max_count3 is set for all graphs by directly being specified as an + // argument to generator_graph.InitializeGraph(). + // + // max_count4 is set per graph because it is directly specified as an + // argument to generator_graph.ProcessGraph(). + // + // max_count5 is set per graph by a PacketGenerator which is run when + // generator_graph.ProcessGraph() is run (because max_count4 isn't + // available until then). + + // Before anything else, split the graph config into two parts, one + // with the PacketFactory and PacketGenerator config and the other + // with the Calculator config. + CalculatorGraphConfig calculator_config = config; + calculator_config.clear_packet_factory(); + calculator_config.clear_packet_generator(); + CalculatorGraphConfig generator_config = config; + generator_config.clear_node(); + + // Next, create a ValidatedGraphConfig for both configs. + ValidatedGraphConfig validated_calculator_config; + MP_ASSERT_OK(validated_calculator_config.Initialize(calculator_config)); + ValidatedGraphConfig validated_generator_config; + MP_ASSERT_OK(validated_generator_config.Initialize(generator_config)); + + // Create a PacketGeneratorGraph. Side packets max_count1, max_count2, + // and max_count3 are created upon initialization. + // Note that validated_generator_config must outlive generator_graph. + PacketGeneratorGraph generator_graph; + MP_ASSERT_OK( + generator_graph.Initialize(&validated_generator_config, nullptr, + {{"max_count1", MakePacket(10)}, + {"max_count3", MakePacket(20)}})); + ASSERT_THAT(generator_graph.BasePackets(), + testing::ElementsAre(testing::Key("max_count1"), + testing::Key("max_count2"), + testing::Key("max_count3"))); + + // Create a bunch of graphs. + std::vector> graphs; + for (int i = 0; i < 100; ++i) { + graphs.emplace_back(absl::make_unique()); + // Do not pass extra side packets here. + // Note that validated_calculator_config must outlive the graph. + MP_ASSERT_OK(graphs.back()->Initialize(calculator_config, {})); + } + // Run a bunch of graphs, reusing side packets max_count1, max_count2, + // and max_count3. The side packet max_count4 is added per run, + // and triggers the execution of a packet generator which generates + // max_count5. + for (int i = 0; i < 100; ++i) { + std::map all_side_packets; + // Creates max_count4 and max_count5. + MP_ASSERT_OK(generator_graph.RunGraphSetup( + {{"max_count4", MakePacket(30 + i)}}, &all_side_packets)); + ASSERT_THAT(all_side_packets, + testing::ElementsAre( + testing::Key("max_count1"), testing::Key("max_count2"), + testing::Key("max_count3"), testing::Key("max_count4"), + testing::Key("max_count5"))); + // Pass all the side packets prepared by generator_graph here. + MP_ASSERT_OK(graphs[i]->Run(all_side_packets)); + // TODO Verify the actual output. + } + + // Destroy all the graphs. + graphs.clear(); +} + +TEST(CalculatorGraph, OutputSidePacketAlreadySet) { + const int64 offset = 100; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "offset" + node { + calculator: "OutputSidePacketInProcessCalculator" + input_stream: "offset" + output_side_packet: "offset" + } + )"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.StartRun({})); + // Send two input packets to cause OutputSidePacketInProcessCalculator to + // set the output side packet twice. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "offset", MakePacket(offset).At(Timestamp(0)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "offset", MakePacket(offset).At(Timestamp(1)))); + MP_ASSERT_OK(graph.CloseInputStream("offset")); + + ::mediapipe::Status status = graph.WaitUntilDone(); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kAlreadyExists); + EXPECT_THAT(status.message(), testing::HasSubstr("was already set.")); +} + +TEST(CalculatorGraph, OutputSidePacketWithTimestamp) { + const int64 offset = 100; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "offset" + node { + calculator: "OutputSidePacketWithTimestampCalculator" + input_stream: "offset" + output_side_packet: "offset" + } + )"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.StartRun({})); + // The OutputSidePacketWithTimestampCalculator neglects to clear the + // timestamp in the input packet when it copies the input packet to the + // output side packet. The timestamp value should appear in the error + // message. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "offset", MakePacket(offset).At(Timestamp(237)))); + MP_ASSERT_OK(graph.CloseInputStream("offset")); + ::mediapipe::Status status = graph.WaitUntilDone(); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), testing::HasSubstr("has a timestamp 237.")); +} + +TEST(CalculatorGraph, OutputSidePacketConsumedBySourceNode) { + const int max_count = 10; + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "max_count" + node { + calculator: "OutputSidePacketInProcessCalculator" + input_stream: "max_count" + output_side_packet: "max_count" + } + node { + calculator: "CountingSourceCalculator" + output_stream: "count" + input_side_packet: "MAX_COUNT:max_count" + } + node { + calculator: "PassThroughCalculator" + input_stream: "count" + output_stream: "output" + } + )"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + std::vector output_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output", [&output_packets](const Packet& packet) { + output_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + // Wait until the graph is idle so that + // Scheduler::TryToScheduleNextSourceLayer() gets called. + // Scheduler::TryToScheduleNextSourceLayer() should not activate source + // nodes that haven't been opened. We can't call graph.WaitUntilIdle() + // because the graph has a source node. + absl::SleepFor(absl::Milliseconds(10)); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "max_count", MakePacket(max_count).At(Timestamp(0)))); + MP_ASSERT_OK(graph.CloseInputStream("max_count")); + MP_ASSERT_OK(graph.WaitUntilDone()); + ASSERT_EQ(max_count, output_packets.size()); + for (int i = 0; i < output_packets.size(); ++i) { + EXPECT_EQ(i, output_packets[i].Get()); + EXPECT_EQ(Timestamp(i), output_packets[i].Timestamp()); + } +} + +// Returns the first packet of the input stream. +class FirstPacketFilterCalculator : public CalculatorBase { + public: + FirstPacketFilterCalculator() {} + ~FirstPacketFilterCalculator() override {} + + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + if (!seen_first_packet_) { + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + cc->Outputs().Index(0).Close(); + seen_first_packet_ = true; + } + return ::mediapipe::OkStatus(); + } + + private: + bool seen_first_packet_ = false; +}; +REGISTER_CALCULATOR(FirstPacketFilterCalculator); + +TEST(CalculatorGraph, SourceLayerInversion) { + // There are three CountingSourceCalculators, indexed 0, 1, and 2. Each of + // them outputs 10 packets. + // + // CountingSourceCalculator 0 should output 0, 1, 2, 3, ..., 9. + // CountingSourceCalculator 1 should output 100, 101, 102, 103, ..., 109. + // CountingSourceCalculator 2 should output 0, 100, 200, 300, ..., 900. + // However, there is a source layer inversion. + // CountingSourceCalculator 0 is in source layer 0. + // CountingSourceCalculator 1 is in source layer 1. + // CountingSourceCalculator 2 is in source layer 0, but consumes an output + // side packet generated by a downstream calculator of + // CountingSourceCalculator 1. + // + // This graph will deadlock when CountingSourceCalculator 0 runs to + // completion and CountingSourceCalculator 1 cannot be activated because + // CountingSourceCalculator 2 cannot be opened. + + const int max_count = 10; + const int initial_value1 = 100; + // Set num_threads to 1 to force sequential execution for deterministic + // outputs. + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + num_threads: 1 + node { + calculator: "CountingSourceCalculator" + output_stream: "count0" + input_side_packet: "MAX_COUNT:max_count" + source_layer: 0 + } + + node { + calculator: "CountingSourceCalculator" + output_stream: "count1" + input_side_packet: "MAX_COUNT:max_count" + input_side_packet: "INITIAL_VALUE:initial_value1" + source_layer: 1 + } + node { + calculator: "FirstPacketFilterCalculator" + input_stream: "count1" + output_stream: "first_count1" + } + node { + calculator: "OutputSidePacketInProcessCalculator" + input_stream: "first_count1" + output_side_packet: "increment2" + } + + node { + calculator: "CountingSourceCalculator" + output_stream: "count2" + input_side_packet: "MAX_COUNT:max_count" + input_side_packet: "INCREMENT:increment2" + source_layer: 0 + } + )"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize( + config, {{"max_count", MakePacket(max_count)}, + {"initial_value1", MakePacket(initial_value1)}})); + ::mediapipe::Status status = graph.Run(); + EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kUnknown); + EXPECT_THAT(status.message(), testing::HasSubstr("deadlock")); +} + +// Tests a graph of packet-generator-like calculators, which have no input +// streams and no output streams. +TEST(CalculatorGraph, PacketGeneratorLikeCalculators) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: "IntegerOutputSidePacketCalculator" + output_side_packet: "one" + } + node { + calculator: "IntegerOutputSidePacketCalculator" + output_side_packet: "another_one" + } + node { + calculator: "SidePacketAdderCalculator" + input_side_packet: "one" + input_side_packet: "another_one" + output_side_packet: "two" + } + node { + calculator: "IntegerOutputSidePacketCalculator" + output_side_packet: "yet_another_one" + } + node { + calculator: "SidePacketAdderCalculator" + input_side_packet: "two" + input_side_packet: "yet_another_one" + output_side_packet: "three" + } + node { + calculator: "SidePacketToStreamPacketCalculator" + input_side_packet: "three" + output_stream: "output" + } + )"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + std::vector output_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output", [&output_packets](const Packet& packet) { + output_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + MP_ASSERT_OK(graph.Run()); + ASSERT_EQ(1, output_packets.size()); + EXPECT_EQ(3, output_packets[0].Get()); + EXPECT_EQ(Timestamp::PostStream(), output_packets[0].Timestamp()); +} + +TEST(CalculatorGraph, OutputSummarySidePacketInClose) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "input_packets" + node { + calculator: "CountAndOutputSummarySidePacketInCloseCalculator" + input_stream: "input_packets" + output_side_packet: "num_of_packets" + } + node { + calculator: "SidePacketToStreamPacketCalculator" + input_side_packet: "num_of_packets" + output_stream: "output" + } + )"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + std::vector output_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output", [&output_packets](const Packet& packet) { + output_packets.push_back(packet); + return ::mediapipe::OkStatus(); + })); + + // Run the graph twice. + int max_count = 100; + for (int run = 0; run < 1; ++run) { + output_packets.clear(); + MP_ASSERT_OK(graph.StartRun({})); + for (int i = 0; i < max_count; ++i) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_packets", MakePacket(i).At(Timestamp(i)))); + } + MP_ASSERT_OK(graph.CloseInputStream("input_packets")); + MP_ASSERT_OK(graph.WaitUntilDone()); + ASSERT_EQ(1, output_packets.size()); + EXPECT_EQ(max_count, output_packets[0].Get()); + EXPECT_EQ(Timestamp::PostStream(), output_packets[0].Timestamp()); + } +} + +TEST(CalculatorGraph, GetOutputSidePacket) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "input_packets" + node { + calculator: "CountAndOutputSummarySidePacketInCloseCalculator" + input_stream: "input_packets" + output_side_packet: "num_of_packets" + } + packet_generator { + packet_generator: "Uint64PacketGenerator" + output_side_packet: "output_uint64" + } + packet_generator { + packet_generator: "IntSplitterPacketGenerator" + input_side_packet: "input_uint64" + output_side_packet: "output_uint32_pair" + } + )"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + // Check a packet generated by the PacketGenerator, which is available after + // graph initialization, can be fetched before graph starts. + ::mediapipe::StatusOr status_or_packet = + graph.GetOutputSidePacket("output_uint64"); + MP_ASSERT_OK(status_or_packet); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + // IntSplitterPacketGenerator is missing its input side packet and we + // won't be able to get its output side packet now. + status_or_packet = graph.GetOutputSidePacket("output_uint32_pair"); + EXPECT_EQ(::mediapipe::StatusCode::kUnavailable, + status_or_packet.status().code()); + // Run the graph twice. + int max_count = 100; + std::map extra_side_packets; + extra_side_packets.insert({"input_uint64", MakePacket(1123)}); + for (int run = 0; run < 1; ++run) { + MP_ASSERT_OK(graph.StartRun(extra_side_packets)); + status_or_packet = graph.GetOutputSidePacket("output_uint32_pair"); + MP_ASSERT_OK(status_or_packet); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + for (int i = 0; i < max_count; ++i) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_packets", MakePacket(i).At(Timestamp(i)))); + } + MP_ASSERT_OK(graph.CloseInputStream("input_packets")); + + // Should return NOT_FOUND for invalid side packets. + status_or_packet = graph.GetOutputSidePacket("unknown"); + EXPECT_FALSE(status_or_packet.ok()); + EXPECT_EQ(::mediapipe::StatusCode::kNotFound, + status_or_packet.status().code()); + // Should return UNAVAILABLE before graph is done for valid non-base + // packets. + status_or_packet = graph.GetOutputSidePacket("num_of_packets"); + EXPECT_FALSE(status_or_packet.ok()); + EXPECT_EQ(::mediapipe::StatusCode::kUnavailable, + status_or_packet.status().code()); + // Should stil return a base even before graph is done. + status_or_packet = graph.GetOutputSidePacket("output_uint64"); + MP_ASSERT_OK(status_or_packet); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + + MP_ASSERT_OK(graph.WaitUntilDone()); + + // Check packets are available after graph is done. + status_or_packet = graph.GetOutputSidePacket("num_of_packets"); + MP_ASSERT_OK(status_or_packet); + EXPECT_EQ(max_count, status_or_packet.ValueOrDie().Get()); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + // Should still return a base packet after graph is done. + status_or_packet = graph.GetOutputSidePacket("output_uint64"); + MP_ASSERT_OK(status_or_packet); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + // Should still return a non-base packet after graph is done. + status_or_packet = graph.GetOutputSidePacket("output_uint32_pair"); + MP_ASSERT_OK(status_or_packet); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_graph_test.cc b/mediapipe/framework/calculator_graph_test.cc index 3f9505250..f45fd5175 100644 --- a/mediapipe/framework/calculator_graph_test.cc +++ b/mediapipe/framework/calculator_graph_test.cc @@ -32,6 +32,7 @@ #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/time/clock.h" +#include "absl/time/time.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/counter_factory.h" @@ -355,97 +356,6 @@ class IntToFloatCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(IntToFloatCalculator); -// A Calculator that passes an input packet through if it contains an even -// integer. -class EvenIntFilterCalculator : public CalculatorBase { - public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Index(0).Set(); - cc->Outputs().Index(0).Set(); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Process(CalculatorContext* cc) final { - int value = cc->Inputs().Index(0).Get(); - if (value % 2 == 0) { - cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); - } else { - cc->Outputs().Index(0).SetNextTimestampBound( - cc->InputTimestamp().NextAllowedInStream()); - } - return ::mediapipe::OkStatus(); - } -}; -REGISTER_CALCULATOR(EvenIntFilterCalculator); - -// A Calculator that passes packets through or not, depending on a second -// input. The first input stream's packets are only propagated if the second -// input stream carries the value true. -class ValveCalculator : public CalculatorBase { - public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Index(0).SetAny(); - cc->Inputs().Index(1).Set(); - cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Open(CalculatorContext* cc) final { - cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header()); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Process(CalculatorContext* cc) final { - if (cc->Inputs().Index(1).Get()) { - cc->GetCounter("PassThrough")->Increment(); - cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); - } else { - cc->GetCounter("Block")->Increment(); - // The next timestamp bound is the minimum timestamp that the next packet - // can have, so, if we want to inform the downstream that no packet at - // InputTimestamp() is coming, we need to set it to the next value. - // We could also just call SetOffset(TimestampDiff(0)) in Open, and then - // we would not have to call this manually. - cc->Outputs().Index(0).SetNextTimestampBound( - cc->InputTimestamp().NextAllowedInStream()); - } - return ::mediapipe::OkStatus(); - } -}; -REGISTER_CALCULATOR(ValveCalculator); - -// A Calculator that simply passes its input Packets and header through, -// but shifts the timestamp. -class TimeShiftCalculator : public CalculatorBase { - public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Index(0).SetAny(); - cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - cc->InputSidePackets().Index(0).Set(); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Open(CalculatorContext* cc) final { - // Input: arbitrary Packets. - // Output: copy of the input. - cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header()); - shift_ = cc->InputSidePackets().Index(0).Get(); - cc->SetOffset(shift_); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Process(CalculatorContext* cc) final { - cc->GetCounter("PassThrough")->Increment(); - cc->Outputs().Index(0).AddPacket( - cc->Inputs().Index(0).Value().At(cc->InputTimestamp() + shift_)); - return ::mediapipe::OkStatus(); - } - - private: - TimestampDiff shift_; -}; -REGISTER_CALCULATOR(TimeShiftCalculator); - template class TypedEmptySourceCalculator : public CalculatorBase { public: @@ -1055,74 +965,6 @@ class OneShot20MsCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(OneShot20MsCalculator); -// A source calculator that alternates between outputting an integer (0, 1, 2, -// ..., 100) and setting the next timestamp bound. The timestamps of the output -// packets and next timestamp bounds are 0, 10, 20, 30, ... -// -// T=0 Output 0 -// T=10 Set timestamp bound -// T=20 Output 1 -// T=30 Set timestamp bound -// ... -// T=2000 Output 100 -class OutputAndBoundSourceCalculator : public CalculatorBase { - public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Outputs().Index(0).Set(); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Open(CalculatorContext* cc) override { - counter_ = 0; - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Process(CalculatorContext* cc) override { - Timestamp timestamp(counter_); - if (counter_ % 20 == 0) { - cc->Outputs().Index(0).AddPacket( - MakePacket(counter_ / 20).At(timestamp)); - } else { - cc->Outputs().Index(0).SetNextTimestampBound(timestamp); - } - if (counter_ == 2000) { - return tool::StatusStop(); - } - counter_ += 10; - return ::mediapipe::OkStatus(); - } - - private: - int counter_; -}; -REGISTER_CALCULATOR(OutputAndBoundSourceCalculator); - -// A calculator that outputs an initial packet of value 0 at time 0 in the -// Open() method, and then delays each input packet by 20 time units in the -// Process() method. The input stream and output stream have the integer type. -class Delay20Calculator : public CalculatorBase { - public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Index(0).Set(); - cc->Outputs().Index(0).Set(); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Open(CalculatorContext* cc) final { - cc->SetOffset(TimestampDiff(20)); - cc->Outputs().Index(0).AddPacket(MakePacket(0).At(Timestamp(0))); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Process(CalculatorContext* cc) final { - const Packet& packet = cc->Inputs().Index(0).Value(); - Timestamp timestamp = packet.Timestamp() + 20; - cc->Outputs().Index(0).AddPacket(packet.At(timestamp)); - return ::mediapipe::OkStatus(); - } -}; -REGISTER_CALCULATOR(Delay20Calculator); - // A source calculator that outputs a packet containing the return value of // pthread_self() (the pthread id of the current thread). class PthreadSelfSourceCalculator : public CalculatorBase { @@ -1322,116 +1164,6 @@ class OutputSidePacketInProcessCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(OutputSidePacketInProcessCalculator); -// Takes an input stream packet and counts the number of the packets it -// receives. Outputs the total number of packets as a side packet in Close. -class CountAndOutputSummarySidePacketInCloseCalculator : public CalculatorBase { - public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Index(0).SetAny(); - cc->OutputSidePackets().Index(0).Set(); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Process(CalculatorContext* cc) final { - ++count_; - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Close(CalculatorContext* cc) final { - cc->OutputSidePackets().Index(0).Set( - MakePacket(count_).At(Timestamp::Unset())); - return ::mediapipe::OkStatus(); - } - - int count_ = 0; -}; -REGISTER_CALCULATOR(CountAndOutputSummarySidePacketInCloseCalculator); - -// Takes an input stream packet and passes it (with timestamp intact) as an -// output side packet. This triggers an error in the graph. -class OutputSidePacketWithTimestampCalculator : public CalculatorBase { - public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Index(0).SetAny(); - cc->OutputSidePackets().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Process(CalculatorContext* cc) final { - cc->OutputSidePackets().Index(0).Set(cc->Inputs().Index(0).Value()); - return ::mediapipe::OkStatus(); - } -}; -REGISTER_CALCULATOR(OutputSidePacketWithTimestampCalculator); - -// Generates an output side packet containing the integer 1. -class IntegerOutputSidePacketCalculator : public CalculatorBase { - public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - cc->OutputSidePackets().Index(0).Set(); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Open(CalculatorContext* cc) final { - cc->OutputSidePackets().Index(0).Set(MakePacket(1)); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Process(CalculatorContext* cc) final { - LOG(FATAL) << "Not reached."; - return ::mediapipe::OkStatus(); - } -}; -REGISTER_CALCULATOR(IntegerOutputSidePacketCalculator); - -// Generates an output side packet containing the sum of the two integer input -// side packets. -class SidePacketAdderCalculator : public CalculatorBase { - public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - cc->InputSidePackets().Index(0).Set(); - cc->InputSidePackets().Index(1).Set(); - cc->OutputSidePackets().Index(0).Set(); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Open(CalculatorContext* cc) final { - cc->OutputSidePackets().Index(0).Set( - MakePacket(cc->InputSidePackets().Index(1).Get() + - cc->InputSidePackets().Index(0).Get())); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Process(CalculatorContext* cc) final { - LOG(FATAL) << "Not reached."; - return ::mediapipe::OkStatus(); - } -}; -REGISTER_CALCULATOR(SidePacketAdderCalculator); - -// Produces an output packet with the PostStream timestamp containing the -// input side packet. -class SidePacketToStreamPacketCalculator : public CalculatorBase { - public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { - cc->InputSidePackets().Index(0).SetAny(); - cc->Outputs().Index(0).SetSameAs(&cc->InputSidePackets().Index(0)); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Open(CalculatorContext* cc) final { - cc->Outputs().Index(0).AddPacket( - cc->InputSidePackets().Index(0).At(Timestamp::PostStream())); - cc->Outputs().Index(0).Close(); - return ::mediapipe::OkStatus(); - } - - ::mediapipe::Status Process(CalculatorContext* cc) final { - return ::mediapipe::tool::StatusStop(); - } -}; -REGISTER_CALCULATOR(SidePacketToStreamPacketCalculator); - // A calculator checks if either of two input streams contains a packet and // sends the packet to the single output stream with the same timestamp. class SimpleMuxCalculator : public CalculatorBase { @@ -2411,205 +2143,6 @@ TEST(CalculatorGraph, InputPacketLifetime) { MP_EXPECT_OK(graph.WaitUntilDone()); } -// Test that SetNextTimestampBound propagates. -TEST(CalculatorGraph, SetNextTimestampBoundPropagation) { - CalculatorGraph graph; - CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie(R"( - input_stream: 'in' - input_stream: 'gate' - node { - calculator: 'ValveCalculator' - input_stream: 'in' - input_stream: 'gate' - output_stream: 'gated' - } - node { - calculator: 'PassThroughCalculator' - input_stream: 'gated' - output_stream: 'passed' - } - node { - calculator: 'TimeShiftCalculator' - input_stream: 'passed' - output_stream: 'shifted' - input_side_packet: 'shift' - } - node { - calculator: 'MergeCalculator' - input_stream: 'in' - input_stream: 'shifted' - output_stream: 'merged' - } - node { - name: 'merged_output' - calculator: 'PassThroughCalculator' - input_stream: 'merged' - output_stream: 'out' - } - )"); - - Timestamp timestamp = Timestamp(0); - auto send_inputs = [&graph, ×tamp](int input, bool pass) { - ++timestamp; - MP_EXPECT_OK(graph.AddPacketToInputStream( - "in", MakePacket(input).At(timestamp))); - MP_EXPECT_OK(graph.AddPacketToInputStream( - "gate", MakePacket(pass).At(timestamp))); - }; - - MP_ASSERT_OK(graph.Initialize(config)); - MP_ASSERT_OK(graph.StartRun({{"shift", MakePacket(0)}})); - - auto pass_counter = - graph.GetCounterFactory()->GetCounter("ValveCalculator-PassThrough"); - auto block_counter = - graph.GetCounterFactory()->GetCounter("ValveCalculator-Block"); - auto merged_counter = - graph.GetCounterFactory()->GetCounter("merged_output-PassThrough"); - - send_inputs(1, true); - send_inputs(2, true); - send_inputs(3, false); - send_inputs(4, false); - MP_ASSERT_OK(graph.WaitUntilIdle()); - - // Verify that MergeCalculator was able to run even when the gated branch - // was blocked. - EXPECT_EQ(2, pass_counter->Get()); - EXPECT_EQ(2, block_counter->Get()); - EXPECT_EQ(4, merged_counter->Get()); - - send_inputs(5, true); - send_inputs(6, false); - MP_ASSERT_OK(graph.WaitUntilIdle()); - - EXPECT_EQ(3, pass_counter->Get()); - EXPECT_EQ(3, block_counter->Get()); - EXPECT_EQ(6, merged_counter->Get()); - - MP_ASSERT_OK(graph.CloseAllInputStreams()); - MP_ASSERT_OK(graph.WaitUntilDone()); - - // Now test with time shift - MP_ASSERT_OK(graph.StartRun({{"shift", MakePacket(-1)}})); - - send_inputs(7, true); - MP_ASSERT_OK(graph.WaitUntilIdle()); - - // The merger should have run only once now, at timestamp 6, with inputs - // . If we do not respect the offset and unblock the merger for - // timestamp 7 too, then it will have run twice, with 6: and - // 7: <7, null>. - EXPECT_EQ(4, pass_counter->Get()); - EXPECT_EQ(3, block_counter->Get()); - EXPECT_EQ(7, merged_counter->Get()); - - MP_ASSERT_OK(graph.CloseAllInputStreams()); - MP_ASSERT_OK(graph.WaitUntilDone()); - - EXPECT_EQ(4, pass_counter->Get()); - EXPECT_EQ(3, block_counter->Get()); - EXPECT_EQ(8, merged_counter->Get()); -} - -// Both input streams of the calculator node have the same next timestamp -// bound. One input stream has a packet at that timestamp. The other input -// stream is empty. We should not run the Process() method of the node in this -// case. -TEST(CalculatorGraph, NotAllInputPacketsAtNextTimestampBoundAvailable) { - // - // in0_unfiltered in1_to_be_filtered - // | | - // | V - // | +-----------------------+ - // | |EvenIntFilterCalculator| - // | +-----------------------+ - // | | - // \ / - // \ / in1_filtered - // \ / - // | | - // V V - // +------------------+ - // |IntAdderCalculator| - // +------------------+ - // | - // V - // out - // - CalculatorGraph graph; - CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie(R"( - input_stream: 'in0_unfiltered' - input_stream: 'in1_to_be_filtered' - node { - calculator: 'EvenIntFilterCalculator' - input_stream: 'in1_to_be_filtered' - output_stream: 'in1_filtered' - } - node { - calculator: 'IntAdderCalculator' - input_stream: 'in0_unfiltered' - input_stream: 'in1_filtered' - output_stream: 'out' - } - )"); - std::vector packet_dump; - tool::AddVectorSink("out", &config, &packet_dump); - - MP_ASSERT_OK(graph.Initialize(config)); - MP_ASSERT_OK(graph.StartRun({})); - - Timestamp timestamp = Timestamp(0); - - // We send an integer with timestamp 1 to the in0_unfiltered input stream of - // the IntAdderCalculator. We then send an even integer with timestamp 1 to - // the EvenIntFilterCalculator. This packet will go through and - // the IntAdderCalculator will run. The next timestamp bounds of both the - // input streams of the IntAdderCalculator will become 2. - - ++timestamp; // Timestamp 1. - MP_EXPECT_OK(graph.AddPacketToInputStream("in0_unfiltered", - MakePacket(1).At(timestamp))); - MP_EXPECT_OK(graph.AddPacketToInputStream("in1_to_be_filtered", - MakePacket(2).At(timestamp))); - MP_ASSERT_OK(graph.WaitUntilIdle()); - ASSERT_EQ(1, packet_dump.size()); - EXPECT_EQ(3, packet_dump[0].Get()); - - // We send an odd integer with timestamp 2 to the EvenIntFilterCalculator. - // This packet will be filtered out and the next timestamp bound of the - // in1_filtered input stream of the IntAdderCalculator will become 3. - - ++timestamp; // Timestamp 2. - MP_EXPECT_OK(graph.AddPacketToInputStream("in1_to_be_filtered", - MakePacket(3).At(timestamp))); - - // We send an integer with timestamp 3 to the in0_unfiltered input stream of - // the IntAdderCalculator. MediaPipe should propagate the next timestamp bound - // across the IntAdderCalculator but should not run its Process() method. - - ++timestamp; // Timestamp 3. - MP_EXPECT_OK(graph.AddPacketToInputStream("in0_unfiltered", - MakePacket(3).At(timestamp))); - MP_ASSERT_OK(graph.WaitUntilIdle()); - ASSERT_EQ(1, packet_dump.size()); - - // We send an even integer with timestamp 3 to the IntAdderCalculator. This - // packet will go through and the IntAdderCalculator will run. - - MP_EXPECT_OK(graph.AddPacketToInputStream("in1_to_be_filtered", - MakePacket(4).At(timestamp))); - MP_ASSERT_OK(graph.WaitUntilIdle()); - ASSERT_EQ(2, packet_dump.size()); - EXPECT_EQ(7, packet_dump[1].Get()); - - MP_ASSERT_OK(graph.CloseAllInputStreams()); - MP_ASSERT_OK(graph.WaitUntilDone()); - EXPECT_EQ(2, packet_dump.size()); -} - // Demonstrate an if-then-else graph. TEST(CalculatorGraph, IfThenElse) { // This graph has an if-then-else structure. The left branch, selected by the @@ -3616,134 +3149,6 @@ class PassThroughGenerator : public PacketGenerator { } }; REGISTER_PACKET_GENERATOR(PassThroughGenerator); - -TEST(CalculatorGraph, SharePacketGeneratorGraph) { - CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie(R"( - node { - calculator: 'CountingSourceCalculator' - output_stream: 'count1' - input_side_packet: 'MAX_COUNT:max_count1' - } - node { - calculator: 'CountingSourceCalculator' - output_stream: 'count2' - input_side_packet: 'MAX_COUNT:max_count2' - } - node { - calculator: 'CountingSourceCalculator' - output_stream: 'count3' - input_side_packet: 'MAX_COUNT:max_count3' - } - node { - calculator: 'CountingSourceCalculator' - output_stream: 'count4' - input_side_packet: 'MAX_COUNT:max_count4' - } - node { - calculator: 'PassThroughCalculator' - input_side_packet: 'MAX_COUNT:max_count5' - output_side_packet: 'MAX_COUNT:max_count6' - } - node { - calculator: 'CountingSourceCalculator' - output_stream: 'count5' - input_side_packet: 'MAX_COUNT:max_count6' - } - packet_generator { - packet_generator: 'PassThroughGenerator' - input_side_packet: 'max_count1' - output_side_packet: 'max_count2' - } - packet_generator { - packet_generator: 'PassThroughGenerator' - input_side_packet: 'max_count4' - output_side_packet: 'max_count5' - } - )"); - - // At this point config is a standard config which specifies both - // calculators and packet_factories/packet_genators. The following - // code is an example of reusing side packets across a number of - // CalculatorGraphs. It is particularly informative to note how each - // side packet is created. - // - // max_count1 is set for all graphs by a PacketFactory in the config. - // The side packet is created by generator_graph.InitializeGraph(). - // - // max_count2 is set for all graphs by a PacketGenerator in the config. - // The side packet is created by generator_graph.InitializeGraph() - // because max_count1 is available at that time. - // - // max_count3 is set for all graphs by directly being specified as an - // argument to generator_graph.InitializeGraph(). - // - // max_count4 is set per graph because it is directly specified as an - // argument to generator_graph.ProcessGraph(). - // - // max_count5 is set per graph by a PacketGenerator which is run when - // generator_graph.ProcessGraph() is run (because max_count4 isn't - // available until then). - - // Before anything else, split the graph config into two parts, one - // with the PacketFactory and PacketGenerator config and the other - // with the Calculator config. - CalculatorGraphConfig calculator_config = config; - calculator_config.clear_packet_factory(); - calculator_config.clear_packet_generator(); - CalculatorGraphConfig generator_config = config; - generator_config.clear_node(); - - // Next, create a ValidatedGraphConfig for both configs. - ValidatedGraphConfig validated_calculator_config; - MP_ASSERT_OK(validated_calculator_config.Initialize(calculator_config)); - ValidatedGraphConfig validated_generator_config; - MP_ASSERT_OK(validated_generator_config.Initialize(generator_config)); - - // Create a PacketGeneratorGraph. Side packets max_count1, max_count2, - // and max_count3 are created upon initialization. - // Note that validated_generator_config must outlive generator_graph. - PacketGeneratorGraph generator_graph; - MP_ASSERT_OK( - generator_graph.Initialize(&validated_generator_config, nullptr, - {{"max_count1", MakePacket(10)}, - {"max_count3", MakePacket(20)}})); - ASSERT_THAT(generator_graph.BasePackets(), - testing::ElementsAre(testing::Key("max_count1"), - testing::Key("max_count2"), - testing::Key("max_count3"))); - - // Create a bunch of graphs. - std::vector> graphs; - for (int i = 0; i < 100; ++i) { - graphs.emplace_back(absl::make_unique()); - // Do not pass extra side packets here. - // Note that validated_calculator_config must outlive the graph. - MP_ASSERT_OK(graphs.back()->Initialize(calculator_config, {})); - } - // Run a bunch of graphs, reusing side packets max_count1, max_count2, - // and max_count3. The side packet max_count4 is added per run, - // and triggers the execution of a packet generator which generates - // max_count5. - for (int i = 0; i < 100; ++i) { - std::map all_side_packets; - // Creates max_count4 and max_count5. - MP_ASSERT_OK(generator_graph.RunGraphSetup( - {{"max_count4", MakePacket(30 + i)}}, &all_side_packets)); - ASSERT_THAT(all_side_packets, - testing::ElementsAre( - testing::Key("max_count1"), testing::Key("max_count2"), - testing::Key("max_count3"), testing::Key("max_count4"), - testing::Key("max_count5"))); - // Pass all the side packets prepared by generator_graph here. - MP_ASSERT_OK(graphs[i]->Run(all_side_packets)); - // TODO Verify the actual output. - } - - // Destroy all the graphs. - graphs.clear(); -} - TEST(CalculatorGraph, RecoverAfterRunError) { PacketGeneratorGraph generator_graph; CalculatorGraphConfig config = @@ -4371,47 +3776,6 @@ TEST(CalculatorGraph, RecoverAfterPreviousFailInOpen) { } } -TEST(CalculatorGraph, PropagateBoundLoop) { - CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie(R"( - node { - calculator: 'OutputAndBoundSourceCalculator' - output_stream: 'integers' - } - node { - calculator: 'IntAdderCalculator' - input_stream: 'integers' - input_stream: 'old_sum' - input_stream_info: { - tag_index: ':1' # 'old_sum' - back_edge: true - } - output_stream: 'sum' - input_stream_handler { - input_stream_handler: 'EarlyCloseInputStreamHandler' - } - } - node { - calculator: 'Delay20Calculator' - input_stream: 'sum' - output_stream: 'old_sum' - } - )"); - std::vector packet_dump; - tool::AddVectorSink("sum", &config, &packet_dump); - - CalculatorGraph graph; - MP_ASSERT_OK(graph.Initialize(config)); - MP_ASSERT_OK(graph.Run()); - ASSERT_EQ(101, packet_dump.size()); - int sum = 0; - for (int i = 0; i < 101; ++i) { - sum += i; - EXPECT_EQ(sum, packet_dump[i].Get()); - EXPECT_EQ(Timestamp(i * 20), packet_dump[i].Timestamp()); - } -} - TEST(CalculatorGraph, ReuseValidatedGraphConfig) { CalculatorGraphConfig config = ::mediapipe::ParseTextProtoOrDie(R"( @@ -4979,176 +4343,6 @@ TEST(CalculatorGraph, CheckInputTimestamp2) { MP_ASSERT_OK(graph.Run()); } -TEST(CalculatorGraph, CheckBatchProcessingBoundPropagation) { - // The timestamp bound sent by OutputAndBoundSourceCalculator shouldn't be - // directly propagated to the output stream when PassThroughCalculator has - // anything in its default calculator context for batch processing. Otherwise, - // the sink calculator's input stream should report packet timestamp - // mismatches. - CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie(R"( - node { - calculator: 'OutputAndBoundSourceCalculator' - output_stream: 'integers' - } - node { - calculator: 'PassThroughCalculator' - input_stream: 'integers' - output_stream: 'output' - input_stream_handler { - input_stream_handler: "DefaultInputStreamHandler" - options: { - [mediapipe.DefaultInputStreamHandlerOptions.ext]: { - batch_size: 10 - } - } - } - } - node { calculator: 'IntSinkCalculator' input_stream: 'output' } - )"); - CalculatorGraph graph; - MP_ASSERT_OK(graph.Initialize(config)); - MP_ASSERT_OK(graph.Run()); -} - -TEST(CalculatorGraph, OutputSidePacketInProcess) { - const int64 offset = 100; - CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie(R"( - input_stream: "offset" - node { - calculator: "OutputSidePacketInProcessCalculator" - input_stream: "offset" - output_side_packet: "offset" - } - node { - calculator: "SidePacketToStreamPacketCalculator" - output_stream: "output" - input_side_packet: "offset" - } - )"); - CalculatorGraph graph; - MP_ASSERT_OK(graph.Initialize(config)); - std::vector output_packets; - MP_ASSERT_OK(graph.ObserveOutputStream( - "output", [&output_packets](const Packet& packet) { - output_packets.push_back(packet); - return ::mediapipe::OkStatus(); - })); - - // Run the graph twice. - for (int run = 0; run < 2; ++run) { - output_packets.clear(); - MP_ASSERT_OK(graph.StartRun({})); - MP_ASSERT_OK(graph.AddPacketToInputStream( - "offset", MakePacket(offset).At(Timestamp(0)))); - MP_ASSERT_OK(graph.CloseInputStream("offset")); - MP_ASSERT_OK(graph.WaitUntilDone()); - ASSERT_EQ(1, output_packets.size()); - EXPECT_EQ(offset, output_packets[0].Get().Value()); - } -} - -TEST(CalculatorGraph, OutputSidePacketAlreadySet) { - const int64 offset = 100; - CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie(R"( - input_stream: "offset" - node { - calculator: "OutputSidePacketInProcessCalculator" - input_stream: "offset" - output_side_packet: "offset" - } - )"); - CalculatorGraph graph; - MP_ASSERT_OK(graph.Initialize(config)); - MP_ASSERT_OK(graph.StartRun({})); - // Send two input packets to cause OutputSidePacketInProcessCalculator to - // set the output side packet twice. - MP_ASSERT_OK(graph.AddPacketToInputStream( - "offset", MakePacket(offset).At(Timestamp(0)))); - MP_ASSERT_OK(graph.AddPacketToInputStream( - "offset", MakePacket(offset).At(Timestamp(1)))); - MP_ASSERT_OK(graph.CloseInputStream("offset")); - - ::mediapipe::Status status = graph.WaitUntilDone(); - EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kAlreadyExists); - EXPECT_THAT(status.message(), testing::HasSubstr("was already set.")); -} - -TEST(CalculatorGraph, OutputSidePacketWithTimestamp) { - const int64 offset = 100; - CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie(R"( - input_stream: "offset" - node { - calculator: "OutputSidePacketWithTimestampCalculator" - input_stream: "offset" - output_side_packet: "offset" - } - )"); - CalculatorGraph graph; - MP_ASSERT_OK(graph.Initialize(config)); - MP_ASSERT_OK(graph.StartRun({})); - // The OutputSidePacketWithTimestampCalculator neglects to clear the - // timestamp in the input packet when it copies the input packet to the - // output side packet. The timestamp value should appear in the error - // message. - MP_ASSERT_OK(graph.AddPacketToInputStream( - "offset", MakePacket(offset).At(Timestamp(237)))); - MP_ASSERT_OK(graph.CloseInputStream("offset")); - ::mediapipe::Status status = graph.WaitUntilDone(); - EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kInvalidArgument); - EXPECT_THAT(status.message(), testing::HasSubstr("has a timestamp 237.")); -} - -TEST(CalculatorGraph, OutputSidePacketConsumedBySourceNode) { - const int max_count = 10; - CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie(R"( - input_stream: "max_count" - node { - calculator: "OutputSidePacketInProcessCalculator" - input_stream: "max_count" - output_side_packet: "max_count" - } - node { - calculator: "CountingSourceCalculator" - output_stream: "count" - input_side_packet: "MAX_COUNT:max_count" - } - node { - calculator: "PassThroughCalculator" - input_stream: "count" - output_stream: "output" - } - )"); - CalculatorGraph graph; - MP_ASSERT_OK(graph.Initialize(config)); - std::vector output_packets; - MP_ASSERT_OK(graph.ObserveOutputStream( - "output", [&output_packets](const Packet& packet) { - output_packets.push_back(packet); - return ::mediapipe::OkStatus(); - })); - MP_ASSERT_OK(graph.StartRun({})); - // Wait until the graph is idle so that - // Scheduler::TryToScheduleNextSourceLayer() gets called. - // Scheduler::TryToScheduleNextSourceLayer() should not activate source - // nodes that haven't been opened. We can't call graph.WaitUntilIdle() - // because the graph has a source node. - absl::SleepFor(absl::Milliseconds(10)); - MP_ASSERT_OK(graph.AddPacketToInputStream( - "max_count", MakePacket(max_count).At(Timestamp(0)))); - MP_ASSERT_OK(graph.CloseInputStream("max_count")); - MP_ASSERT_OK(graph.WaitUntilDone()); - ASSERT_EQ(max_count, output_packets.size()); - for (int i = 0; i < output_packets.size(); ++i) { - EXPECT_EQ(i, output_packets[i].Get()); - EXPECT_EQ(Timestamp(i), output_packets[i].Timestamp()); - } -} - TEST(CalculatorGraph, GraphInputStreamWithTag) { CalculatorGraphConfig config = ::mediapipe::ParseTextProtoOrDie(R"( @@ -5201,245 +4395,6 @@ class FirstPacketFilterCalculator : public CalculatorBase { bool seen_first_packet_ = false; }; REGISTER_CALCULATOR(FirstPacketFilterCalculator); - -TEST(CalculatorGraph, SourceLayerInversion) { - // There are three CountingSourceCalculators, indexed 0, 1, and 2. Each of - // them outputs 10 packets. - // - // CountingSourceCalculator 0 should output 0, 1, 2, 3, ..., 9. - // CountingSourceCalculator 1 should output 100, 101, 102, 103, ..., 109. - // CountingSourceCalculator 2 should output 0, 100, 200, 300, ..., 900. - // However, there is a source layer inversion. - // CountingSourceCalculator 0 is in source layer 0. - // CountingSourceCalculator 1 is in source layer 1. - // CountingSourceCalculator 2 is in source layer 0, but consumes an output - // side packet generated by a downstream calculator of - // CountingSourceCalculator 1. - // - // This graph will deadlock when CountingSourceCalculator 0 runs to - // completion and CountingSourceCalculator 1 cannot be activated because - // CountingSourceCalculator 2 cannot be opened. - - const int max_count = 10; - const int initial_value1 = 100; - // Set num_threads to 1 to force sequential execution for deterministic - // outputs. - CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie(R"( - num_threads: 1 - node { - calculator: "CountingSourceCalculator" - output_stream: "count0" - input_side_packet: "MAX_COUNT:max_count" - source_layer: 0 - } - - node { - calculator: "CountingSourceCalculator" - output_stream: "count1" - input_side_packet: "MAX_COUNT:max_count" - input_side_packet: "INITIAL_VALUE:initial_value1" - source_layer: 1 - } - node { - calculator: "FirstPacketFilterCalculator" - input_stream: "count1" - output_stream: "first_count1" - } - node { - calculator: "OutputSidePacketInProcessCalculator" - input_stream: "first_count1" - output_side_packet: "increment2" - } - - node { - calculator: "CountingSourceCalculator" - output_stream: "count2" - input_side_packet: "MAX_COUNT:max_count" - input_side_packet: "INCREMENT:increment2" - source_layer: 0 - } - )"); - CalculatorGraph graph; - MP_ASSERT_OK(graph.Initialize( - config, {{"max_count", MakePacket(max_count)}, - {"initial_value1", MakePacket(initial_value1)}})); - ::mediapipe::Status status = graph.Run(); - EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kUnknown); - EXPECT_THAT(status.message(), testing::HasSubstr("deadlock")); -} - -// Tests a graph of packet-generator-like calculators, which have no input -// streams and no output streams. -TEST(CalculatorGraph, PacketGeneratorLikeCalculators) { - CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie(R"( - node { - calculator: "IntegerOutputSidePacketCalculator" - output_side_packet: "one" - } - node { - calculator: "IntegerOutputSidePacketCalculator" - output_side_packet: "another_one" - } - node { - calculator: "SidePacketAdderCalculator" - input_side_packet: "one" - input_side_packet: "another_one" - output_side_packet: "two" - } - node { - calculator: "IntegerOutputSidePacketCalculator" - output_side_packet: "yet_another_one" - } - node { - calculator: "SidePacketAdderCalculator" - input_side_packet: "two" - input_side_packet: "yet_another_one" - output_side_packet: "three" - } - node { - calculator: "SidePacketToStreamPacketCalculator" - input_side_packet: "three" - output_stream: "output" - } - )"); - CalculatorGraph graph; - MP_ASSERT_OK(graph.Initialize(config)); - std::vector output_packets; - MP_ASSERT_OK(graph.ObserveOutputStream( - "output", [&output_packets](const Packet& packet) { - output_packets.push_back(packet); - return ::mediapipe::OkStatus(); - })); - MP_ASSERT_OK(graph.Run()); - ASSERT_EQ(1, output_packets.size()); - EXPECT_EQ(3, output_packets[0].Get()); - EXPECT_EQ(Timestamp::PostStream(), output_packets[0].Timestamp()); -} - -TEST(CalculatorGraph, OutputSummarySidePacketInClose) { - CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie(R"( - input_stream: "input_packets" - node { - calculator: "CountAndOutputSummarySidePacketInCloseCalculator" - input_stream: "input_packets" - output_side_packet: "num_of_packets" - } - node { - calculator: "SidePacketToStreamPacketCalculator" - input_side_packet: "num_of_packets" - output_stream: "output" - } - )"); - CalculatorGraph graph; - MP_ASSERT_OK(graph.Initialize(config)); - std::vector output_packets; - MP_ASSERT_OK(graph.ObserveOutputStream( - "output", [&output_packets](const Packet& packet) { - output_packets.push_back(packet); - return ::mediapipe::OkStatus(); - })); - - // Run the graph twice. - int max_count = 100; - for (int run = 0; run < 1; ++run) { - output_packets.clear(); - MP_ASSERT_OK(graph.StartRun({})); - for (int i = 0; i < max_count; ++i) { - MP_ASSERT_OK(graph.AddPacketToInputStream( - "input_packets", MakePacket(i).At(Timestamp(i)))); - } - MP_ASSERT_OK(graph.CloseInputStream("input_packets")); - MP_ASSERT_OK(graph.WaitUntilDone()); - ASSERT_EQ(1, output_packets.size()); - EXPECT_EQ(max_count, output_packets[0].Get()); - EXPECT_EQ(Timestamp::PostStream(), output_packets[0].Timestamp()); - } -} - -TEST(CalculatorGraph, GetOutputSidePacket) { - CalculatorGraphConfig config = - ::mediapipe::ParseTextProtoOrDie(R"( - input_stream: "input_packets" - node { - calculator: "CountAndOutputSummarySidePacketInCloseCalculator" - input_stream: "input_packets" - output_side_packet: "num_of_packets" - } - packet_generator { - packet_generator: "Uint64PacketGenerator" - output_side_packet: "output_uint64" - } - packet_generator { - packet_generator: "IntSplitterPacketGenerator" - input_side_packet: "input_uint64" - output_side_packet: "output_uint32_pair" - } - )"); - CalculatorGraph graph; - MP_ASSERT_OK(graph.Initialize(config)); - // Check a packet generated by the PacketGenerator, which is available after - // graph initialization, can be fetched before graph starts. - ::mediapipe::StatusOr status_or_packet = - graph.GetOutputSidePacket("output_uint64"); - MP_ASSERT_OK(status_or_packet); - EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); - // IntSplitterPacketGenerator is missing its input side packet and we - // won't be able to get its output side packet now. - status_or_packet = graph.GetOutputSidePacket("output_uint32_pair"); - EXPECT_EQ(::mediapipe::StatusCode::kUnavailable, - status_or_packet.status().code()); - // Run the graph twice. - int max_count = 100; - std::map extra_side_packets; - extra_side_packets.insert({"input_uint64", MakePacket(1123)}); - for (int run = 0; run < 1; ++run) { - MP_ASSERT_OK(graph.StartRun(extra_side_packets)); - status_or_packet = graph.GetOutputSidePacket("output_uint32_pair"); - MP_ASSERT_OK(status_or_packet); - EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); - for (int i = 0; i < max_count; ++i) { - MP_ASSERT_OK(graph.AddPacketToInputStream( - "input_packets", MakePacket(i).At(Timestamp(i)))); - } - MP_ASSERT_OK(graph.CloseInputStream("input_packets")); - - // Should return NOT_FOUND for invalid side packets. - status_or_packet = graph.GetOutputSidePacket("unknown"); - EXPECT_FALSE(status_or_packet.ok()); - EXPECT_EQ(::mediapipe::StatusCode::kNotFound, - status_or_packet.status().code()); - // Should return UNAVAILABLE before graph is done for valid non-base - // packets. - status_or_packet = graph.GetOutputSidePacket("num_of_packets"); - EXPECT_FALSE(status_or_packet.ok()); - EXPECT_EQ(::mediapipe::StatusCode::kUnavailable, - status_or_packet.status().code()); - // Should stil return a base even before graph is done. - status_or_packet = graph.GetOutputSidePacket("output_uint64"); - MP_ASSERT_OK(status_or_packet); - EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); - - MP_ASSERT_OK(graph.WaitUntilDone()); - - // Check packets are available after graph is done. - status_or_packet = graph.GetOutputSidePacket("num_of_packets"); - MP_ASSERT_OK(status_or_packet); - EXPECT_EQ(max_count, status_or_packet.ValueOrDie().Get()); - EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); - // Should still return a base packet after graph is done. - status_or_packet = graph.GetOutputSidePacket("output_uint64"); - MP_ASSERT_OK(status_or_packet); - EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); - // Should still return a non-base packet after graph is done. - status_or_packet = graph.GetOutputSidePacket("output_uint32_pair"); - MP_ASSERT_OK(status_or_packet); - EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); - } -} - constexpr int kDefaultMaxCount = 1000; TEST(CalculatorGraph, TestPollPacket) { diff --git a/mediapipe/framework/formats/matrix_data.proto b/mediapipe/framework/formats/matrix_data.proto index deec5c8df..216d01288 100644 --- a/mediapipe/framework/formats/matrix_data.proto +++ b/mediapipe/framework/formats/matrix_data.proto @@ -34,7 +34,7 @@ message MatrixData { ROW_MAJOR = 1; } - // Order in which the data are stored. Implicitly defaults to COLUMN_MAJOR, - // which matches the default for mediapipe::Matrix and Eigen::Matrix*. - optional Layout layout = 4; + // Order in which the data are stored. Defaults to COLUMN_MAJOR, which matches + // the default for mediapipe::Matrix and Eigen::Matrix*. + optional Layout layout = 4 [default = COLUMN_MAJOR]; } diff --git a/mediapipe/framework/graph_validation_test.cc b/mediapipe/framework/graph_validation_test.cc index c67bf57df..98492b8d0 100644 --- a/mediapipe/framework/graph_validation_test.cc +++ b/mediapipe/framework/graph_validation_test.cc @@ -154,11 +154,13 @@ TEST(ValidatedGraphConfigTest, InitializeTemplateFromProtos) { } )"); auto options = ParseTextProtoOrDie(R"( - [mediapipe.TemplateSubgraphOptions.ext]: { - dict: { - arg: { - key: "in_name" - value: { str: "stream_9" } + options: { + [mediapipe.TemplateSubgraphOptions.ext]: { + dict: { + arg: { + key: "in_name" + value: { str: "stream_9" } + } } } })"); diff --git a/mediapipe/framework/subgraph.cc b/mediapipe/framework/subgraph.cc index a1479c5bf..8d121e01e 100644 --- a/mediapipe/framework/subgraph.cc +++ b/mediapipe/framework/subgraph.cc @@ -44,8 +44,8 @@ TemplateSubgraph::~TemplateSubgraph() {} ::mediapipe::StatusOr TemplateSubgraph::GetConfig( const Subgraph::SubgraphOptions& options) { - const TemplateDict& arguments = - options.GetExtension(TemplateSubgraphOptions::ext).dict(); + TemplateDict arguments = + Subgraph::GetOptions(options).dict(); tool::TemplateExpander expander; CalculatorGraphConfig config; MP_RETURN_IF_ERROR(expander.ExpandTemplates(arguments, templ_, &config)); diff --git a/mediapipe/framework/subgraph.h b/mediapipe/framework/subgraph.h index abfecbc36..3febde7e9 100644 --- a/mediapipe/framework/subgraph.h +++ b/mediapipe/framework/subgraph.h @@ -24,6 +24,7 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/tool/calculator_graph_template.pb.h" +#include "mediapipe/framework/tool/options_util.h" namespace mediapipe { @@ -32,7 +33,7 @@ namespace mediapipe { // the graph is running. class Subgraph { public: - using SubgraphOptions = CalculatorOptions; + using SubgraphOptions = CalculatorGraphConfig::Node; Subgraph(); virtual ~Subgraph(); // Returns the config to use for one instantiation of the subgraph. The @@ -42,6 +43,12 @@ class Subgraph { // TODO: make this static? virtual ::mediapipe::StatusOr GetConfig( const SubgraphOptions& options) = 0; + + // Returns options of a specific type. + template + static T GetOptions(Subgraph::SubgraphOptions supgraph_options) { + return tool::OptionsMap().Initialize(supgraph_options).Get(); + } }; using SubgraphRegistry = GlobalFactoryRegistry>; diff --git a/mediapipe/framework/test_calculators.cc b/mediapipe/framework/test_calculators.cc index 21d46c8b8..f3d1f0c79 100644 --- a/mediapipe/framework/test_calculators.cc +++ b/mediapipe/framework/test_calculators.cc @@ -548,8 +548,12 @@ typedef std::function<::mediapipe::Status(const InputStreamShardSet&, OutputStreamShardSet*)> ProcessFunction; +// A callback function for Calculator::Open, Process, or Close. +typedef std::function<::mediapipe::Status(CalculatorContext* cc)> + CalculatorContextFunction; + // A Calculator that runs a testing callback function in Process, -// which is specified as an input side packet. +// Open, or Close, which is specified as an input side packet. class LambdaCalculator : public CalculatorBase { public: static ::mediapipe::Status GetContract(CalculatorContract* cc) { @@ -561,21 +565,49 @@ class LambdaCalculator : public CalculatorBase { id < cc->Outputs().EndId(); ++id) { cc->Outputs().Get(id).SetAny(); } - cc->InputSidePackets().Index(0).Set(); + if (cc->InputSidePackets().HasTag("") > 0) { + cc->InputSidePackets().Tag("").Set(); + } + for (std::string tag : {"OPEN", "PROCESS", "CLOSE"}) { + if (cc->InputSidePackets().HasTag(tag)) { + cc->InputSidePackets().Tag(tag).Set(); + } + } return ::mediapipe::OkStatus(); } ::mediapipe::Status Open(CalculatorContext* cc) final { - callback_ = cc->InputSidePackets().Index(0).Get(); + if (cc->InputSidePackets().HasTag("OPEN")) { + return GetContextFn(cc, "OPEN")(cc); + } return ::mediapipe::OkStatus(); } ::mediapipe::Status Process(CalculatorContext* cc) final { - return callback_(cc->Inputs(), &(cc->Outputs())); + if (cc->InputSidePackets().HasTag("PROCESS")) { + return GetContextFn(cc, "PROCESS")(cc); + } + if (cc->InputSidePackets().HasTag("") > 0) { + return GetProcessFn(cc, "")(cc->Inputs(), &cc->Outputs()); + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Close(CalculatorContext* cc) final { + if (cc->InputSidePackets().HasTag("CLOSE")) { + return GetContextFn(cc, "CLOSE")(cc); + } + return ::mediapipe::OkStatus(); } private: - ProcessFunction callback_; + ProcessFunction GetProcessFn(CalculatorContext* cc, std::string tag) { + return cc->InputSidePackets().Tag(tag).Get(); + } + CalculatorContextFunction GetContextFn(CalculatorContext* cc, + std::string tag) { + return cc->InputSidePackets().Tag(tag).Get(); + } }; REGISTER_CALCULATOR(LambdaCalculator); diff --git a/mediapipe/framework/testdata/BUILD b/mediapipe/framework/testdata/BUILD index 75ee0802d..599576899 100644 --- a/mediapipe/framework/testdata/BUILD +++ b/mediapipe/framework/testdata/BUILD @@ -55,6 +55,14 @@ proto_library( deps = ["@com_google_protobuf//:any_proto"], ) +mediapipe_cc_proto_library( + name = "zoo_mutator_cc_proto", + srcs = ["zoo_mutator.proto"], + cc_deps = ["@com_google_protobuf//:cc_wkt_protos"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":zoo_mutator_proto"], +) + proto_library( name = "zoo_mutation_calculator_proto", srcs = ["zoo_mutation_calculator.proto"], diff --git a/mediapipe/framework/tool/subgraph_expansion.cc b/mediapipe/framework/tool/subgraph_expansion.cc index a91b20d92..665fd4cec 100644 --- a/mediapipe/framework/tool/subgraph_expansion.cc +++ b/mediapipe/framework/tool/subgraph_expansion.cc @@ -237,9 +237,9 @@ static ::mediapipe::Status PrefixNames(int subgraph_index, for (auto it = subgraph_nodes_start; it != nodes->end(); ++it) { const auto& node = *it; MP_RETURN_IF_ERROR(ValidateSubgraphFields(node)); - ASSIGN_OR_RETURN(auto subgraph, graph_registry->CreateByName( - config->package(), node.calculator(), - &node.options())); + ASSIGN_OR_RETURN(auto subgraph, + graph_registry->CreateByName(config->package(), + node.calculator(), &node)); MP_RETURN_IF_ERROR(PrefixNames(subgraph_counter++, &subgraph)); MP_RETURN_IF_ERROR(ConnectSubgraphStreams(node, &subgraph)); subgraphs.push_back(subgraph); diff --git a/mediapipe/framework/tool/subgraph_expansion_test.cc b/mediapipe/framework/tool/subgraph_expansion_test.cc index 8502d7461..8df5eb3c7 100644 --- a/mediapipe/framework/tool/subgraph_expansion_test.cc +++ b/mediapipe/framework/tool/subgraph_expansion_test.cc @@ -128,8 +128,8 @@ class NodeChainSubgraph : public Subgraph { public: ::mediapipe::StatusOr GetConfig( const SubgraphOptions& options) override { - const mediapipe::NodeChainSubgraphOptions& opts = - options.GetExtension(mediapipe::NodeChainSubgraphOptions::ext); + auto opts = + Subgraph::GetOptions(options); const ProtoString& node_type = opts.node_type(); int chain_length = opts.chain_length(); RET_CHECK(!node_type.empty()); diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 46f2384e7..639ab9e24 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -80,8 +80,7 @@ GL_BASE_LINK_OPTS_OSS = GL_BASE_LINK_OPTS + select({ "-lEGL", ], "//mediapipe:android": [], - "//mediapipe:apple": [], - "//mediapipe:macos": [], + "//mediapipe:ios": [], ":disable_gpu": [], }) @@ -289,6 +288,25 @@ objc_library( ], ) +objc_library( + name = "MPPMetalUtil", + srcs = ["MPPMetalUtil.mm"], + hdrs = ["MPPMetalUtil.h"], + copts = [ + "-x objective-c++", + "-Wno-shorten-64-to-32", + ], + sdk_frameworks = [ + "CoreVideo", + "Metal", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/objc:mediapipe_framework_ios", + "@google_toolbox_for_mac//:GTM_Defines", + ], +) + proto_library( name = "gl_context_options_proto", srcs = ["gl_context_options.proto"], @@ -499,6 +517,7 @@ cc_library( ":gl_base", ":gl_context", ":gpu_buffer", + ":gpu_buffer_format", ":gpu_buffer_multi_pool", ":gpu_shared_data_internal", ":gpu_service", diff --git a/mediapipe/gpu/MPPMetalUtil.h b/mediapipe/gpu/MPPMetalUtil.h new file mode 100644 index 000000000..328cb99fa --- /dev/null +++ b/mediapipe/gpu/MPPMetalUtil.h @@ -0,0 +1,49 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_GPU_MPP_METAL_UTIL_H_ +#define MEDIAPIPE_GPU_MPP_METAL_UTIL_H_ + +#import +#import +#import + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPMetalUtil : NSObject { +} + +/// Copies a Metal Buffer from source to destination. +/// Uses blitCommandEncoder and assumes offset of 0. ++ (void)blitMetalBufferTo:(id)destination + from:(id)source + blocking:(bool)blocking + commandBuffer:(id)commandBuffer; + +/// Copies a Metal Buffer from source to destination. +/// Simple wrapper for blitCommandEncoder. +/// Optionally block until operation is completed. ++ (void)blitMetalBufferTo:(id)destination + destinationOffset:(int)destinationOffset + from:(id)source + sourceOffset:(int)sourceOffset + bytes:(size_t)bytes + blocking:(bool)blocking + commandBuffer:(id)commandBuffer; + +@end + +NS_ASSUME_NONNULL_END + +#endif // MEDIAPIPE_GPU_MPP_METAL_UTIL_H_ diff --git a/mediapipe/gpu/MPPMetalUtil.mm b/mediapipe/gpu/MPPMetalUtil.mm new file mode 100644 index 000000000..81cd8a358 --- /dev/null +++ b/mediapipe/gpu/MPPMetalUtil.mm @@ -0,0 +1,51 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/gpu/MPPMetalUtil.h" + +@implementation MPPMetalUtil + ++ (void)blitMetalBufferTo:(id)destination + from:(id)source + blocking:(bool)blocking + commandBuffer:(id)commandBuffer { + size_t bytes = MIN(destination.length, source.length); + [self blitMetalBufferTo:destination + destinationOffset:0 + from:source + sourceOffset:0 + bytes:bytes + blocking:blocking + commandBuffer:commandBuffer]; +} + ++ (void)blitMetalBufferTo:(id)destination + destinationOffset:(int)destinationOffset + from:(id)source + sourceOffset:(int)sourceOffset + bytes:(size_t)bytes + blocking:(bool)blocking + commandBuffer:(id)commandBuffer { + id blit_command = [commandBuffer blitCommandEncoder]; + [blit_command copyFromBuffer:source + sourceOffset:sourceOffset + toBuffer:destination + destinationOffset:destinationOffset + size:bytes]; + [blit_command endEncoding]; + [commandBuffer commit]; + if (blocking) [commandBuffer waitUntilCompleted]; +} + +@end diff --git a/mediapipe/gpu/gl_calculator_helper_impl.h b/mediapipe/gpu/gl_calculator_helper_impl.h index 1c80917e2..3d92ca671 100644 --- a/mediapipe/gpu/gl_calculator_helper_impl.h +++ b/mediapipe/gpu/gl_calculator_helper_impl.h @@ -73,7 +73,7 @@ class GlCalculatorHelperImpl { #endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER // Sets default texture filtering parameters. - void SetStandardTextureParams(GLenum target); + void SetStandardTextureParams(GLenum target, GLint internal_format); // Create the framebuffer for rendering. void CreateFramebuffer(); diff --git a/mediapipe/gpu/gl_calculator_helper_impl_common.cc b/mediapipe/gpu/gl_calculator_helper_impl_common.cc index b42618f77..cf2dcf582 100644 --- a/mediapipe/gpu/gl_calculator_helper_impl_common.cc +++ b/mediapipe/gpu/gl_calculator_helper_impl_common.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "mediapipe/gpu/gl_calculator_helper_impl.h" +#include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" namespace mediapipe { @@ -86,9 +87,21 @@ void GlCalculatorHelperImpl::BindFramebuffer(const GlTexture& dst) { #endif } -void GlCalculatorHelperImpl::SetStandardTextureParams(GLenum target) { - glTexParameteri(target, GL_TEXTURE_MIN_FILTER, GL_LINEAR); - glTexParameteri(target, GL_TEXTURE_MAG_FILTER, GL_LINEAR); +void GlCalculatorHelperImpl::SetStandardTextureParams(GLenum target, + GLint internal_format) { + GLint filter; + switch (internal_format) { + case GL_R32F: + case GL_RGBA32F: + // 32F (unlike 16f) textures do not support texture filtering + // (According to OpenGL ES specification [TEXTURE IMAGE SPECIFICATION]) + filter = GL_NEAREST; + break; + default: + filter = GL_LINEAR; + } + glTexParameteri(target, GL_TEXTURE_MIN_FILTER, filter); + glTexParameteri(target, GL_TEXTURE_MAG_FILTER, filter); glTexParameteri(target, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); glTexParameteri(target, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); } @@ -136,7 +149,9 @@ GlTexture GlCalculatorHelperImpl::MapGlTextureBuffer( // TODO: do the params need to be reset here?? glBindTexture(texture.target(), texture.name()); - SetStandardTextureParams(texture.target()); + GlTextureInfo info = + GlTextureInfoForGpuBufferFormat(texture_buffer->format(), texture.plane_); + SetStandardTextureParams(texture.target(), info.gl_internal_format); glBindTexture(texture.target(), 0); return texture; @@ -150,7 +165,9 @@ GlTextureBufferSharedPtr GlCalculatorHelperImpl::MakeGlTextureBuffer( GpuBufferFormatForImageFormat(image_frame.Format()), image_frame.PixelData()); glBindTexture(GL_TEXTURE_2D, buffer->name_); - SetStandardTextureParams(buffer->target_); + GlTextureInfo info = + GlTextureInfoForGpuBufferFormat(buffer->format_, /*plane=*/0); + SetStandardTextureParams(buffer->target_, info.gl_internal_format); glBindTexture(GL_TEXTURE_2D, 0); return buffer; diff --git a/mediapipe/gpu/gl_calculator_helper_impl_ios.mm b/mediapipe/gpu/gl_calculator_helper_impl_ios.mm index d62a1f90d..00b2e643c 100644 --- a/mediapipe/gpu/gl_calculator_helper_impl_ios.mm +++ b/mediapipe/gpu/gl_calculator_helper_impl_ios.mm @@ -54,7 +54,7 @@ GlTexture GlCalculatorHelperImpl::CreateSourceTexture( glTexImage2D(GL_TEXTURE_2D, 0, info.gl_internal_format, texture.width_, texture.height_, 0, info.gl_format, info.gl_type, image_frame.PixelData()); - SetStandardTextureParams(GL_TEXTURE_2D); + SetStandardTextureParams(GL_TEXTURE_2D, info.gl_internal_format); return texture; } @@ -107,7 +107,7 @@ GlTexture GlCalculatorHelperImpl::MapGpuBuffer( #endif // TARGET_OS_OSX glBindTexture(texture.target(), texture.name()); - SetStandardTextureParams(texture.target()); + SetStandardTextureParams(texture.target(), info.gl_internal_format); return texture; } diff --git a/mediapipe/gpu/gl_context_egl.cc b/mediapipe/gpu/gl_context_egl.cc index 5a57b32db..79f0c30eb 100644 --- a/mediapipe/gpu/gl_context_egl.cc +++ b/mediapipe/gpu/gl_context_egl.cc @@ -110,7 +110,13 @@ GlContext::StatusOrGlContext GlContext::Create(EGLContext share_context, eglChooseConfig(display_, config_attr, &config_, 1, &num_configs); if (!success) { return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) - << "eglChooseConfig() returned error " << eglGetError(); + << "eglChooseConfig() returned error " << std::showbase << std::hex + << eglGetError(); + } + if (!num_configs) { + return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) + << "eglChooseConfig() returned no matching EGL configuration for " + << "RGBA8888 D16 ES" << gl_version << " request. "; } const EGLint context_attr[] = { @@ -125,7 +131,8 @@ GlContext::StatusOrGlContext GlContext::Create(EGLContext share_context, int error = eglGetError(); RET_CHECK(context_ != EGL_NO_CONTEXT) << "Could not create GLES " << gl_version << " context; " - << "eglCreateContext() returned error " << error + << "eglCreateContext() returned error " << std::showbase << std::hex + << error << (error == EGL_BAD_CONTEXT ? ": external context uses a different version of OpenGL" : ""); @@ -143,7 +150,8 @@ GlContext::StatusOrGlContext GlContext::Create(EGLContext share_context, display_ = eglGetDisplay(EGL_DEFAULT_DISPLAY); RET_CHECK(display_ != EGL_NO_DISPLAY) - << "eglGetDisplay() returned error " << eglGetError(); + << "eglGetDisplay() returned error " << std::showbase << std::hex + << eglGetError(); EGLBoolean success = eglInitialize(display_, &major, &minor); RET_CHECK(success) << "Unable to initialize EGL"; @@ -162,7 +170,8 @@ GlContext::StatusOrGlContext GlContext::Create(EGLContext share_context, surface_ = eglCreatePbufferSurface(display_, config_, pbuffer_attr); RET_CHECK(surface_ != EGL_NO_SURFACE) - << "eglCreatePbufferSurface() returned error " << eglGetError(); + << "eglCreatePbufferSurface() returned error " << std::showbase + << std::hex << eglGetError(); return ::mediapipe::OkStatus(); } @@ -186,17 +195,21 @@ void GlContext::DestroyContext() { if (IsCurrent()) { if (!eglMakeCurrent(display_, EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT)) { - LOG(ERROR) << "eglMakeCurrent() returned error " << eglGetError(); + LOG(ERROR) << "eglMakeCurrent() returned error " << std::showbase + << std::hex << eglGetError(); } } if (surface_ != EGL_NO_SURFACE) { if (!eglDestroySurface(display_, surface_)) { - LOG(ERROR) << "eglDestroySurface() returned error " << eglGetError(); + LOG(ERROR) << "eglDestroySurface() returned error " << std::showbase + << std::hex << eglGetError(); } + surface_ = EGL_NO_SURFACE; } if (context_ != EGL_NO_CONTEXT) { if (!eglDestroyContext(display_, context_)) { - LOG(ERROR) << "eglDestroyContext() returned error " << eglGetError(); + LOG(ERROR) << "eglDestroyContext() returned error " << std::showbase + << std::hex << eglGetError(); } context_ = EGL_NO_CONTEXT; } @@ -245,7 +258,8 @@ void GlContext::GetCurrentContextBinding(GlContext::ContextBinding* binding) { EGLBoolean success = eglMakeCurrent(display, new_binding.draw_surface, new_binding.read_surface, new_binding.context); - RET_CHECK(success) << "eglMakeCurrent() returned error " << eglGetError(); + RET_CHECK(success) << "eglMakeCurrent() returned error " << std::showbase + << std::hex << eglGetError(); return ::mediapipe::OkStatus(); } diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 80cfd20d5..e8dbda3e3 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -77,7 +77,8 @@ bool GlTextureBuffer::CreateInternal(const void* data) { // TODO: maybe we do not actually have to wait for the // consumer sync here. Check docs. sync_token->WaitOnGpu(); - DCHECK(glIsTexture(name_to_delete)); + DLOG_IF(ERROR, !glIsTexture(name_to_delete)) + << "Deleting invalid texture id: " << name_to_delete; glDeleteTextures(1, &name_to_delete); }); }; diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc new file mode 100644 index 000000000..a4bd93a7a --- /dev/null +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -0,0 +1,50 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/gpu/gpu_buffer.h" + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/gpu/gpu_test_base.h" + +namespace mediapipe { +namespace { + +class GpuBufferTest : public GpuTestBase {}; + +TEST_F(GpuBufferTest, BasicTest) { + RunInGlContext([this] { + GpuBuffer buffer = gpu_shared_.gpu_buffer_pool.GetBuffer(300, 200); + EXPECT_EQ(buffer.width(), 300); + EXPECT_EQ(buffer.height(), 200); + EXPECT_TRUE(buffer); + EXPECT_FALSE(buffer == nullptr); + + GpuBuffer no_buffer; + EXPECT_FALSE(no_buffer); + EXPECT_TRUE(no_buffer == nullptr); + + GpuBuffer buffer2 = buffer; + EXPECT_EQ(buffer, buffer); + EXPECT_EQ(buffer, buffer2); + EXPECT_NE(buffer, no_buffer); + + buffer = nullptr; + EXPECT_TRUE(buffer == nullptr); + EXPECT_TRUE(buffer == no_buffer); + }); +} + +} // anonymous namespace +} // namespace mediapipe diff --git a/mediapipe/gpu/gpu_shared_data_internal.h b/mediapipe/gpu/gpu_shared_data_internal.h index 65f9d8891..b829c4f63 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.h +++ b/mediapipe/gpu/gpu_shared_data_internal.h @@ -123,7 +123,7 @@ struct GpuSharedData { PlatformGlContext external_context) { auto status_or_resources = GpuResources::Create(external_context); MEDIAPIPE_CHECK_OK(status_or_resources.status()) - << "could not create GpuResources"; + << ": could not create GpuResources"; return std::move(status_or_resources).ValueOrDie(); } }; diff --git a/mediapipe/gpu/gpu_test_base.h b/mediapipe/gpu/gpu_test_base.h new file mode 100644 index 000000000..e9fd64725 --- /dev/null +++ b/mediapipe/gpu/gpu_test_base.h @@ -0,0 +1,39 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_GPU_GPU_TEST_BASE_H_ +#define MEDIAPIPE_GPU_GPU_TEST_BASE_H_ + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_shared_data_internal.h" + +namespace mediapipe { + +class GpuTestBase : public ::testing::Test { + protected: + GpuTestBase() { helper_.InitializeForTest(&gpu_shared_); } + + void RunInGlContext(std::function gl_func) { + helper_.RunInGlContext(std::move(gl_func)); + } + + GpuSharedData gpu_shared_; + GlCalculatorHelper helper_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_GPU_TEST_BASE_H_ diff --git a/mediapipe/graphs/face_detection/BUILD b/mediapipe/graphs/face_detection/BUILD index 0281f3437..ccc9995d6 100644 --- a/mediapipe/graphs/face_detection/BUILD +++ b/mediapipe/graphs/face_detection/BUILD @@ -35,6 +35,23 @@ cc_library( ], ) +cc_library( + name = "desktop_tflite_calculators", + deps = [ + "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/tflite:ssd_anchors_calculator", + "//mediapipe/calculators/tflite:tflite_converter_calculator", + "//mediapipe/calculators/tflite:tflite_inference_calculator", + "//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator", + "//mediapipe/calculators/util:annotation_overlay_calculator", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator", + "//mediapipe/calculators/util:detection_letterbox_removal_calculator", + "//mediapipe/calculators/util:detections_to_render_data_calculator", + "//mediapipe/calculators/util:non_max_suppression_calculator", + ], +) + load( "//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_binary_graph", diff --git a/mediapipe/graphs/face_detection/face_detection_desktop_live.pbtxt b/mediapipe/graphs/face_detection/face_detection_desktop_live.pbtxt new file mode 100644 index 000000000..95fdb3623 --- /dev/null +++ b/mediapipe/graphs/face_detection/face_detection_desktop_live.pbtxt @@ -0,0 +1,184 @@ +# MediaPipe graph that performs face detection with TensorFlow Lite on CPU. +# Used in the examples in +# mediapipie/examples/desktop/face_detection:face_detection_cpu. + +# Images on GPU coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Throttles the images flowing downstream for flow control. It passes through +# the very first incoming image unaltered, and waits for +# TfLiteTensorsToDetectionsCalculator downstream in the graph to finish +# generating the corresponding detections before it passes through another +# image. All images that come in while waiting are dropped, limiting the number +# of in-flight images between this calculator and +# TfLiteTensorsToDetectionsCalculator to 1. This prevents the nodes in between +# from queuing up incoming images and data excessively, which leads to increased +# latency and memory usage, unwanted in real-time mobile applications. It also +# eliminates unnecessarily computation, e.g., a transformed image produced by +# ImageTransformationCalculator may get dropped downstream if the subsequent +# TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy +# processing previous inputs. +node { + calculator: "FlowLimiterCalculator" + input_stream: "input_video" + input_stream: "FINISHED:detections" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_input_video" +} + +# Transforms the input image on CPU to a 128x128 image. To scale the input +# image, the scale_mode option is set to FIT to preserve the aspect ratio, +# resulting in potential letterboxing in the transformed image. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE:throttled_input_video" + output_stream: "IMAGE:transformed_input_video_cpu" + output_stream: "LETTERBOX_PADDING:letterbox_padding" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 128 + output_height: 128 + scale_mode: FIT + } + } +} + +# Converts the transformed input image on CPU into an image tensor stored as a +# TfLiteTensor. +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE:transformed_input_video_cpu" + output_stream: "TENSORS:image_tensor" +} + +# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS:image_tensor" + output_stream: "TENSORS:detection_tensors" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "mediapipe/models/face_detection_front.tflite" + } + } +} + +# Generates a single side packet containing a vector of SSD anchors based on +# the specification in the options. +node { + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + node_options: { + [type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] { + num_layers: 4 + min_scale: 0.1484375 + max_scale: 0.75 + input_size_height: 128 + input_size_width: 128 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 8 + strides: 16 + strides: 16 + strides: 16 + aspect_ratios: 1.0 + fixed_anchor_size: true + } + } +} + +# Decodes the detection tensors generated by the TensorFlow Lite model, based on +# the SSD anchors and the specification in the options, into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "TfLiteTensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] { + num_classes: 1 + num_boxes: 896 + num_coords: 16 + box_coord_offset: 0 + keypoint_coord_offset: 4 + num_keypoints: 6 + num_values_per_keypoint: 2 + sigmoid_score: true + score_clipping_thresh: 100.0 + reverse_output_order: true + x_scale: 128.0 + y_scale: 128.0 + h_scale: 128.0 + w_scale: 128.0 + min_score_thresh: 0.75 + } + } +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "detections" + output_stream: "filtered_detections" + node_options: { + [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { + min_suppression_threshold: 0.3 + overlap_type: INTERSECTION_OVER_UNION + algorithm: WEIGHTED + return_empty_detections: true + } + } +} + +# Maps detection label IDs to the corresponding label text ("Face"). The label +# map is provided in the label_map_path option. +node { + calculator: "DetectionLabelIdToTextCalculator" + input_stream: "filtered_detections" + output_stream: "labeled_detections" + node_options: { + [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { + label_map_path: "mediapipe/models/face_detection_front_labelmap.txt" + } + } +} + +# Adjusts detection locations (already normalized to [0.f, 1.f]) on the +# letterboxed image (after image transformation with the FIT scale mode) to the +# corresponding locations on the same image with the letterbox removed (the +# input image to the graph before image transformation). +node { + calculator: "DetectionLetterboxRemovalCalculator" + input_stream: "DETECTIONS:labeled_detections" + input_stream: "LETTERBOX_PADDING:letterbox_padding" + output_stream: "DETECTIONS:output_detections" +} + +# Converts the detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTIONS:output_detections" + output_stream: "RENDER_DATA:render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 255 g: 0 b: 0 } + } + } +} + +# Draws annotations and overlays them on top of the input images. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME:throttled_input_video" + input_stream: "render_data" + output_stream: "OUTPUT_FRAME:output_video" +} + diff --git a/mediapipe/graphs/face_detection/face_detection_mobile_cpu.pbtxt b/mediapipe/graphs/face_detection/face_detection_mobile_cpu.pbtxt index d52ef6c5c..1b6ecbf47 100644 --- a/mediapipe/graphs/face_detection/face_detection_mobile_cpu.pbtxt +++ b/mediapipe/graphs/face_detection/face_detection_mobile_cpu.pbtxt @@ -41,7 +41,7 @@ node: { output_stream: "input_video_cpu" } -# Transforms the input image on GPU to a 128x128 image. To scale the input +# Transforms the input image on CPU to a 128x128 image. To scale the input # image, the scale_mode option is set to FIT to preserve the aspect ratio, # resulting in potential letterboxing in the transformed image. node: { @@ -75,7 +75,7 @@ node { output_stream: "TENSORS:detection_tensors" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { - model_path: "face_detection_front.tflite" + model_path: "mediapipe/models/face_detection_front.tflite" } } } @@ -156,7 +156,7 @@ node { output_stream: "labeled_detections" node_options: { [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { - label_map_path: "face_detection_front_labelmap.txt" + label_map_path: "mediapipe/models/face_detection_front_labelmap.txt" } } } @@ -179,7 +179,7 @@ node { output_stream: "RENDER_DATA:render_data" node_options: { [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { - thickness: 10.0 + thickness: 4.0 color { r: 255 g: 0 b: 0 } } } diff --git a/mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt b/mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt index e12787d5b..2fb85bc00 100644 --- a/mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt +++ b/mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt @@ -62,10 +62,10 @@ node { node { calculator: "TfLiteInferenceCalculator" input_stream: "TENSORS_GPU:image_tensor" - output_stream: "TENSORS:detection_tensors" + output_stream: "TENSORS_GPU:detection_tensors" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { - model_path: "face_detection_front.tflite" + model_path: "mediapipe/models/face_detection_front.tflite" } } } @@ -99,7 +99,7 @@ node { # detections. Each detection describes a detected object. node { calculator: "TfLiteTensorsToDetectionsCalculator" - input_stream: "TENSORS:detection_tensors" + input_stream: "TENSORS_GPU:detection_tensors" input_side_packet: "ANCHORS:anchors" output_stream: "DETECTIONS:detections" node_options: { @@ -146,7 +146,7 @@ node { output_stream: "labeled_detections" node_options: { [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { - label_map_path: "face_detection_front_labelmap.txt" + label_map_path: "mediapipe/models/face_detection_front_labelmap.txt" } } } @@ -169,7 +169,7 @@ node { output_stream: "RENDER_DATA:render_data" node_options: { [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { - thickness: 10.0 + thickness: 4.0 color { r: 255 g: 0 b: 0 } } } diff --git a/mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt b/mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt index ed5d0ada4..c8db44d40 100644 --- a/mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt +++ b/mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt @@ -111,7 +111,7 @@ node { input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { - model_path: "hair_segmentation.tflite" + model_path: "mediapipe/models/hair_segmentation.tflite" use_gpu: true } } diff --git a/mediapipe/graphs/hand_tracking/BUILD b/mediapipe/graphs/hand_tracking/BUILD index 73c5a6ce3..09a8e4d0f 100644 --- a/mediapipe/graphs/hand_tracking/BUILD +++ b/mediapipe/graphs/hand_tracking/BUILD @@ -19,73 +19,35 @@ package(default_visibility = ["//visibility:public"]) load( "//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_binary_graph", - "mediapipe_simple_subgraph", ) -mediapipe_simple_subgraph( - name = "hand_detection_gpu", - graph = "hand_detection_gpu.pbtxt", - register_as = "HandDetectionSubgraph", +cc_library( + name = "desktop_tflite_calculators", deps = [ - "//mediapipe/calculators/image:image_properties_calculator", - "//mediapipe/calculators/image:image_transformation_calculator", - "//mediapipe/calculators/tflite:ssd_anchors_calculator", - "//mediapipe/calculators/tflite:tflite_converter_calculator", - "//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator", - "//mediapipe/calculators/tflite:tflite_inference_calculator", - "//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator", - "//mediapipe/calculators/util:detection_label_id_to_text_calculator", - "//mediapipe/calculators/util:detection_letterbox_removal_calculator", - "//mediapipe/calculators/util:detections_to_rects_calculator", - "//mediapipe/calculators/util:non_max_suppression_calculator", - "//mediapipe/calculators/util:rect_transformation_calculator", - ], -) - -mediapipe_simple_subgraph( - name = "hand_landmark_gpu", - graph = "hand_landmark_gpu.pbtxt", - register_as = "HandLandmarkSubgraph", - deps = [ - "//mediapipe/calculators/core:split_vector_calculator", - "//mediapipe/calculators/image:image_cropping_calculator", - "//mediapipe/calculators/image:image_properties_calculator", - "//mediapipe/calculators/image:image_transformation_calculator", - "//mediapipe/calculators/tflite:tflite_converter_calculator", - "//mediapipe/calculators/tflite:tflite_inference_calculator", - "//mediapipe/calculators/tflite:tflite_tensors_to_floats_calculator", - "//mediapipe/calculators/tflite:tflite_tensors_to_landmarks_calculator", - "//mediapipe/calculators/util:detections_to_rects_calculator", - "//mediapipe/calculators/util:landmark_letterbox_removal_calculator", - "//mediapipe/calculators/util:landmark_projection_calculator", - "//mediapipe/calculators/util:landmarks_to_detection_calculator", - "//mediapipe/calculators/util:rect_transformation_calculator", - "//mediapipe/calculators/util:thresholding_calculator", - ], -) - -mediapipe_simple_subgraph( - name = "renderer_gpu", - graph = "renderer_gpu.pbtxt", - register_as = "RendererSubgraph", - deps = [ - "//mediapipe/calculators/util:annotation_overlay_calculator", - "//mediapipe/calculators/util:detections_to_render_data_calculator", - "//mediapipe/calculators/util:landmarks_to_render_data_calculator", - "//mediapipe/calculators/util:rect_to_render_data_calculator", + "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/calculators/core:gate_calculator", + "//mediapipe/calculators/core:immediate_mux_calculator", + "//mediapipe/calculators/core:merge_calculator", + "//mediapipe/calculators/core:packet_inner_join_calculator", + "//mediapipe/calculators/core:previous_loopback_calculator", + "//mediapipe/calculators/video:opencv_video_decoder_calculator", + "//mediapipe/calculators/video:opencv_video_encoder_calculator", + "//mediapipe/graphs/hand_tracking/subgraphs:hand_detection_cpu", + "//mediapipe/graphs/hand_tracking/subgraphs:hand_landmark_cpu", + "//mediapipe/graphs/hand_tracking/subgraphs:renderer_cpu", ], ) cc_library( name = "mobile_calculators", deps = [ - ":hand_detection_gpu", - ":hand_landmark_gpu", - ":renderer_gpu", "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:gate_calculator", "//mediapipe/calculators/core:merge_calculator", "//mediapipe/calculators/core:previous_loopback_calculator", + "//mediapipe/graphs/hand_tracking/subgraphs:hand_detection_gpu", + "//mediapipe/graphs/hand_tracking/subgraphs:hand_landmark_gpu", + "//mediapipe/graphs/hand_tracking/subgraphs:renderer_gpu", ], ) @@ -99,9 +61,9 @@ mediapipe_binary_graph( cc_library( name = "detection_mobile_calculators", deps = [ - ":hand_detection_gpu", - ":renderer_gpu", "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/graphs/hand_tracking/subgraphs:hand_detection_gpu", + "//mediapipe/graphs/hand_tracking/subgraphs:renderer_gpu", ], ) diff --git a/mediapipe/graphs/hand_tracking/hand_detection_desktop.pbtxt b/mediapipe/graphs/hand_tracking/hand_detection_desktop.pbtxt new file mode 100644 index 000000000..ac8e7a401 --- /dev/null +++ b/mediapipe/graphs/hand_tracking/hand_detection_desktop.pbtxt @@ -0,0 +1,62 @@ +# MediaPipe graph that performs hand detection on desktop with TensorFlow Lite +# on CPU. +# Used in the example in +# mediapipie/examples/desktop/hand_tracking:hand_detection_tflite. + +# max_queue_size limits the number of packets enqueued on any input stream +# by throttling inputs to the graph. This makes the graph only process one +# frame per time. +max_queue_size: 1 + +# Decodes an input video file into images and a video header. +node { + calculator: "OpenCvVideoDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_video_path" + output_stream: "VIDEO:input_video" + output_stream: "VIDEO_PRESTREAM:input_video_header" +} + +# Performs hand detection model on the input frames. See +# hand_detection_cpu.pbtxt for the detail of the sub-graph. +node { + calculator: "HandDetectionSubgraph" + input_stream: "input_video" + output_stream: "DETECTIONS:output_detections" +} + +# Converts the detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTIONS:output_detections" + output_stream: "RENDER_DATA:render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 0 g: 255 b: 0 } + } + } +} + +# Draws annotations and overlays them on top of the original image coming into +# the graph. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME:input_video" + input_stream: "render_data" + output_stream: "OUTPUT_FRAME:output_video" +} + +# Encodes the annotated images into a video file, adopting properties specified +# in the input video header, e.g., video framerate. +node { + calculator: "OpenCvVideoEncoderCalculator" + input_stream: "VIDEO:output_video" + input_stream: "VIDEO_PRESTREAM:input_video_header" + input_side_packet: "OUTPUT_FILE_PATH:output_video_path" + node_options: { + [type.googleapis.com/mediapipe.OpenCvVideoEncoderCalculatorOptions]: { + codec: "avc1" + video_format: "mp4" + } + } +} diff --git a/mediapipe/graphs/hand_tracking/hand_detection_desktop_live.pbtxt b/mediapipe/graphs/hand_tracking/hand_detection_desktop_live.pbtxt new file mode 100644 index 000000000..9e6fdad06 --- /dev/null +++ b/mediapipe/graphs/hand_tracking/hand_detection_desktop_live.pbtxt @@ -0,0 +1,38 @@ +# MediaPipe graph that performs hand detection on desktop with TensorFlow Lite +# on CPU. +# Used in the example in +# mediapipie/examples/desktop/hand_tracking:hand_detection_cpu. + +# Images coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Performs hand detection model on the input frames. See +# hand_detection_cpu.pbtxt for the detail of the sub-graph. +node { + calculator: "HandDetectionSubgraph" + input_stream: "input_video" + output_stream: "DETECTIONS:output_detections" +} + +# Converts the detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTIONS:output_detections" + output_stream: "RENDER_DATA:render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 0 g: 255 b: 0 } + } + } +} + +# Draws annotations and overlays them on top of the original image coming into +# the graph. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME:input_video" + input_stream: "render_data" + output_stream: "OUTPUT_FRAME:output_video" +} diff --git a/mediapipe/graphs/hand_tracking/hand_tracking_desktop.pbtxt b/mediapipe/graphs/hand_tracking/hand_tracking_desktop.pbtxt new file mode 100644 index 000000000..29ad822a8 --- /dev/null +++ b/mediapipe/graphs/hand_tracking/hand_tracking_desktop.pbtxt @@ -0,0 +1,126 @@ +# MediaPipe graph that performs hand tracking on desktop with TensorFlow Lite +# on CPU. +# Used in the example in +# mediapipie/examples/desktop/hand_tracking:hand_tracking_tflite. + +# max_queue_size limits the number of packets enqueued on any input stream +# by throttling inputs to the graph. This makes the graph only process one +# frame per time. +max_queue_size: 1 + +# Decodes an input video file into images and a video header. +node { + calculator: "OpenCvVideoDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_video_path" + output_stream: "VIDEO:input_video" + output_stream: "VIDEO_PRESTREAM:input_video_header" +} + +# Caches a hand-presence decision fed back from HandLandmarkSubgraph, and upon +# the arrival of the next input image sends out the cached decision with the +# timestamp replaced by that of the input image, essentially generating a packet +# that carries the previous hand-presence decision. Note that upon the arrival +# of the very first input image, an empty packet is sent out to jump start the +# feedback loop. +node { + calculator: "PreviousLoopbackCalculator" + input_stream: "MAIN:input_video" + input_stream: "LOOP:hand_presence" + input_stream_info: { + tag_index: "LOOP" + back_edge: true + } + output_stream: "PREV_LOOP:prev_hand_presence" +} + +# Drops the incoming image if HandLandmarkSubgraph was able to identify hand +# presence in the previous image. Otherwise, passes the incoming image through +# to trigger a new round of hand detection in HandDetectionSubgraph. +node { + calculator: "GateCalculator" + input_stream: "input_video" + input_stream: "DISALLOW:prev_hand_presence" + output_stream: "hand_detection_input_video" + + node_options: { + [type.googleapis.com/mediapipe.GateCalculatorOptions] { + empty_packets_as_allow: true + } + } +} + +# Subgraph that detections hands (see hand_detection_cpu.pbtxt). +node { + calculator: "HandDetectionSubgraph" + input_stream: "hand_detection_input_video" + output_stream: "DETECTIONS:palm_detections" + output_stream: "NORM_RECT:hand_rect_from_palm_detections" +} + +# Subgraph that localizes hand landmarks (see hand_landmark_cpu.pbtxt). +node { + calculator: "HandLandmarkSubgraph" + input_stream: "IMAGE:input_video" + input_stream: "NORM_RECT:hand_rect" + output_stream: "LANDMARKS:hand_landmarks" + output_stream: "NORM_RECT:hand_rect_from_landmarks" + output_stream: "PRESENCE:hand_presence" +} + +# Caches a hand rectangle fed back from HandLandmarkSubgraph, and upon the +# arrival of the next input image sends out the cached rectangle with the +# timestamp replaced by that of the input image, essentially generating a packet +# that carries the previous hand rectangle. Note that upon the arrival of the +# very first input image, an empty packet is sent out to jump start the +# feedback loop. +node { + calculator: "PreviousLoopbackCalculator" + input_stream: "MAIN:input_video" + input_stream: "LOOP:hand_rect_from_landmarks" + input_stream_info: { + tag_index: "LOOP" + back_edge: true + } + output_stream: "PREV_LOOP:prev_hand_rect_from_landmarks" +} + +# Merges a stream of hand rectangles generated by HandDetectionSubgraph and that +# generated by HandLandmarkSubgraph into a single output stream by selecting +# between one of the two streams. The former is selected if the incoming packet +# is not empty, i.e., hand detection is performed on the current image by +# HandDetectionSubgraph (because HandLandmarkSubgraph could not identify hand +# presence in the previous image). Otherwise, the latter is selected, which is +# never empty because HandLandmarkSubgraphs processes all images (that went +# through FlowLimiterCaculator). +node { + calculator: "MergeCalculator" + input_stream: "hand_rect_from_palm_detections" + input_stream: "prev_hand_rect_from_landmarks" + output_stream: "hand_rect" +} + +# Subgraph that renders annotations and overlays them on top of the input +# images (see renderer_cpu.pbtxt). +node { + calculator: "RendererSubgraph" + input_stream: "IMAGE:input_video" + input_stream: "LANDMARKS:hand_landmarks" + input_stream: "NORM_RECT:hand_rect" + input_stream: "DETECTIONS:palm_detections" + output_stream: "IMAGE:output_video" +} + +# Encodes the annotated images into a video file, adopting properties specified +# in the input video header, e.g., video framerate. +node { + calculator: "OpenCvVideoEncoderCalculator" + input_stream: "VIDEO:output_video" + input_stream: "VIDEO_PRESTREAM:input_video_header" + input_side_packet: "OUTPUT_FILE_PATH:output_video_path" + node_options: { + [type.googleapis.com/mediapipe.OpenCvVideoEncoderCalculatorOptions]: { + codec: "avc1" + video_format: "mp4" + } + } +} diff --git a/mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt b/mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt new file mode 100644 index 000000000..3aefbf761 --- /dev/null +++ b/mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt @@ -0,0 +1,103 @@ +# MediaPipe graph that performs hand tracking on desktop with TensorFlow Lite +# on CPU. +# Used in the example in +# mediapipie/examples/desktop/hand_tracking:hand_tracking_cpu. + +# Images coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Caches a hand-presence decision fed back from HandLandmarkSubgraph, and upon +# the arrival of the next input image sends out the cached decision with the +# timestamp replaced by that of the input image, essentially generating a packet +# that carries the previous hand-presence decision. Note that upon the arrival +# of the very first input image, an empty packet is sent out to jump start the +# feedback loop. +node { + calculator: "PreviousLoopbackCalculator" + input_stream: "MAIN:input_video" + input_stream: "LOOP:hand_presence" + input_stream_info: { + tag_index: "LOOP" + back_edge: true + } + output_stream: "PREV_LOOP:prev_hand_presence" +} + +# Drops the incoming image if HandLandmarkSubgraph was able to identify hand +# presence in the previous image. Otherwise, passes the incoming image through +# to trigger a new round of hand detection in HandDetectionSubgraph. +node { + calculator: "GateCalculator" + input_stream: "input_video" + input_stream: "DISALLOW:prev_hand_presence" + output_stream: "hand_detection_input_video" + + node_options: { + [type.googleapis.com/mediapipe.GateCalculatorOptions] { + empty_packets_as_allow: true + } + } +} + +# Subgraph that detections hands (see hand_detection_cpu.pbtxt). +node { + calculator: "HandDetectionSubgraph" + input_stream: "hand_detection_input_video" + output_stream: "DETECTIONS:palm_detections" + output_stream: "NORM_RECT:hand_rect_from_palm_detections" +} + +# Subgraph that localizes hand landmarks (see hand_landmark_cpu.pbtxt). +node { + calculator: "HandLandmarkSubgraph" + input_stream: "IMAGE:input_video" + input_stream: "NORM_RECT:hand_rect" + output_stream: "LANDMARKS:hand_landmarks" + output_stream: "NORM_RECT:hand_rect_from_landmarks" + output_stream: "PRESENCE:hand_presence" +} + +# Caches a hand rectangle fed back from HandLandmarkSubgraph, and upon the +# arrival of the next input image sends out the cached rectangle with the +# timestamp replaced by that of the input image, essentially generating a packet +# that carries the previous hand rectangle. Note that upon the arrival of the +# very first input image, an empty packet is sent out to jump start the +# feedback loop. +node { + calculator: "PreviousLoopbackCalculator" + input_stream: "MAIN:input_video" + input_stream: "LOOP:hand_rect_from_landmarks" + input_stream_info: { + tag_index: "LOOP" + back_edge: true + } + output_stream: "PREV_LOOP:prev_hand_rect_from_landmarks" +} + +# Merges a stream of hand rectangles generated by HandDetectionSubgraph and that +# generated by HandLandmarkSubgraph into a single output stream by selecting +# between one of the two streams. The former is selected if the incoming packet +# is not empty, i.e., hand detection is performed on the current image by +# HandDetectionSubgraph (because HandLandmarkSubgraph could not identify hand +# presence in the previous image). Otherwise, the latter is selected, which is +# never empty because HandLandmarkSubgraphs processes all images (that went +# through FlowLimiterCaculator). +node { + calculator: "MergeCalculator" + input_stream: "hand_rect_from_palm_detections" + input_stream: "prev_hand_rect_from_landmarks" + output_stream: "hand_rect" +} + +# Subgraph that renders annotations and overlays them on top of the input +# images (see renderer_cpu.pbtxt). +node { + calculator: "RendererSubgraph" + input_stream: "IMAGE:input_video" + input_stream: "LANDMARKS:hand_landmarks" + input_stream: "NORM_RECT:hand_rect" + input_stream: "DETECTIONS:palm_detections" + output_stream: "IMAGE:output_video" +} + diff --git a/mediapipe/graphs/hand_tracking/subgraphs/BUILD b/mediapipe/graphs/hand_tracking/subgraphs/BUILD new file mode 100644 index 000000000..93a0d1048 --- /dev/null +++ b/mediapipe/graphs/hand_tracking/subgraphs/BUILD @@ -0,0 +1,132 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load( + "//mediapipe/framework/tool:mediapipe_graph.bzl", + "mediapipe_simple_subgraph", +) + +mediapipe_simple_subgraph( + name = "hand_detection_cpu", + graph = "hand_detection_cpu.pbtxt", + register_as = "HandDetectionSubgraph", + deps = [ + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/tflite:ssd_anchors_calculator", + "//mediapipe/calculators/tflite:tflite_converter_calculator", + "//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator", + "//mediapipe/calculators/tflite:tflite_inference_calculator", + "//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator", + "//mediapipe/calculators/util:detection_letterbox_removal_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator", + "//mediapipe/calculators/util:detections_to_render_data_calculator", + "//mediapipe/calculators/util:non_max_suppression_calculator", + "//mediapipe/calculators/util:rect_transformation_calculator", + ], +) + +mediapipe_simple_subgraph( + name = "hand_landmark_cpu", + graph = "hand_landmark_cpu.pbtxt", + register_as = "HandLandmarkSubgraph", + deps = [ + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/calculators/image:image_cropping_calculator", + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/tflite:tflite_converter_calculator", + "//mediapipe/calculators/tflite:tflite_inference_calculator", + "//mediapipe/calculators/tflite:tflite_tensors_to_floats_calculator", + "//mediapipe/calculators/tflite:tflite_tensors_to_landmarks_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator", + "//mediapipe/calculators/util:landmark_letterbox_removal_calculator", + "//mediapipe/calculators/util:landmark_projection_calculator", + "//mediapipe/calculators/util:landmarks_to_detection_calculator", + "//mediapipe/calculators/util:landmarks_to_render_data_calculator", + "//mediapipe/calculators/util:rect_transformation_calculator", + "//mediapipe/calculators/util:thresholding_calculator", + ], +) + +mediapipe_simple_subgraph( + name = "renderer_cpu", + graph = "renderer_cpu.pbtxt", + register_as = "RendererSubgraph", + deps = [ + "//mediapipe/calculators/util:annotation_overlay_calculator", + "//mediapipe/calculators/util:detections_to_render_data_calculator", + "//mediapipe/calculators/util:landmarks_to_render_data_calculator", + "//mediapipe/calculators/util:rect_to_render_data_calculator", + ], +) + +mediapipe_simple_subgraph( + name = "hand_detection_gpu", + graph = "hand_detection_gpu.pbtxt", + register_as = "HandDetectionSubgraph", + deps = [ + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/tflite:ssd_anchors_calculator", + "//mediapipe/calculators/tflite:tflite_converter_calculator", + "//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator", + "//mediapipe/calculators/tflite:tflite_inference_calculator", + "//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator", + "//mediapipe/calculators/util:detection_letterbox_removal_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator", + "//mediapipe/calculators/util:non_max_suppression_calculator", + "//mediapipe/calculators/util:rect_transformation_calculator", + ], +) + +mediapipe_simple_subgraph( + name = "hand_landmark_gpu", + graph = "hand_landmark_gpu.pbtxt", + register_as = "HandLandmarkSubgraph", + deps = [ + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/calculators/image:image_cropping_calculator", + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/tflite:tflite_converter_calculator", + "//mediapipe/calculators/tflite:tflite_inference_calculator", + "//mediapipe/calculators/tflite:tflite_tensors_to_floats_calculator", + "//mediapipe/calculators/tflite:tflite_tensors_to_landmarks_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator", + "//mediapipe/calculators/util:landmark_letterbox_removal_calculator", + "//mediapipe/calculators/util:landmark_projection_calculator", + "//mediapipe/calculators/util:landmarks_to_detection_calculator", + "//mediapipe/calculators/util:rect_transformation_calculator", + "//mediapipe/calculators/util:thresholding_calculator", + ], +) + +mediapipe_simple_subgraph( + name = "renderer_gpu", + graph = "renderer_gpu.pbtxt", + register_as = "RendererSubgraph", + deps = [ + "//mediapipe/calculators/util:annotation_overlay_calculator", + "//mediapipe/calculators/util:detections_to_render_data_calculator", + "//mediapipe/calculators/util:landmarks_to_render_data_calculator", + "//mediapipe/calculators/util:rect_to_render_data_calculator", + ], +) diff --git a/mediapipe/graphs/hand_tracking/subgraphs/hand_detection_cpu.pbtxt b/mediapipe/graphs/hand_tracking/subgraphs/hand_detection_cpu.pbtxt new file mode 100644 index 000000000..65c7d162f --- /dev/null +++ b/mediapipe/graphs/hand_tracking/subgraphs/hand_detection_cpu.pbtxt @@ -0,0 +1,193 @@ +# MediaPipe hand detection subgraph. + +type: "HandDetectionSubgraph" + +input_stream: "input_video" +output_stream: "DETECTIONS:palm_detections" +output_stream: "NORM_RECT:hand_rect_from_palm_detections" + +# Transforms the input image on CPU to a 256x256 image. To scale the input +# image, the scale_mode option is set to FIT to preserve the aspect ratio, +# resulting in potential letterboxing in the transformed image. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE:input_video" + output_stream: "IMAGE:transformed_input_video" + output_stream: "LETTERBOX_PADDING:letterbox_padding" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 256 + output_height: 256 + scale_mode: FIT + } + } +} + +# Generates a single side packet containing a TensorFlow Lite op resolver that +# supports custom ops needed by the model used in this graph. +node { + calculator: "TfLiteCustomOpResolverCalculator" + output_side_packet: "op_resolver" +} + +# Converts the transformed input image on CPU into an image tensor as a +# TfLiteTensor. The zero_center option is set to true to normalize the +# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE:transformed_input_video" + output_stream: "TENSORS:image_tensor" +} + +# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS:image_tensor" + output_stream: "TENSORS:detection_tensors" + input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "mediapipe/models/palm_detection.tflite" + } + } +} + +# Generates a single side packet containing a vector of SSD anchors based on +# the specification in the options. +node { + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + node_options: { + [type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] { + num_layers: 5 + min_scale: 0.1171875 + max_scale: 0.75 + input_size_height: 256 + input_size_width: 256 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 8 + strides: 16 + strides: 32 + strides: 32 + strides: 32 + aspect_ratios: 1.0 + fixed_anchor_size: true + } + } +} + +# Decodes the detection tensors generated by the TensorFlow Lite model, based on +# the SSD anchors and the specification in the options, into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "TfLiteTensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] { + num_classes: 1 + num_boxes: 2944 + num_coords: 18 + box_coord_offset: 0 + keypoint_coord_offset: 4 + num_keypoints: 7 + num_values_per_keypoint: 2 + sigmoid_score: true + score_clipping_thresh: 100.0 + reverse_output_order: true + + x_scale: 256.0 + y_scale: 256.0 + h_scale: 256.0 + w_scale: 256.0 + min_score_thresh: 0.5 + } + } +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "detections" + output_stream: "filtered_detections" + node_options: { + [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { + min_suppression_threshold: 0.3 + min_score_threshold: 0.5 + overlap_type: INTERSECTION_OVER_UNION + algorithm: WEIGHTED + return_empty_detections: true + } + } +} + +# Maps detection label IDs to the corresponding label text. The label map is +# provided in the label_map_path option. +node { + calculator: "DetectionLabelIdToTextCalculator" + input_stream: "filtered_detections" + output_stream: "labeled_detections" + node_options: { + [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { + label_map_path: "mediapipe/models/palm_detection_labelmap.txt" + } + } +} + +# Adjusts detection locations (already normalized to [0.f, 1.f]) on the +# letterboxed image (after image transformation with the FIT scale mode) to the +# corresponding locations on the same image with the letterbox removed (the +# input image to the graph before image transformation). +node { + calculator: "DetectionLetterboxRemovalCalculator" + input_stream: "DETECTIONS:labeled_detections" + input_stream: "LETTERBOX_PADDING:letterbox_padding" + output_stream: "DETECTIONS:palm_detections" +} + +# Extracts image size from the input images. +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE:input_video" + output_stream: "SIZE:image_size" +} + +# Converts results of palm detection into a rectangle (normalized by image size) +# that encloses the palm and is rotated such that the line connecting center of +# the wrist and MCP of the middle finger is aligned with the Y-axis of the +# rectangle. +node { + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTIONS:palm_detections" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "NORM_RECT:palm_rect" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRectsCalculatorOptions] { + rotation_vector_start_keypoint_index: 0 # Center of wrist. + rotation_vector_end_keypoint_index: 2 # MCP of middle finger. + rotation_vector_target_angle_degrees: 90 + output_zero_rect_for_empty_detections: true + } + } +} + +# Expands and shifts the rectangle that contains the palm so that it's likely +# to cover the entire hand. +node { + calculator: "RectTransformationCalculator" + input_stream: "NORM_RECT:palm_rect" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "hand_rect_from_palm_detections" + node_options: { + [type.googleapis.com/mediapipe.RectTransformationCalculatorOptions] { + scale_x: 2.6 + scale_y: 2.6 + shift_y: -0.5 + square_long: true + } + } +} diff --git a/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt b/mediapipe/graphs/hand_tracking/subgraphs/hand_detection_gpu.pbtxt similarity index 96% rename from mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt rename to mediapipe/graphs/hand_tracking/subgraphs/hand_detection_gpu.pbtxt index 848bacb9f..833286066 100644 --- a/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt +++ b/mediapipe/graphs/hand_tracking/subgraphs/hand_detection_gpu.pbtxt @@ -49,11 +49,11 @@ node { node { calculator: "TfLiteInferenceCalculator" input_stream: "TENSORS_GPU:image_tensor" - output_stream: "TENSORS:detection_tensors" + output_stream: "TENSORS_GPU:detection_tensors" input_side_packet: "CUSTOM_OP_RESOLVER:opresolver" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { - model_path: "palm_detection.tflite" + model_path: "mediapipe/models/palm_detection.tflite" use_gpu: true } } @@ -89,7 +89,7 @@ node { # detections. Each detection describes a detected object. node { calculator: "TfLiteTensorsToDetectionsCalculator" - input_stream: "TENSORS:detection_tensors" + input_stream: "TENSORS_GPU:detection_tensors" input_side_packet: "ANCHORS:anchors" output_stream: "DETECTIONS:detections" node_options: { @@ -137,7 +137,7 @@ node { output_stream: "labeled_detections" node_options: { [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { - label_map_path: "palm_detection_labelmap.txt" + label_map_path: "mediapipe/models/palm_detection_labelmap.txt" } } } diff --git a/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_cpu.pbtxt b/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_cpu.pbtxt new file mode 100644 index 000000000..ad52a5716 --- /dev/null +++ b/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_cpu.pbtxt @@ -0,0 +1,185 @@ +# MediaPipe hand landmark localization subgraph. + +type: "HandLandmarkSubgraph" + +input_stream: "IMAGE:input_video" +input_stream: "NORM_RECT:hand_rect" +output_stream: "LANDMARKS:hand_landmarks" +output_stream: "NORM_RECT:hand_rect_for_next_frame" +output_stream: "PRESENCE:hand_presence" + +# Crops the rectangle that contains a hand from the input image. +node { + calculator: "ImageCroppingCalculator" + input_stream: "IMAGE:input_video" + input_stream: "NORM_RECT:hand_rect" + output_stream: "IMAGE:hand_image" +} + +# Transforms the input image on CPU to a 256x256 image. To scale the input +# image, the scale_mode option is set to FIT to preserve the aspect ratio, +# resulting in potential letterboxing in the transformed image. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE:hand_image" + output_stream: "IMAGE:transformed_input_video" + output_stream: "LETTERBOX_PADDING:letterbox_padding" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 256 + output_height: 256 + scale_mode: FIT + } + } +} + +# Converts the transformed input image on GPU into an image tensor stored in +# tflite::gpu::GlBuffer. The zero_center option is set to true to normalize the +# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. The flip_vertically +# option is set to true to account for the descrepancy between the +# representation of the input image (origin at the bottom-left corner, the +# OpenGL convention) and what the model used in this graph is expecting (origin +# at the top-left corner). +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE:transformed_input_video" + output_stream: "TENSORS:image_tensor" + node_options: { + [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { + zero_center: false + } + } +} + +# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS:image_tensor" + output_stream: "TENSORS:output_tensors" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "mediapipe/models/hand_landmark.tflite" + } + } +} + +# Splits a vector of TFLite tensors to multiple vectors according to the ranges +# specified in option. +node { + calculator: "SplitTfLiteTensorVectorCalculator" + input_stream: "output_tensors" + output_stream: "landmark_tensors" + output_stream: "hand_flag_tensor" + node_options: { + [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { + ranges: { begin: 0 end: 1 } + ranges: { begin: 1 end: 2 } + } + } +} + +# Converts the hand-flag tensor into a float that represents the confidence +# score of hand presence. +node { + calculator: "TfLiteTensorsToFloatsCalculator" + input_stream: "TENSORS:hand_flag_tensor" + output_stream: "FLOAT:hand_presence_score" +} + +# Applies a threshold to the confidence score to determine whether a hand is +# present. +node { + calculator: "ThresholdingCalculator" + input_stream: "FLOAT:hand_presence_score" + output_stream: "FLAG:hand_presence" + node_options: { + [type.googleapis.com/mediapipe.ThresholdingCalculatorOptions] { + threshold: 0.1 + } + } +} + +# Decodes the landmark tensors into a vector of lanmarks, where the landmark +# coordinates are normalized by the size of the input image to the model. +node { + calculator: "TfLiteTensorsToLandmarksCalculator" + input_stream: "TENSORS:landmark_tensors" + output_stream: "NORM_LANDMARKS:landmarks" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToLandmarksCalculatorOptions] { + num_landmarks: 21 + input_image_width: 256 + input_image_height: 256 + } + } +} + +# Adjusts landmarks (already normalized to [0.f, 1.f]) on the letterboxed hand +# image (after image transformation with the FIT scale mode) to the +# corresponding locations on the same image with the letterbox removed (hand +# image before image transformation). +node { + calculator: "LandmarkLetterboxRemovalCalculator" + input_stream: "LANDMARKS:landmarks" + input_stream: "LETTERBOX_PADDING:letterbox_padding" + output_stream: "LANDMARKS:scaled_landmarks" +} + +# Projects the landmarks from the cropped hand image to the corresponding +# locations on the full image before cropping (input to the graph). +node { + calculator: "LandmarkProjectionCalculator" + input_stream: "NORM_LANDMARKS:scaled_landmarks" + input_stream: "NORM_RECT:hand_rect" + output_stream: "NORM_LANDMARKS:hand_landmarks" +} + +# Extracts image size from the input images. +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE:input_video" + output_stream: "SIZE:image_size" +} + +# Converts hand landmarks to a detection that tightly encloses all landmarks. +node { + calculator: "LandmarksToDetectionCalculator" + input_stream: "NORM_LANDMARKS:hand_landmarks" + output_stream: "DETECTION:hand_detection" +} + +# Converts the hand detection into a rectangle (normalized by image size) +# that encloses the hand and is rotated such that the line connecting center of +# the wrist and MCP of the middle finger is aligned with the Y-axis of the +# rectangle. +node { + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTION:hand_detection" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "NORM_RECT:hand_rect_from_landmarks" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRectsCalculatorOptions] { + rotation_vector_start_keypoint_index: 0 # Center of wrist. + rotation_vector_end_keypoint_index: 9 # MCP of middle finger. + rotation_vector_target_angle_degrees: 90 + } + } +} + +# Expands the hand rectangle so that in the next video frame it's likely to +# still contain the hand even with some motion. +node { + calculator: "RectTransformationCalculator" + input_stream: "NORM_RECT:hand_rect_from_landmarks" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "hand_rect_for_next_frame" + node_options: { + [type.googleapis.com/mediapipe.RectTransformationCalculatorOptions] { + scale_x: 1.6 + scale_y: 1.6 + square_long: true + } + } +} diff --git a/mediapipe/graphs/hand_tracking/hand_landmark_gpu.pbtxt b/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_gpu.pbtxt similarity index 96% rename from mediapipe/graphs/hand_tracking/hand_landmark_gpu.pbtxt rename to mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_gpu.pbtxt index 467abd4c5..283ce459c 100644 --- a/mediapipe/graphs/hand_tracking/hand_landmark_gpu.pbtxt +++ b/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_gpu.pbtxt @@ -39,6 +39,11 @@ node { calculator: "TfLiteConverterCalculator" input_stream: "IMAGE_GPU:transformed_hand_image" output_stream: "TENSORS_GPU:image_tensor" + node_options: { + [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { + zero_center: false + } + } } # Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a @@ -50,7 +55,7 @@ node { output_stream: "TENSORS:output_tensors" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { - model_path: "hand_landmark.tflite" + model_path: "mediapipe/models/hand_landmark.tflite" use_gpu: true } } diff --git a/mediapipe/graphs/hand_tracking/subgraphs/renderer_cpu.pbtxt b/mediapipe/graphs/hand_tracking/subgraphs/renderer_cpu.pbtxt new file mode 100644 index 000000000..c3033155d --- /dev/null +++ b/mediapipe/graphs/hand_tracking/subgraphs/renderer_cpu.pbtxt @@ -0,0 +1,102 @@ +# MediaPipe hand tracking rendering subgraph. + +type: "RendererSubgraph" + +input_stream: "IMAGE:input_image" +input_stream: "DETECTIONS:detections" +input_stream: "LANDMARKS:landmarks" +input_stream: "NORM_RECT:rect" +output_stream: "IMAGE:output_image" + +# Converts detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTIONS:detections" + output_stream: "RENDER_DATA:detection_render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 0 g: 255 b: 0 } + } + } +} + +# Converts landmarks to drawing primitives for annotation overlay. +node { + calculator: "LandmarksToRenderDataCalculator" + input_stream: "NORM_LANDMARKS:landmarks" + output_stream: "RENDER_DATA:landmark_render_data" + node_options: { + [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { + landmark_connections: 0 + landmark_connections: 1 + landmark_connections: 1 + landmark_connections: 2 + landmark_connections: 2 + landmark_connections: 3 + landmark_connections: 3 + landmark_connections: 4 + landmark_connections: 0 + landmark_connections: 5 + landmark_connections: 5 + landmark_connections: 6 + landmark_connections: 6 + landmark_connections: 7 + landmark_connections: 7 + landmark_connections: 8 + landmark_connections: 5 + landmark_connections: 9 + landmark_connections: 9 + landmark_connections: 10 + landmark_connections: 10 + landmark_connections: 11 + landmark_connections: 11 + landmark_connections: 12 + landmark_connections: 9 + landmark_connections: 13 + landmark_connections: 13 + landmark_connections: 14 + landmark_connections: 14 + landmark_connections: 15 + landmark_connections: 15 + landmark_connections: 16 + landmark_connections: 13 + landmark_connections: 17 + landmark_connections: 0 + landmark_connections: 17 + landmark_connections: 17 + landmark_connections: 18 + landmark_connections: 18 + landmark_connections: 19 + landmark_connections: 19 + landmark_connections: 20 + landmark_color { r: 255 g: 0 b: 0 } + connection_color { r: 0 g: 255 b: 0 } + thickness: 4.0 + } + } +} + +# Converts normalized rects to drawing primitives for annotation overlay. +node { + calculator: "RectToRenderDataCalculator" + input_stream: "NORM_RECT:rect" + output_stream: "RENDER_DATA:rect_render_data" + node_options: { + [type.googleapis.com/mediapipe.RectToRenderDataCalculatorOptions] { + filled: false + color { r: 255 g: 0 b: 0 } + thickness: 4.0 + } + } +} + +# Draws annotations and overlays them on top of the input images. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME:input_image" + input_stream: "detection_render_data" + input_stream: "landmark_render_data" + input_stream: "rect_render_data" + output_stream: "OUTPUT_FRAME:output_image" +} diff --git a/mediapipe/graphs/hand_tracking/renderer_gpu.pbtxt b/mediapipe/graphs/hand_tracking/subgraphs/renderer_gpu.pbtxt similarity index 100% rename from mediapipe/graphs/hand_tracking/renderer_gpu.pbtxt rename to mediapipe/graphs/hand_tracking/subgraphs/renderer_gpu.pbtxt diff --git a/mediapipe/graphs/object_detection/BUILD b/mediapipe/graphs/object_detection/BUILD index cd1d1b6be..36c0181a9 100644 --- a/mediapipe/graphs/object_detection/BUILD +++ b/mediapipe/graphs/object_detection/BUILD @@ -56,6 +56,10 @@ cc_library( cc_library( name = "desktop_tflite_calculators", deps = [ + "//mediapipe/calculators/core:concatenate_vector_calculator", + "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/calculators/core:previous_loopback_calculator", + "//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/tflite:ssd_anchors_calculator", "//mediapipe/calculators/tflite:tflite_converter_calculator", diff --git a/mediapipe/graphs/object_detection/object_detection_desktop_live.pbtxt b/mediapipe/graphs/object_detection/object_detection_desktop_live.pbtxt new file mode 100644 index 000000000..899785a1c --- /dev/null +++ b/mediapipe/graphs/object_detection/object_detection_desktop_live.pbtxt @@ -0,0 +1,174 @@ +# MediaPipe graph that performs object detection with TensorFlow Lite on CPU. +# Used in the examples in +# mediapipie/examples/desktop/object_detection:object_detection_cpu. + +# Images on CPU coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Throttles the images flowing downstream for flow control. It passes through +# the very first incoming image unaltered, and waits for +# TfLiteTensorsToDetectionsCalculator downstream in the graph to finish +# generating the corresponding detections before it passes through another +# image. All images that come in while waiting are dropped, limiting the number +# of in-flight images between this calculator and +# TfLiteTensorsToDetectionsCalculator to 1. This prevents the nodes in between +# from queuing up incoming images and data excessively, which leads to increased +# latency and memory usage, unwanted in real-time mobile applications. It also +# eliminates unnecessarily computation, e.g., a transformed image produced by +# ImageTransformationCalculator may get dropped downstream if the subsequent +# TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy +# processing previous inputs. +node { + calculator: "FlowLimiterCalculator" + input_stream: "input_video" + input_stream: "FINISHED:detections" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_input_video" +} + +# Transforms the input image on CPU to a 320x320 image. To scale the image, by +# default it uses the STRETCH scale mode that maps the entire input image to the +# entire transformed image. As a result, image aspect ratio may be changed and +# objects in the image may be deformed (stretched or squeezed), but the object +# detection model used in this graph is agnostic to that deformation. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE:throttled_input_video" + output_stream: "IMAGE:transformed_input_video" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 320 + output_height: 320 + } + } +} + +# Converts the transformed input image on CPU into an image tensor stored as a +# TfLiteTensor. +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE:transformed_input_video" + output_stream: "TENSORS:image_tensor" +} + +# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS:image_tensor" + output_stream: "TENSORS:detection_tensors" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "mediapipe/models/ssdlite_object_detection.tflite" + } + } +} + +# Generates a single side packet containing a vector of SSD anchors based on +# the specification in the options. +node { + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + node_options: { + [type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] { + num_layers: 6 + min_scale: 0.2 + max_scale: 0.95 + input_size_height: 320 + input_size_width: 320 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 16 + strides: 32 + strides: 64 + strides: 128 + strides: 256 + strides: 512 + aspect_ratios: 1.0 + aspect_ratios: 2.0 + aspect_ratios: 0.5 + aspect_ratios: 3.0 + aspect_ratios: 0.3333 + reduce_boxes_in_lowest_layer: true + } + } +} + +# Decodes the detection tensors generated by the TensorFlow Lite model, based on +# the SSD anchors and the specification in the options, into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "TfLiteTensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] { + num_classes: 91 + num_boxes: 2034 + num_coords: 4 + ignore_classes: 0 + sigmoid_score: true + apply_exponential_on_box_size: true + x_scale: 10.0 + y_scale: 10.0 + h_scale: 5.0 + w_scale: 5.0 + min_score_thresh: 0.6 + } + } +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "detections" + output_stream: "filtered_detections" + node_options: { + [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { + min_suppression_threshold: 0.4 + max_num_detections: 3 + overlap_type: INTERSECTION_OVER_UNION + return_empty_detections: true + } + } +} + +# Maps detection label IDs to the corresponding label text. The label map is +# provided in the label_map_path option. +node { + calculator: "DetectionLabelIdToTextCalculator" + input_stream: "filtered_detections" + output_stream: "output_detections" + node_options: { + [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { + label_map_path: "mediapipe/models/ssdlite_object_detection_labelmap.txt" + } + } +} + +# Converts the detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTIONS:output_detections" + output_stream: "RENDER_DATA:render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 255 g: 0 b: 0 } + } + } +} + +# Draws annotations and overlays them on top of the input images. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME:throttled_input_video" + input_stream: "render_data" + output_stream: "OUTPUT_FRAME:output_video" +} diff --git a/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt b/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt index 4eb527a3c..3e0e4e6d3 100644 --- a/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt +++ b/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt @@ -75,7 +75,7 @@ node { output_stream: "TENSORS:detection_tensors" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { - model_path: "ssdlite_object_detection.tflite" + model_path: "mediapipe/models/ssdlite_object_detection.tflite" } } } @@ -158,7 +158,7 @@ node { output_stream: "output_detections" node_options: { [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { - label_map_path: "ssdlite_object_detection_labelmap.txt" + label_map_path: "mediapipe/models/ssdlite_object_detection_labelmap.txt" } } } diff --git a/mediapipe/graphs/object_detection/object_detection_mobile_gpu.pbtxt b/mediapipe/graphs/object_detection/object_detection_mobile_gpu.pbtxt index 44bf61057..dfed16696 100644 --- a/mediapipe/graphs/object_detection/object_detection_mobile_gpu.pbtxt +++ b/mediapipe/graphs/object_detection/object_detection_mobile_gpu.pbtxt @@ -62,10 +62,10 @@ node { node { calculator: "TfLiteInferenceCalculator" input_stream: "TENSORS_GPU:image_tensor" - output_stream: "TENSORS:detection_tensors" + output_stream: "TENSORS_GPU:detection_tensors" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { - model_path: "ssdlite_object_detection.tflite" + model_path: "mediapipe/models/ssdlite_object_detection.tflite" } } } @@ -105,7 +105,7 @@ node { # detections. Each detection describes a detected object. node { calculator: "TfLiteTensorsToDetectionsCalculator" - input_stream: "TENSORS:detection_tensors" + input_stream: "TENSORS_GPU:detection_tensors" input_side_packet: "ANCHORS:anchors" output_stream: "DETECTIONS:detections" node_options: { @@ -148,7 +148,7 @@ node { output_stream: "output_detections" node_options: { [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { - label_map_path: "ssdlite_object_detection_labelmap.txt" + label_map_path: "mediapipe/models/ssdlite_object_detection_labelmap.txt" } } } diff --git a/mediapipe/graphs/youtube8m/BUILD b/mediapipe/graphs/youtube8m/BUILD index 4bfb0d46d..be0fff44c 100644 --- a/mediapipe/graphs/youtube8m/BUILD +++ b/mediapipe/graphs/youtube8m/BUILD @@ -17,7 +17,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) cc_library( - name = "yt8m_calculators_deps", + name = "yt8m_feature_extraction_calculators", deps = [ "//mediapipe/calculators/audio:audio_decoder_calculator", "//mediapipe/calculators/audio:basic_time_series_calculators", diff --git a/mediapipe/graphs/youtube8m/feature_extraction.pbtxt b/mediapipe/graphs/youtube8m/feature_extraction.pbtxt index 42fd5988e..89d1053de 100644 --- a/mediapipe/graphs/youtube8m/feature_extraction.pbtxt +++ b/mediapipe/graphs/youtube8m/feature_extraction.pbtxt @@ -16,12 +16,16 @@ node { input_side_packet: "SEQUENCE_EXAMPLE:parsed_sequence_example" output_side_packet: "DATA_PATH:input_file" output_side_packet: "RESAMPLER_OPTIONS:packet_resampler_options" + output_side_packet: "AUDIO_DECODER_OPTIONS:audio_decoder_options" node_options: { [type.googleapis.com/mediapipe.UnpackMediaSequenceCalculatorOptions]: { base_packet_resampler_options { frame_rate: 1.0 base_timestamp: 0 } + base_audio_decoder_options { + audio_stream { stream_index: 0 } + } } } } @@ -121,13 +125,9 @@ node { node { calculator: "AudioDecoderCalculator" input_side_packet: "INPUT_FILE_PATH:input_file" + input_side_packet: "OPTIONS:audio_decoder_options" output_stream: "AUDIO:audio" output_stream: "AUDIO_HEADER:audio_header" - node_options: { - [type.googleapis.com/mediapipe.AudioDecoderOptions]: { - audio_stream { stream_index: 0 } - } - } } node { diff --git a/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java b/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java index ba506524f..c63f0495a 100644 --- a/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java +++ b/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java @@ -185,16 +185,10 @@ public class FrameProcessor implements TextureFrameProcessor { public void close() { if (started.get()) { try { - mediapipeGraph.closeAllInputStreams(); - - // TODO Add a way to signal a source calculator to stop. - // Required for a graph containing a source calculator to shut down properly. - mediapipeGraph.cancelGraph(); - + mediapipeGraph.closeAllPacketSources(); mediapipeGraph.waitUntilGraphDone(); } catch (MediaPipeException e) { - // TODO: cancelGraph will cause an exception to be raised in waitUntilGraphDone. - // We should not cancel the graph here! Also, we should handle exceptions better. + Log.e(TAG, "Mediapipe error: ", e); } try { mediapipeGraph.tearDown(); diff --git a/mediapipe/java/com/google/mediapipe/framework/Graph.java b/mediapipe/java/com/google/mediapipe/framework/Graph.java index 0feabf386..9065d4e50 100644 --- a/mediapipe/java/com/google/mediapipe/framework/Graph.java +++ b/mediapipe/java/com/google/mediapipe/framework/Graph.java @@ -16,7 +16,6 @@ package com.google.mediapipe.framework; import com.google.common.base.Preconditions; import com.google.common.flogger.FluentLogger; -import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig; import com.google.mediapipe.proto.GraphTemplateProto.CalculatorGraphTemplate; import com.google.protobuf.InvalidProtocolBufferException; @@ -119,7 +118,7 @@ public class Graph { } /** Specifies options such as template arguments for the graph. */ - public synchronized void setGraphOptions(CalculatorOptions options) { + public synchronized void setGraphOptions(CalculatorGraphConfig.Node options) { nativeSetGraphOptions(nativeGraphHandle, options.toByteArray()); } diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD index a02cd2e33..182226cbb 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD @@ -73,9 +73,6 @@ cc_library( ], "//mediapipe/gpu:disable_gpu": [], }), - copts = [ - "-DDISABLE_GOOGLE_GLOBAL_USING_DECLARATIONS", # b/33667913 - ], linkopts = select({ "//conditions:default": [], "//mediapipe:android": [ @@ -123,7 +120,9 @@ cc_library( "//mediapipe/gpu:gpu_shared_data_internal", "//mediapipe/gpu:graph_support", ], - "//mediapipe/gpu:disable_gpu": [], + "//mediapipe/gpu:disable_gpu": [ + "//mediapipe/gpu:gpu_shared_data_internal", + ], }), alwayslink = 1, ) diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph.h index 39bd91446..c6f64b6fe 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.h @@ -191,7 +191,7 @@ class Graph { // CalculatorGraphTemplates for the calculator graph and subgraphs. std::vector graph_templates_; // Options such as template arguments for the top-level calculator graph. - CalculatorOptions graph_options_; + Subgraph::SubgraphOptions graph_options_; // The CalculatorGraphConfig::type of the top-level calculator graph. std::string graph_type_ = ""; diff --git a/mediapipe/objc/MPPGraph.h b/mediapipe/objc/MPPGraph.h index e0371c081..6823aad18 100644 --- a/mediapipe/objc/MPPGraph.h +++ b/mediapipe/objc/MPPGraph.h @@ -116,6 +116,13 @@ typedef NS_ENUM(int, MPPPacketType) { /// @param name The name of the input side packet. - (void)setSidePacket:(const mediapipe::Packet&)packet named:(const std::string&)name; +/// Sets a service packet. If it was already set, it is overwritten. +/// Must be called before the graph is started. +/// @param packet The packet to be associated with the service. +/// @param service. +- (void)setServicePacket:(mediapipe::Packet&)packet + forService:(const mediapipe::GraphServiceBase&)service; + /// Adds input side packets from a map. Any inputs that were already set are /// left unchanged. /// Must be called before the graph is started. diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index 636f3edbf..8c9da2011 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -22,6 +22,7 @@ #include "absl/memory/memory.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/graph_service.h" #include "mediapipe/gpu/MPPGraphGPUData.h" #include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" @@ -38,6 +39,8 @@ std::map _inputSidePackets; /// Packet headers that will be added to the graph when it is started. std::map _streamHeaders; + /// Service packets to be added to the graph when it is started. + std::map _servicePackets; /// Number of frames currently being processed by the graph. std::atomic _framesInFlight; @@ -199,6 +202,13 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, _inputSidePackets[name] = packet; } +- (void)setServicePacket:(mediapipe::Packet&)packet + forService:(const mediapipe::GraphServiceBase&)service { + _GTMDevAssert(!_started, @"%@ must be called before the graph is started", + NSStringFromSelector(_cmd)); + _servicePackets[&service] = std::move(packet); +} + - (void)addSidePackets:(const std::map&)extraSidePackets { _GTMDevAssert(!_started, @"%@ must be called before the graph is started", NSStringFromSelector(_cmd)); @@ -206,18 +216,33 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, } - (BOOL)startWithError:(NSError**)error { + ::mediapipe::Status status = [self performStart]; + if (!status.ok()) { + if (error) { + *error = [NSError gus_errorWithStatus:status]; + } + return NO; + } + _started = YES; + return YES; +} + +- (::mediapipe::Status)performStart { ::mediapipe::Status status = _graph->Initialize(_config); - if (status.ok()) { - status = _graph->StartRun(_inputSidePackets, _streamHeaders); - if (status.ok()) { - _started = YES; - return YES; + if (!status.ok()) { + return status; + } + for (const auto& service_packet : _servicePackets) { + status = _graph->SetServicePacket(*service_packet.first, service_packet.second); + if (!status.ok()) { + return status; } } - if (error) { - *error = [NSError gus_errorWithStatus:status]; + status = _graph->StartRun(_inputSidePackets, _streamHeaders); + if (!status.ok()) { + return status; } - return NO; + return status; } - (void)cancel { diff --git a/mediapipe/objc/MPPGraphTestBase.h b/mediapipe/objc/MPPGraphTestBase.h index c9b67c264..7c457fbbe 100644 --- a/mediapipe/objc/MPPGraphTestBase.h +++ b/mediapipe/objc/MPPGraphTestBase.h @@ -61,6 +61,9 @@ /// Loads an image from the test bundle. - (UIImage*)testImageNamed:(NSString*)name extension:(NSString*)extension; +/// Returns a URL for a file.extension in the test bundle. +- (NSURL*)URLForTestFile:(NSString*)file extension:(NSString*)extension; + /// Loads an image from the test bundle in subpath. - (UIImage*)testImageNamed:(NSString*)name extension:(NSString*)extension diff --git a/mediapipe/objc/MPPGraphTestBase.mm b/mediapipe/objc/MPPGraphTestBase.mm index cfb757369..46fe42755 100644 --- a/mediapipe/objc/MPPGraphTestBase.mm +++ b/mediapipe/objc/MPPGraphTestBase.mm @@ -43,9 +43,13 @@ static void EnsureOutputDirFor(NSString *outputFile) { @implementation MPPGraphTestBase -- (NSData*)testDataNamed:(NSString*)name extension:(NSString*)extension { +- (NSURL*)URLForTestFile:(NSString*)file extension:(NSString*)extension { NSBundle* testBundle = [NSBundle bundleForClass:[self class]]; - NSURL* resourceURL = [testBundle URLForResource:name withExtension:extension]; + return [testBundle URLForResource:file withExtension:extension]; +} + +- (NSData*)testDataNamed:(NSString*)name extension:(NSString*)extension { + NSURL* resourceURL = [self URLForTestFile:name extension:extension]; XCTAssertNotNil(resourceURL, @"Unable to find data with name: %@. Did you add it to your resources?", name); NSError* error; diff --git a/mediapipe/util/android/file/base/helpers.cc b/mediapipe/util/android/file/base/helpers.cc index bfa144d9a..930f916fa 100644 --- a/mediapipe/util/android/file/base/helpers.cc +++ b/mediapipe/util/android/file/base/helpers.cc @@ -47,8 +47,9 @@ class FdCloser { const file::Options& /*options*/) { int fd = open(std::string(file_name).c_str(), O_RDONLY); if (fd < 0) { - return ::mediapipe::Status(mediapipe::StatusCode::kUnknown, - "Failed to open file"); + return ::mediapipe::Status( + mediapipe::StatusCode::kUnknown, + "Failed to open file: " + std::string(file_name)); } FdCloser closer(fd); @@ -92,8 +93,9 @@ class FdCloser { int fd = open(std::string(file_name).c_str(), O_WRONLY | O_CREAT | O_TRUNC, mode); if (fd < 0) { - return ::mediapipe::Status(mediapipe::StatusCode::kUnknown, - "Failed to open file"); + return ::mediapipe::Status( + mediapipe::StatusCode::kUnknown, + "Failed to open file: " + std::string(file_name)); } int bytes_written = 0; diff --git a/mediapipe/util/annotation_renderer.cc b/mediapipe/util/annotation_renderer.cc index 85ca2e6b7..54ba7bb17 100644 --- a/mediapipe/util/annotation_renderer.cc +++ b/mediapipe/util/annotation_renderer.cc @@ -357,7 +357,8 @@ void AnnotationRenderer::DrawFilledOval(const RenderAnnotation& annotation) { bottom = static_cast(enclosing_rectangle.bottom()); } cv::Point center((left + right) / 2, (top + bottom) / 2); - cv::Size size((right - left) / 2, (bottom - top) / 2); + cv::Size size(std::max(0, (right - left) / 2), + std::max(0, (bottom - top) / 2)); const cv::Scalar color = MediapipeColorToOpenCVColor(annotation.color()); cv::ellipse(mat_image_, center, size, 0, 0, 360, color, -1); } diff --git a/mediapipe/util/resource_util_android.cc b/mediapipe/util/resource_util_android.cc index 00d3b83f6..643739566 100644 --- a/mediapipe/util/resource_util_android.cc +++ b/mediapipe/util/resource_util_android.cc @@ -21,13 +21,38 @@ namespace mediapipe { +namespace { +::mediapipe::StatusOr PathToResourceAsFileInternal( + const std::string& path) { + return Singleton::get()->CachedFileFromAsset(path); +} +} // namespace + ::mediapipe::StatusOr PathToResourceAsFile( const std::string& path) { + // Return full path. if (absl::StartsWith(path, "/")) { return path; } - return Singleton::get()->CachedFileFromAsset(path); + // Try to load a relative path or a base filename as is. + { + auto status_or_path = PathToResourceAsFileInternal(path); + if (status_or_path.ok()) { + LOG(INFO) << "Successfully loaded: " << path; + return status_or_path; + } + } + + // If that fails, assume it was a relative path, and try just the base name. + { + const size_t last_slash_idx = path.find_last_of("\\/"); + CHECK_NE(last_slash_idx, std::string::npos); // Make sure it's a path. + auto base_name = path.substr(last_slash_idx + 1); + auto status_or_path = PathToResourceAsFileInternal(base_name); + if (status_or_path.ok()) LOG(INFO) << "Successfully loaded: " << base_name; + return status_or_path; + } } ::mediapipe::Status GetResourceContents(const std::string& path, diff --git a/mediapipe/util/resource_util_apple.cc b/mediapipe/util/resource_util_apple.cc index bcd90fe1a..9b7677679 100644 --- a/mediapipe/util/resource_util_apple.cc +++ b/mediapipe/util/resource_util_apple.cc @@ -23,25 +23,48 @@ namespace mediapipe { -::mediapipe::StatusOr PathToResourceAsFile( +namespace { +::mediapipe::StatusOr PathToResourceAsFileInternal( const std::string& path) { - if (absl::StartsWith(path, "/")) { - return path; - } - NSString* ns_path = [NSString stringWithUTF8String:path.c_str()]; Class mediapipeGraphClass = NSClassFromString(@"MPPGraph"); NSString* resource_dir = [[NSBundle bundleForClass:mediapipeGraphClass] resourcePath]; NSString* resolved_ns_path = [resource_dir stringByAppendingPathComponent:ns_path]; - std::string resolved_path = [resolved_ns_path UTF8String]; RET_CHECK([[NSFileManager defaultManager] fileExistsAtPath:resolved_ns_path]) << "cannot find file: " << resolved_path; - return resolved_path; } +} // namespace + +::mediapipe::StatusOr PathToResourceAsFile( + const std::string& path) { + // Return full path. + if (absl::StartsWith(path, "/")) { + return path; + } + + // Try to load a relative path or a base filename as is. + { + auto status_or_path = PathToResourceAsFileInternal(path); + if (status_or_path.ok()) { + LOG(INFO) << "Successfully loaded: " << path; + return status_or_path; + } + } + + // If that fails, assume it was a relative path, and try just the base name. + { + const size_t last_slash_idx = path.find_last_of("\\/"); + CHECK_NE(last_slash_idx, std::string::npos); // Make sure it's a path. + auto base_name = path.substr(last_slash_idx + 1); + auto status_or_path = PathToResourceAsFileInternal(base_name); + if (status_or_path.ok()) LOG(INFO) << "Successfully loaded: " << base_name; + return status_or_path; + } +} ::mediapipe::Status GetResourceContents(const std::string& path, std::string* output) { diff --git a/third_party/com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff b/third_party/com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff new file mode 100644 index 000000000..776e6d671 --- /dev/null +++ b/third_party/com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff @@ -0,0 +1,58 @@ +commit 9779e5ea6ef59562b030248947f787d1256132ae +Author: jqtang +Date: Wed Sep 18 11:43:48 2019 -0700 + + Add glog Android support for MediaPipe. + +diff --git a/src/logging.cc b/src/logging.cc +index 0b5e6ee..be5a506 100644 +--- a/src/logging.cc ++++ b/src/logging.cc +@@ -67,6 +67,10 @@ + # include "stacktrace.h" + #endif + ++#ifdef __ANDROID__ ++#include ++#endif ++ + using std::string; + using std::vector; + using std::setw; +@@ -1279,6 +1283,23 @@ ostream& LogMessage::stream() { + return data_->stream_; + } + ++namespace { ++#if defined(__ANDROID__) ++int AndroidLogLevel(const int severity) { ++ switch (severity) { ++ case 3: ++ return ANDROID_LOG_FATAL; ++ case 2: ++ return ANDROID_LOG_ERROR; ++ case 1: ++ return ANDROID_LOG_WARN; ++ default: ++ return ANDROID_LOG_INFO; ++ } ++} ++#endif // defined(__ANDROID__) ++} // namespace ++ + // Flush buffered message, called by the destructor, or any other function + // that needs to synchronize the log. + void LogMessage::Flush() { +@@ -1313,6 +1334,12 @@ void LogMessage::Flush() { + } + LogDestination::WaitForSinks(data_); + ++#if defined(__ANDROID__) ++ const int level = AndroidLogLevel((int)data_->severity_); ++ const std::string text = std::string(data_->message_text_); ++ __android_log_write(level, "native", text.substr(0,data_->num_chars_to_log_).c_str()); ++#endif // !defined(__ANDROID__) ++ + if (append_newline) { + // Fix the ostrstream back how it was before we screwed with it. + // It's 99.44% certain that we don't need to worry about doing this. \ No newline at end of file