From bbf40cba87cdd773bd519c7ec672562270955439 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 19 Sep 2023 13:16:11 -0700 Subject: [PATCH] split stream utility function. PiperOrigin-RevId: 566722901 --- mediapipe/calculators/core/BUILD | 1 + .../core/split_proto_list_calculator.cc | 13 + mediapipe/framework/api2/stream/BUILD | 37 ++ mediapipe/framework/api2/stream/split.h | 335 +++++++++++++ mediapipe/framework/api2/stream/split_test.cc | 473 ++++++++++++++++++ mediapipe/framework/formats/BUILD | 16 + 6 files changed, 875 insertions(+) create mode 100644 mediapipe/framework/api2/stream/split.h create mode 100644 mediapipe/framework/api2/stream/split_test.cc diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 02efc84ea..219803b03 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -944,6 +944,7 @@ cc_library( deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:body_rig_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/calculators/core/split_proto_list_calculator.cc b/mediapipe/calculators/core/split_proto_list_calculator.cc index fc1fd4df1..5f8bcf169 100644 --- a/mediapipe/calculators/core/split_proto_list_calculator.cc +++ b/mediapipe/calculators/core/split_proto_list_calculator.cc @@ -17,6 +17,7 @@ #include "mediapipe/calculators/core/split_vector_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/body_rig.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" @@ -196,6 +197,18 @@ class SplitLandmarkListCalculator }; REGISTER_CALCULATOR(SplitLandmarkListCalculator); +class SplitJointListCalculator : public SplitListsCalculator { + protected: + int ListSize(const JointList& list) const override { + return list.joint_size(); + } + const Joint GetItem(const JointList& list, int idx) const override { + return list.joint(idx); + } + Joint* AddItem(JointList& list) const override { return list.add_joint(); } +}; +REGISTER_CALCULATOR(SplitJointListCalculator); + } // namespace mediapipe // NOLINTNEXTLINE diff --git a/mediapipe/framework/api2/stream/BUILD b/mediapipe/framework/api2/stream/BUILD index 8c6e9bd18..1391165fa 100644 --- a/mediapipe/framework/api2/stream/BUILD +++ b/mediapipe/framework/api2/stream/BUILD @@ -192,3 +192,40 @@ cc_test( "//mediapipe/framework/port:status_matchers", ], ) + +cc_library( + name = "split", + hdrs = ["split.h"], + deps = [ + "//mediapipe/calculators/core:split_proto_list_calculator", + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/calculators/core:split_vector_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:body_rig_cc_proto", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "@org_tensorflow//tensorflow/lite/c:common", + ], +) + +cc_test( + name = "split_test", + srcs = ["split_test.cc"], + deps = [ + ":split", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status_matchers", + ], +) diff --git a/mediapipe/framework/api2/stream/split.h b/mediapipe/framework/api2/stream/split.h new file mode 100644 index 000000000..6e723336a --- /dev/null +++ b/mediapipe/framework/api2/stream/split.h @@ -0,0 +1,335 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_SPLIT_H_ +#define MEDIAPIPE_FRAMEWORK_API2_STREAM_SPLIT_H_ + +#include +#include +#include +#include +#include + +#include "mediapipe/calculators/core/split_vector_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/body_rig.pb.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "tensorflow/lite/c/common.h" + +namespace mediapipe::api2::builder { + +namespace stream_split_internal { + +// Helper function that adds a node to a graph, that is capable of splitting a +// specific type (T). +template +mediapipe::api2::builder::GenericNode& AddSplitVectorNode( + mediapipe::api2::builder::Graph& graph) { + if constexpr (std::is_same_v>) { + return graph.AddNode("SplitTfLiteTensorVectorCalculator"); + } else if constexpr (std::is_same_v>) { + return graph.AddNode("SplitTensorVectorCalculator"); + } else if constexpr (std::is_same_v>) { + return graph.AddNode("SplitUint64tVectorCalculator"); + } else if constexpr (std::is_same_v< + T, std::vector>) { + return graph.AddNode("SplitLandmarkVectorCalculator"); + } else if constexpr (std::is_same_v< + T, std::vector>) { + return graph.AddNode("SplitNormalizedLandmarkListVectorCalculator"); + } else if constexpr (std::is_same_v>) { + return graph.AddNode("SplitNormalizedRectVectorCalculator"); + } else if constexpr (std::is_same_v>) { + return graph.AddNode("SplitMatrixVectorCalculator"); + } else if constexpr (std::is_same_v>) { + return graph.AddNode("SplitDetectionVectorCalculator"); + } else if constexpr (std::is_same_v< + T, std::vector>) { + return graph.AddNode("SplitClassificationListVectorCalculator"); + } else if constexpr (std::is_same_v) { + return graph.AddNode("SplitNormalizedLandmarkListCalculator"); + } else if constexpr (std::is_same_v) { + return graph.AddNode("SplitLandmarkListCalculator"); + } else if constexpr (std::is_same_v) { + return graph.AddNode("SplitJointListCalculator"); + } else { + static_assert(dependent_false::value, + "Split node is not available for the specified type."); + } +} + +template +struct split_result_item { + using type = typename T::value_type; +}; +template <> +struct split_result_item { + using type = mediapipe::NormalizedLandmark; +}; +template <> +struct split_result_item { + using type = mediapipe::Landmark; +}; + +template +struct split_result_item { + using type = std::vector; +}; +template <> +struct split_result_item { + using type = mediapipe::NormalizedLandmarkList; +}; +template <> +struct split_result_item { + using type = mediapipe::LandmarkList; +}; + +template +auto Split(Stream items, I begin, I end, + mediapipe::api2::builder::Graph& graph) { + auto& splitter = AddSplitVectorNode(graph); + items.ConnectTo(splitter.In("")); + + constexpr bool kIteratorContainsRanges = + std::is_same_v::value_type, + std::pair>; + using R = + typename split_result_item::type; + auto& splitter_opts = + splitter.template GetOptions(); + if constexpr (!kIteratorContainsRanges) { + splitter_opts.set_element_only(true); + } + std::vector> result; + int output = 0; + for (auto it = begin; it != end; ++it) { + auto* range = splitter_opts.add_ranges(); + if constexpr (kIteratorContainsRanges) { + range->set_begin(it->first); + range->set_end(it->second); + } else { + range->set_begin(*it); + range->set_end(*it + 1); + } + result.push_back(splitter.Out("")[output++].template Cast()); + } + return result; +} + +template +Stream SplitAndCombine(Stream items, I begin, I end, + mediapipe::api2::builder::Graph& graph) { + auto& splitter = AddSplitVectorNode(graph); + items.ConnectTo(splitter.In("")); + + constexpr bool kIteratorContainsRanges = + std::is_same_v::value_type, + std::pair>; + + auto& splitter_opts = + splitter.template GetOptions(); + splitter_opts.set_combine_outputs(true); + + for (auto it = begin; it != end; ++it) { + auto* range = splitter_opts.add_ranges(); + if constexpr (kIteratorContainsRanges) { + range->set_begin(it->first); + range->set_end(it->second); + } else { + range->set_begin(*it); + range->set_end(*it + 1); + } + } + return splitter.Out("").template Cast(); +} + +} // namespace stream_split_internal + +// Splits stream containing a collection based on passed @indices into a vector +// of streams where each steam repesents individual item of a collection. +// +// Example: +// ``` +// +// Graph graph; +// std::vector indices = {0, 1, 2, 3}; +// +// Stream> detections = ...; +// std::vector> detections_split = +// Split(detections, indices, graph); +// +// Stream landmarks = ...; +// std::vector> landmarks_split = +// Split(landmarks, indices, graph); +// +// ``` +template +auto Split(Stream items, const I& indices, + mediapipe::api2::builder::Graph& graph) { + return stream_split_internal::Split(items, indices.begin(), indices.end(), + graph); +} +// Splits stream containing a collection based on passed @indices into a vector +// of streams where each steam repesents individual item of a collection. +// +// Example: +// ``` +// +// Graph graph; +// std::vector indices = {0, 1, 2, 3}; +// +// Stream> detections = ...; +// std::vector> detections_split = +// Split(detections, indices, graph); +// +// Stream landmarks = ...; +// std::vector> landmarks_split = +// Split(landmarks, indices, graph); +// +// ``` +template +auto Split(Stream items, std::initializer_list indices, + mediapipe::api2::builder::Graph& graph) { + return stream_split_internal::Split(items, indices.begin(), indices.end(), + graph); +} + +// Splits stream containing a collection into a sub ranges, each represented as +// a stream containing same collection type. +// +// Example: +// ``` +// +// Graph graph; +// std::vector> ranges = {{0, 3}, {7, 10}}; +// +// Stream> detections = ...; +// std::vector>> detections_split = +// SplitToRanges(detections, ranges, graph); +// +// Stream landmarks = ...; +// std::vector> landmarks_split = +// SplitToRanges(landmarks, ranges, graph); +// +// ``` +template +auto SplitToRanges(Stream items, const RangeT& ranges, + mediapipe::api2::builder::Graph& graph) { + return stream_split_internal::Split(items, ranges.begin(), ranges.end(), + graph); +} + +// Splits stream containing a collection into a sub ranges, each represented as +// a stream containing same collection type. +// +// Example: +// ``` +// +// Graph graph; +// std::vector> ranges = {{0, 3}, {7, 10}}; +// +// Stream> detections = ...; +// std::vector>> detections_split = +// SplitToRanges(detections, ranges, graph); +// +// Stream landmarks = ...; +// std::vector> landmarks_split = +// SplitToRanges(landmarks, ranges, graph); +// +// ``` +template +auto SplitToRanges(Stream items, + std::initializer_list> ranges, + mediapipe::api2::builder::Graph& graph) { + return stream_split_internal::Split(items, ranges.begin(), ranges.end(), + graph); +} + +// Splits stream containing a collection into a sub ranges and combines them +// into a stream containing same collection type. +// +// Example: +// ``` +// +// Graph graph; +// std::vector> ranges = {{0, 3}, {7, 10}}; +// +// Stream> detections = ...; +// Stream> detections_split_and_combined = +// SplitAndCombine(detections, ranges, graph); +// +// Stream landmarks = ...; +// Stream landmarks_split_and_combined = +// SplitAndCombine(landmarks, ranges, graph); +// +// ``` +template +Stream SplitAndCombine(Stream items, + const RangeT& ranges, + mediapipe::api2::builder::Graph& graph) { + return stream_split_internal::SplitAndCombine(items, ranges.begin(), + ranges.end(), graph); +} + +// Splits stream containing a collection into a sub ranges and combines them +// into a stream containing same collection type. +// +// Example: +// ``` +// +// Graph graph; +// +// Stream> detections = ...; +// Stream> detections_split_and_combined = +// SplitAndCombine(detections, {{0, 3}, {7, 10}}, graph); +// +// Stream landmarks = ...; +// Stream landmarks_split_and_combined = +// SplitAndCombine(landmarks, {{0, 3}, {7, 10}}, graph); +// +// ``` +template +Stream SplitAndCombine( + Stream items, + std::initializer_list> ranges, + mediapipe::api2::builder::Graph& graph) { + return stream_split_internal::SplitAndCombine(items, ranges.begin(), + ranges.end(), graph); +} + +// Splits stream containing a collection into individual items and combines them +// into a stream containing same collection type. +// +// Example: +// ``` +// +// Graph graph; +// +// Stream> detections = ...; +// Stream> detections_split_and_combined = +// SplitAndCombine(detections, {0, 7, 10}, graph); +// +// Stream landmarks = ...; +// Stream landmarks_split_and_combined = +// SplitAndCombine(landmarks, {0, 7, 10}, graph); +// +// ``` +template +Stream SplitAndCombine(Stream items, + std::initializer_list ranges, + mediapipe::api2::builder::Graph& graph) { + return stream_split_internal::SplitAndCombine(items, ranges.begin(), + ranges.end(), graph); +} + +} // namespace mediapipe::api2::builder + +#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_SPLIT_H_ diff --git a/mediapipe/framework/api2/stream/split_test.cc b/mediapipe/framework/api2/stream/split_test.cc new file mode 100644 index 000000000..170909c72 --- /dev/null +++ b/mediapipe/framework/api2/stream/split_test.cc @@ -0,0 +1,473 @@ +#include "mediapipe/framework/api2/stream/split.h" + +#include +#include +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe::api2::builder { +namespace { + +TEST(SplitTest, SplitToRanges2Ranges) { + Graph graph; + Stream> tensors = + graph.In("TENSORS").Cast>(); + std::vector>> result = + SplitToRanges(tensors, {{0, 1}, {1, 2}}, graph); + EXPECT_EQ(result.size(), 2); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SplitTensorVectorCalculator" + input_stream: "__stream_0" + output_stream: "__stream_1" + output_stream: "__stream_2" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges { begin: 0 end: 1 } + ranges { begin: 1 end: 2 } + } + } + } + input_stream: "TENSORS:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(SplitTest, Split2Items) { + Graph graph; + Stream> tensors = + graph.In("TENSORS").Cast>(); + std::vector> result = Split(tensors, {0, 1}, graph); + EXPECT_EQ(result.size(), 2); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SplitTensorVectorCalculator" + input_stream: "__stream_0" + output_stream: "__stream_1" + output_stream: "__stream_2" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges { begin: 0 end: 1 } + ranges { begin: 1 end: 2 } + element_only: true + } + } + } + input_stream: "TENSORS:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(SplitTest, Split2Uint64tItems) { + Graph graph; + Stream> ids = + graph.In("IDS").Cast>(); + std::vector> result = Split(ids, {0, 1}, graph); + EXPECT_EQ(result.size(), 2); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SplitUint64tVectorCalculator" + input_stream: "__stream_0" + output_stream: "__stream_1" + output_stream: "__stream_2" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges { begin: 0 end: 1 } + ranges { begin: 1 end: 2 } + element_only: true + } + } + } + input_stream: "IDS:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(SplitTest, SplitToRanges5Ranges) { + Graph graph; + Stream> tensors = + graph.In("RECTS").Cast>(); + std::vector>> result = + SplitToRanges(tensors, {{0, 1}, {1, 2}, {2, 3}, {3, 4}, {4, 5}}, graph); + EXPECT_EQ(result.size(), 5); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SplitNormalizedRectVectorCalculator" + input_stream: "__stream_0" + output_stream: "__stream_1" + output_stream: "__stream_2" + output_stream: "__stream_3" + output_stream: "__stream_4" + output_stream: "__stream_5" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges { begin: 0 end: 1 } + ranges { begin: 1 end: 2 } + ranges { begin: 2 end: 3 } + ranges { begin: 3 end: 4 } + ranges { begin: 4 end: 5 } + } + } + } + input_stream: "RECTS:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(SplitTest, Split5Items) { + Graph graph; + Stream> tensors = + graph.In("RECTS").Cast>(); + std::vector> result = + Split(tensors, {0, 1, 2, 3, 4}, graph); + EXPECT_EQ(result.size(), 5); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SplitNormalizedRectVectorCalculator" + input_stream: "__stream_0" + output_stream: "__stream_1" + output_stream: "__stream_2" + output_stream: "__stream_3" + output_stream: "__stream_4" + output_stream: "__stream_5" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges { begin: 0 end: 1 } + ranges { begin: 1 end: 2 } + ranges { begin: 2 end: 3 } + ranges { begin: 3 end: 4 } + ranges { begin: 4 end: 5 } + element_only: true + } + } + } + input_stream: "RECTS:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(SplitTest, SplitPassingVectorIndices) { + Graph graph; + Stream> tensors = + graph.In("RECTS").Cast>(); + std::vector indices = {250, 300}; + std::vector> second_split_result = + Split(tensors, indices, graph); + EXPECT_EQ(second_split_result.size(), 2); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SplitNormalizedRectVectorCalculator" + input_stream: "__stream_0" + output_stream: "__stream_1" + output_stream: "__stream_2" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges { begin: 250 end: 251 } + ranges { begin: 300 end: 301 } + element_only: true + } + } + } + input_stream: "RECTS:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(SplitTest, SplitToRangesPassingVectorRanges) { + Graph graph; + Stream> tensors = + graph.In("RECTS").Cast>(); + std::vector> indices = {{250, 255}, {300, 301}}; + std::vector>> second_split_result = + SplitToRanges(tensors, indices, graph); + EXPECT_EQ(second_split_result.size(), 2); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SplitNormalizedRectVectorCalculator" + input_stream: "__stream_0" + output_stream: "__stream_1" + output_stream: "__stream_2" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges { begin: 250 end: 255 } + ranges { begin: 300 end: 301 } + } + } + } + input_stream: "RECTS:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(SplitTest, SplitToRangesNormalizedLandmarkList) { + Graph graph; + Stream tensors = + graph.In("LM_LIST").Cast(); + std::vector> indices = {{250, 255}, {300, 301}}; + std::vector> second_split_result = + SplitToRanges(tensors, indices, graph); + EXPECT_EQ(second_split_result.size(), 2); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SplitNormalizedLandmarkListCalculator" + input_stream: "__stream_0" + output_stream: "__stream_1" + output_stream: "__stream_2" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges { begin: 250 end: 255 } + ranges { begin: 300 end: 301 } + } + } + } + input_stream: "LM_LIST:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(SplitTest, SplitNormalizedLandmarkList) { + Graph graph; + Stream tensors = + graph.In("LM_LIST").Cast(); + std::vector> second_split_result = + Split(tensors, {250, 300}, graph); + EXPECT_EQ(second_split_result.size(), 2); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SplitNormalizedLandmarkListCalculator" + input_stream: "__stream_0" + output_stream: "__stream_1" + output_stream: "__stream_2" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges { begin: 250 end: 251 } + ranges { begin: 300 end: 301 } + element_only: true + } + } + } + input_stream: "LM_LIST:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(SplitTest, SplitAndCombineRanges) { + Graph graph; + Stream> tensors = + graph.In("TENSORS").Cast>(); + Stream> result = + SplitAndCombine(tensors, {{0, 1}, {2, 5}, {70, 75}}, graph); + result.SetName("tensors_split_and_combined"); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SplitTensorVectorCalculator" + input_stream: "__stream_0" + output_stream: "tensors_split_and_combined" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges { begin: 0 end: 1 } + ranges { begin: 2 end: 5 } + ranges { begin: 70 end: 75 } + combine_outputs: true + } + } + } + input_stream: "TENSORS:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(SplitTest, SplitAndCombineIndividualIndices) { + Graph graph; + Stream> tensors = + graph.In("TENSORS").Cast>(); + Stream> result = + SplitAndCombine(tensors, {0, 2, 70, 100}, graph); + result.SetName("tensors_split_and_combined"); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SplitTensorVectorCalculator" + input_stream: "__stream_0" + output_stream: "tensors_split_and_combined" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges { begin: 0 end: 1 } + ranges { begin: 2 end: 3 } + ranges { begin: 70 end: 71 } + ranges { begin: 100 end: 101 } + combine_outputs: true + } + } + } + input_stream: "TENSORS:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(SplitTest, SplitAndCombineLandmarkList) { + Graph graph; + Stream tensors = graph.In("LM_LIST").Cast(); + std::vector> ranges = {{250, 255}, {300, 301}}; + Stream landmark_list = SplitAndCombine(tensors, ranges, graph); + landmark_list.SetName("landmark_list"); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SplitLandmarkListCalculator" + input_stream: "__stream_0" + output_stream: "landmark_list" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges { begin: 250 end: 255 } + ranges { begin: 300 end: 301 } + combine_outputs: true + } + } + } + input_stream: "LM_LIST:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(SplitTest, SplitAndCombineLandmarkListIndividualIndices) { + Graph graph; + Stream tensors = graph.In("LM_LIST").Cast(); + std::vector indices = {250, 300}; + Stream landmark_list = SplitAndCombine(tensors, indices, graph); + landmark_list.SetName("landmark_list"); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SplitLandmarkListCalculator" + input_stream: "__stream_0" + output_stream: "landmark_list" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges { begin: 250 end: 251 } + ranges { begin: 300 end: 301 } + combine_outputs: true + } + } + } + input_stream: "LM_LIST:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(SplitTest, SplitAndCombineJointList) { + Graph graph; + Stream tensors = graph.In("JT_LIST").Cast(); + std::vector> ranges = {{250, 255}, {300, 301}}; + Stream joint_list = SplitAndCombine(tensors, ranges, graph); + joint_list.SetName("joint_list"); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SplitJointListCalculator" + input_stream: "__stream_0" + output_stream: "joint_list" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges { begin: 250 end: 255 } + ranges { begin: 300 end: 301 } + combine_outputs: true + } + } + } + input_stream: "JT_LIST:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(SplitTest, SplitAndCombineJointListIndividualIndices) { + Graph graph; + Stream tensors = graph.In("LM_LIST").Cast(); + std::vector indices = {250, 300}; + Stream joint_list = SplitAndCombine(tensors, indices, graph); + joint_list.SetName("joint_list"); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SplitJointListCalculator" + input_stream: "__stream_0" + output_stream: "joint_list" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges { begin: 250 end: 251 } + ranges { begin: 300 end: 301 } + combine_outputs: true + } + } + } + input_stream: "LM_LIST:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +} // namespace +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index 9a570d524..7f1792a73 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -533,3 +533,19 @@ cc_library( "@com_google_absl//absl/status:statusor", ], ) + +mediapipe_proto_library( + name = "body_rig_proto", + srcs = ["body_rig.proto"], +) + +mediapipe_register_type( + base_name = "body_rig", + include_headers = ["mediapipe/framework/formats/body_rig.pb.h"], + types = [ + "::mediapipe::Joint", + "::mediapipe::JointList", + "::std::vector<::mediapipe::JointList>", + ], + deps = [":body_rig_cc_proto"], +)