Move stream API image_size to third_party.

PiperOrigin-RevId: 559475476
This commit is contained in:
MediaPipe Team 2023-08-23 10:42:26 -07:00 committed by Copybara-Service
parent 8689f4f595
commit f645c59746
4 changed files with 153 additions and 17 deletions

View File

@ -10,5 +10,49 @@ cc_library(
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
], ],
alwayslink = 1, )
cc_test(
name = "loopback_test",
srcs = ["loopback_test.cc"],
deps = [
":loopback",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status_matchers",
],
)
cc_library(
name = "image_size",
hdrs = ["image_size.h"],
deps = [
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/gpu:gpu_buffer",
],
)
cc_test(
name = "image_size_test",
srcs = ["image_size_test.cc"],
deps = [
":image_size",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status_matchers",
"//mediapipe/gpu:gpu_buffer",
],
) )

View File

@ -0,0 +1,34 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_IMAGE_SIZE_H_
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_IMAGE_SIZE_H_
#include <utility>
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/gpu/gpu_buffer.h"
namespace mediapipe::api2::builder {
// Updates graph to calculate image size and returns corresponding stream.
//
// @image image represented as ImageFrame/Image/GpuBuffer.
// @graph graph to update.
template <typename ImageT>
Stream<std::pair<int, int>> GetImageSize(
Stream<ImageT> image, mediapipe::api2::builder::Graph& graph) {
auto& img_props_node = graph.AddNode("ImagePropertiesCalculator");
if constexpr (std::is_same_v<ImageT, ImageFrame> ||
std::is_same_v<ImageT, mediapipe::Image>) {
image.ConnectTo(img_props_node.In("IMAGE"));
} else if constexpr (std::is_same_v<ImageT, GpuBuffer>) {
image.ConnectTo(img_props_node.In("IMAGE_GPU"));
} else {
static_assert(dependent_false<ImageT>::value, "Type not supported.");
}
return img_props_node.Out("SIZE").Cast<std::pair<int, int>>();
}
} // namespace mediapipe::api2::builder
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_IMAGE_SIZE_H_

View File

@ -0,0 +1,57 @@
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/gpu/gpu_buffer.h"
namespace mediapipe::api2::builder {
namespace {
TEST(GetImageSize, VerifyConfig) {
Graph graph;
Stream<ImageFrame> image_frame = graph.In("IMAGE_FRAME").Cast<ImageFrame>();
image_frame.SetName("image_frame");
Stream<GpuBuffer> gpu_buffer = graph.In("GPU_BUFFER").Cast<GpuBuffer>();
gpu_buffer.SetName("gpu_buffer");
Stream<mediapipe::Image> image = graph.In("IMAGE").Cast<mediapipe::Image>();
image.SetName("image");
GetImageSize(image_frame, graph).SetName("image_frame_size");
GetImageSize(gpu_buffer, graph).SetName("gpu_buffer_size");
GetImageSize(image, graph).SetName("image_size");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ImagePropertiesCalculator"
input_stream: "IMAGE:image_frame"
output_stream: "SIZE:image_frame_size"
}
node {
calculator: "ImagePropertiesCalculator"
input_stream: "IMAGE_GPU:gpu_buffer"
output_stream: "SIZE:gpu_buffer_size"
}
node {
calculator: "ImagePropertiesCalculator"
input_stream: "IMAGE:image"
output_stream: "SIZE:image_size"
}
input_stream: "GPU_BUFFER:gpu_buffer"
input_stream: "IMAGE:image"
input_stream: "IMAGE_FRAME:image_frame"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
} // namespace
} // namespace mediapipe::api2::builder

View File

@ -6,6 +6,7 @@
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe::api2::builder { namespace mediapipe::api2::builder {
namespace { namespace {
@ -33,22 +34,22 @@ TEST(LoopbackTest, GetLoopbackData) {
// PreviousLoopbackCalculator configuration is incorrect here and should be // PreviousLoopbackCalculator configuration is incorrect here and should be
// updated when corresponding b/175887687 is fixed. // updated when corresponding b/175887687 is fixed.
// Use mediapipe::aimatter::GraphBuilder to fix back edges in the graph. // Use mediapipe::aimatter::GraphBuilder to fix back edges in the graph.
EXPECT_THAT(graph.GetConfig(), EXPECT_THAT(
testing::EqualsProto( graph.GetConfig(),
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node { node {
calculator: "PreviousLoopbackCalculator" calculator: "PreviousLoopbackCalculator"
input_stream: "LOOP:__stream_2" input_stream: "LOOP:__stream_2"
input_stream: "MAIN:__stream_0" input_stream: "MAIN:__stream_0"
output_stream: "PREV_LOOP:__stream_1" output_stream: "PREV_LOOP:__stream_1"
} }
node { node {
calculator: "TestDataProducer" calculator: "TestDataProducer"
input_stream: "LOOPBACK_DATA:__stream_1" input_stream: "LOOPBACK_DATA:__stream_1"
output_stream: "PRODUCED_DATA:__stream_2" output_stream: "PRODUCED_DATA:__stream_2"
} }
input_stream: "TICK:__stream_0" input_stream: "TICK:__stream_0"
)pb"))); )pb")));
} }
} // namespace } // namespace