From f645c597463b7be8b5da6dcb365bb34520d7c996 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 23 Aug 2023 10:42:26 -0700 Subject: [PATCH] Move stream API image_size to third_party. PiperOrigin-RevId: 559475476 --- mediapipe/framework/api2/stream/BUILD | 46 ++++++++++++++- mediapipe/framework/api2/stream/image_size.h | 34 +++++++++++ .../framework/api2/stream/image_size_test.cc | 57 +++++++++++++++++++ .../framework/api2/stream/loopback_test.cc | 33 +++++------ 4 files changed, 153 insertions(+), 17 deletions(-) create mode 100644 mediapipe/framework/api2/stream/image_size.h create mode 100644 mediapipe/framework/api2/stream/image_size_test.cc diff --git a/mediapipe/framework/api2/stream/BUILD b/mediapipe/framework/api2/stream/BUILD index f9f371d2f..4444938ac 100644 --- a/mediapipe/framework/api2/stream/BUILD +++ b/mediapipe/framework/api2/stream/BUILD @@ -10,5 +10,49 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", ], - alwayslink = 1, +) + +cc_test( + name = "loopback_test", + srcs = ["loopback_test.cc"], + deps = [ + ":loopback", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status_matchers", + ], +) + +cc_library( + name = "image_size", + hdrs = ["image_size.h"], + deps = [ + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/gpu:gpu_buffer", + ], +) + +cc_test( + name = "image_size_test", + srcs = ["image_size_test.cc"], + deps = [ + ":image_size", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status_matchers", + "//mediapipe/gpu:gpu_buffer", + ], ) diff --git a/mediapipe/framework/api2/stream/image_size.h b/mediapipe/framework/api2/stream/image_size.h new file mode 100644 index 000000000..b726f07a9 --- /dev/null +++ b/mediapipe/framework/api2/stream/image_size.h @@ -0,0 +1,34 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_IMAGE_SIZE_H_ +#define MEDIAPIPE_FRAMEWORK_API2_STREAM_IMAGE_SIZE_H_ + +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/gpu/gpu_buffer.h" + +namespace mediapipe::api2::builder { + +// Updates graph to calculate image size and returns corresponding stream. +// +// @image image represented as ImageFrame/Image/GpuBuffer. +// @graph graph to update. +template +Stream> GetImageSize( + Stream image, mediapipe::api2::builder::Graph& graph) { + auto& img_props_node = graph.AddNode("ImagePropertiesCalculator"); + if constexpr (std::is_same_v || + std::is_same_v) { + image.ConnectTo(img_props_node.In("IMAGE")); + } else if constexpr (std::is_same_v) { + image.ConnectTo(img_props_node.In("IMAGE_GPU")); + } else { + static_assert(dependent_false::value, "Type not supported."); + } + return img_props_node.Out("SIZE").Cast>(); +} + +} // namespace mediapipe::api2::builder + +#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_IMAGE_SIZE_H_ diff --git a/mediapipe/framework/api2/stream/image_size_test.cc b/mediapipe/framework/api2/stream/image_size_test.cc new file mode 100644 index 000000000..3b080ba02 --- /dev/null +++ b/mediapipe/framework/api2/stream/image_size_test.cc @@ -0,0 +1,57 @@ +#include "mediapipe/framework/api2/stream/image_size.h" + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/gpu/gpu_buffer.h" + +namespace mediapipe::api2::builder { +namespace { + +TEST(GetImageSize, VerifyConfig) { + Graph graph; + + Stream image_frame = graph.In("IMAGE_FRAME").Cast(); + image_frame.SetName("image_frame"); + Stream gpu_buffer = graph.In("GPU_BUFFER").Cast(); + gpu_buffer.SetName("gpu_buffer"); + Stream image = graph.In("IMAGE").Cast(); + image.SetName("image"); + + GetImageSize(image_frame, graph).SetName("image_frame_size"); + GetImageSize(gpu_buffer, graph).SetName("gpu_buffer_size"); + GetImageSize(image, graph).SetName("image_size"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE:image_frame" + output_stream: "SIZE:image_frame_size" + } + node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE_GPU:gpu_buffer" + output_stream: "SIZE:gpu_buffer_size" + } + node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE:image" + output_stream: "SIZE:image_size" + } + input_stream: "GPU_BUFFER:gpu_buffer" + input_stream: "IMAGE:image" + input_stream: "IMAGE_FRAME:image_frame" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} +} // namespace +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/stream/loopback_test.cc b/mediapipe/framework/api2/stream/loopback_test.cc index 8b5694db9..50c3041e2 100644 --- a/mediapipe/framework/api2/stream/loopback_test.cc +++ b/mediapipe/framework/api2/stream/loopback_test.cc @@ -6,6 +6,7 @@ #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" namespace mediapipe::api2::builder { namespace { @@ -33,22 +34,22 @@ TEST(LoopbackTest, GetLoopbackData) { // PreviousLoopbackCalculator configuration is incorrect here and should be // updated when corresponding b/175887687 is fixed. // Use mediapipe::aimatter::GraphBuilder to fix back edges in the graph. - EXPECT_THAT(graph.GetConfig(), - testing::EqualsProto( - mediapipe::ParseTextProtoOrDie(R"pb( - node { - calculator: "PreviousLoopbackCalculator" - input_stream: "LOOP:__stream_2" - input_stream: "MAIN:__stream_0" - output_stream: "PREV_LOOP:__stream_1" - } - node { - calculator: "TestDataProducer" - input_stream: "LOOPBACK_DATA:__stream_1" - output_stream: "PRODUCED_DATA:__stream_2" - } - input_stream: "TICK:__stream_0" - )pb"))); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "PreviousLoopbackCalculator" + input_stream: "LOOP:__stream_2" + input_stream: "MAIN:__stream_0" + output_stream: "PREV_LOOP:__stream_1" + } + node { + calculator: "TestDataProducer" + input_stream: "LOOPBACK_DATA:__stream_1" + output_stream: "PRODUCED_DATA:__stream_2" + } + input_stream: "TICK:__stream_0" + )pb"))); } } // namespace