split stream utility function.

PiperOrigin-RevId: 566722901
This commit is contained in:
MediaPipe Team 2023-09-19 13:16:11 -07:00 committed by Copybara-Service
parent 58bb2d1b92
commit bbf40cba87
6 changed files with 875 additions and 0 deletions

View File

@ -944,6 +944,7 @@ cc_library(
deps = [
":split_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:body_rig_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",

View File

@ -17,6 +17,7 @@
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
@ -196,6 +197,18 @@ class SplitLandmarkListCalculator
};
REGISTER_CALCULATOR(SplitLandmarkListCalculator);
class SplitJointListCalculator : public SplitListsCalculator<Joint, JointList> {
protected:
int ListSize(const JointList& list) const override {
return list.joint_size();
}
const Joint GetItem(const JointList& list, int idx) const override {
return list.joint(idx);
}
Joint* AddItem(JointList& list) const override { return list.add_joint(); }
};
REGISTER_CALCULATOR(SplitJointListCalculator);
} // namespace mediapipe
// NOLINTNEXTLINE

View File

@ -192,3 +192,40 @@ cc_test(
"//mediapipe/framework/port:status_matchers",
],
)
cc_library(
name = "split",
hdrs = ["split.h"],
deps = [
"//mediapipe/calculators/core:split_proto_list_calculator",
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:body_rig_cc_proto",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"@org_tensorflow//tensorflow/lite/c:common",
],
)
cc_test(
name = "split_test",
srcs = ["split_test.cc"],
deps = [
":split",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status_matchers",
],
)

View File

@ -0,0 +1,335 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_SPLIT_H_
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_SPLIT_H_
#include <initializer_list>
#include <iterator>
#include <type_traits>
#include <utility>
#include <vector>
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "tensorflow/lite/c/common.h"
namespace mediapipe::api2::builder {
namespace stream_split_internal {
// Helper function that adds a node to a graph, that is capable of splitting a
// specific type (T).
template <class T>
mediapipe::api2::builder::GenericNode& AddSplitVectorNode(
mediapipe::api2::builder::Graph& graph) {
if constexpr (std::is_same_v<T, std::vector<TfLiteTensor>>) {
return graph.AddNode("SplitTfLiteTensorVectorCalculator");
} else if constexpr (std::is_same_v<T, std::vector<mediapipe::Tensor>>) {
return graph.AddNode("SplitTensorVectorCalculator");
} else if constexpr (std::is_same_v<T, std::vector<uint64_t>>) {
return graph.AddNode("SplitUint64tVectorCalculator");
} else if constexpr (std::is_same_v<
T, std::vector<mediapipe::NormalizedLandmark>>) {
return graph.AddNode("SplitLandmarkVectorCalculator");
} else if constexpr (std::is_same_v<
T, std::vector<mediapipe::NormalizedLandmarkList>>) {
return graph.AddNode("SplitNormalizedLandmarkListVectorCalculator");
} else if constexpr (std::is_same_v<T,
std::vector<mediapipe::NormalizedRect>>) {
return graph.AddNode("SplitNormalizedRectVectorCalculator");
} else if constexpr (std::is_same_v<T, std::vector<Matrix>>) {
return graph.AddNode("SplitMatrixVectorCalculator");
} else if constexpr (std::is_same_v<T, std::vector<mediapipe::Detection>>) {
return graph.AddNode("SplitDetectionVectorCalculator");
} else if constexpr (std::is_same_v<
T, std::vector<mediapipe::ClassificationList>>) {
return graph.AddNode("SplitClassificationListVectorCalculator");
} else if constexpr (std::is_same_v<T, mediapipe::NormalizedLandmarkList>) {
return graph.AddNode("SplitNormalizedLandmarkListCalculator");
} else if constexpr (std::is_same_v<T, mediapipe::LandmarkList>) {
return graph.AddNode("SplitLandmarkListCalculator");
} else if constexpr (std::is_same_v<T, mediapipe::JointList>) {
return graph.AddNode("SplitJointListCalculator");
} else {
static_assert(dependent_false<T>::value,
"Split node is not available for the specified type.");
}
}
template <typename T, bool kIteratorContainsRanges = false>
struct split_result_item {
using type = typename T::value_type;
};
template <>
struct split_result_item<mediapipe::NormalizedLandmarkList,
/*kIteratorContainsRanges=*/false> {
using type = mediapipe::NormalizedLandmark;
};
template <>
struct split_result_item<mediapipe::LandmarkList,
/*kIteratorContainsRanges=*/false> {
using type = mediapipe::Landmark;
};
template <typename T>
struct split_result_item<T, /*kIteratorContainsRanges=*/true> {
using type = std::vector<typename T::value_type>;
};
template <>
struct split_result_item<mediapipe::NormalizedLandmarkList,
/*kIteratorContainsRanges=*/true> {
using type = mediapipe::NormalizedLandmarkList;
};
template <>
struct split_result_item<mediapipe::LandmarkList,
/*kIteratorContainsRanges=*/true> {
using type = mediapipe::LandmarkList;
};
template <typename CollectionT, typename I>
auto Split(Stream<CollectionT> items, I begin, I end,
mediapipe::api2::builder::Graph& graph) {
auto& splitter = AddSplitVectorNode<CollectionT>(graph);
items.ConnectTo(splitter.In(""));
constexpr bool kIteratorContainsRanges =
std::is_same_v<typename std::iterator_traits<I>::value_type,
std::pair<int, int>>;
using R =
typename split_result_item<CollectionT, kIteratorContainsRanges>::type;
auto& splitter_opts =
splitter.template GetOptions<mediapipe::SplitVectorCalculatorOptions>();
if constexpr (!kIteratorContainsRanges) {
splitter_opts.set_element_only(true);
}
std::vector<Stream<R>> result;
int output = 0;
for (auto it = begin; it != end; ++it) {
auto* range = splitter_opts.add_ranges();
if constexpr (kIteratorContainsRanges) {
range->set_begin(it->first);
range->set_end(it->second);
} else {
range->set_begin(*it);
range->set_end(*it + 1);
}
result.push_back(splitter.Out("")[output++].template Cast<R>());
}
return result;
}
template <typename CollectionT, typename I>
Stream<CollectionT> SplitAndCombine(Stream<CollectionT> items, I begin, I end,
mediapipe::api2::builder::Graph& graph) {
auto& splitter = AddSplitVectorNode<CollectionT>(graph);
items.ConnectTo(splitter.In(""));
constexpr bool kIteratorContainsRanges =
std::is_same_v<typename std::iterator_traits<I>::value_type,
std::pair<int, int>>;
auto& splitter_opts =
splitter.template GetOptions<mediapipe::SplitVectorCalculatorOptions>();
splitter_opts.set_combine_outputs(true);
for (auto it = begin; it != end; ++it) {
auto* range = splitter_opts.add_ranges();
if constexpr (kIteratorContainsRanges) {
range->set_begin(it->first);
range->set_end(it->second);
} else {
range->set_begin(*it);
range->set_end(*it + 1);
}
}
return splitter.Out("").template Cast<CollectionT>();
}
} // namespace stream_split_internal
// Splits stream containing a collection based on passed @indices into a vector
// of streams where each steam repesents individual item of a collection.
//
// Example:
// ```
//
// Graph graph;
// std::vector<int> indices = {0, 1, 2, 3};
//
// Stream<std::vector<Detection>> detections = ...;
// std::vector<Stream<Detection>> detections_split =
// Split(detections, indices, graph);
//
// Stream<NormalizedLandmarkList> landmarks = ...;
// std::vector<Stream<NormalizedLandmark>> landmarks_split =
// Split(landmarks, indices, graph);
//
// ```
template <typename CollectionT, typename I>
auto Split(Stream<CollectionT> items, const I& indices,
mediapipe::api2::builder::Graph& graph) {
return stream_split_internal::Split(items, indices.begin(), indices.end(),
graph);
}
// Splits stream containing a collection based on passed @indices into a vector
// of streams where each steam repesents individual item of a collection.
//
// Example:
// ```
//
// Graph graph;
// std::vector<int> indices = {0, 1, 2, 3};
//
// Stream<std::vector<Detection>> detections = ...;
// std::vector<Stream<Detection>> detections_split =
// Split(detections, indices, graph);
//
// Stream<NormalizedLandmarkList> landmarks = ...;
// std::vector<Stream<NormalizedLandmark>> landmarks_split =
// Split(landmarks, indices, graph);
//
// ```
template <typename CollectionT>
auto Split(Stream<CollectionT> items, std::initializer_list<int> indices,
mediapipe::api2::builder::Graph& graph) {
return stream_split_internal::Split(items, indices.begin(), indices.end(),
graph);
}
// Splits stream containing a collection into a sub ranges, each represented as
// a stream containing same collection type.
//
// Example:
// ```
//
// Graph graph;
// std::vector<std::pair<int, int>> ranges = {{0, 3}, {7, 10}};
//
// Stream<std::vector<Detection>> detections = ...;
// std::vector<Stream<std::vector<Detection>>> detections_split =
// SplitToRanges(detections, ranges, graph);
//
// Stream<NormalizedLandmarkList> landmarks = ...;
// std::vector<Stream<NormalizedLandmarkList>> landmarks_split =
// SplitToRanges(landmarks, ranges, graph);
//
// ```
template <typename CollectionT, typename RangeT>
auto SplitToRanges(Stream<CollectionT> items, const RangeT& ranges,
mediapipe::api2::builder::Graph& graph) {
return stream_split_internal::Split(items, ranges.begin(), ranges.end(),
graph);
}
// Splits stream containing a collection into a sub ranges, each represented as
// a stream containing same collection type.
//
// Example:
// ```
//
// Graph graph;
// std::vector<std::pair<int, int>> ranges = {{0, 3}, {7, 10}};
//
// Stream<std::vector<Detection>> detections = ...;
// std::vector<Stream<std::vector<Detection>>> detections_split =
// SplitToRanges(detections, ranges, graph);
//
// Stream<NormalizedLandmarkList> landmarks = ...;
// std::vector<Stream<NormalizedLandmarkList>> landmarks_split =
// SplitToRanges(landmarks, ranges, graph);
//
// ```
template <typename CollectionT>
auto SplitToRanges(Stream<CollectionT> items,
std::initializer_list<std::pair<int, int>> ranges,
mediapipe::api2::builder::Graph& graph) {
return stream_split_internal::Split(items, ranges.begin(), ranges.end(),
graph);
}
// Splits stream containing a collection into a sub ranges and combines them
// into a stream containing same collection type.
//
// Example:
// ```
//
// Graph graph;
// std::vector<std::pair<int, int>> ranges = {{0, 3}, {7, 10}};
//
// Stream<std::vector<Detection>> detections = ...;
// Stream<std::vector<Detection>> detections_split_and_combined =
// SplitAndCombine(detections, ranges, graph);
//
// Stream<NormalizedLandmarkList> landmarks = ...;
// Stream<NormalizedLandmarkList> landmarks_split_and_combined =
// SplitAndCombine(landmarks, ranges, graph);
//
// ```
template <typename CollectionT, typename RangeT>
Stream<CollectionT> SplitAndCombine(Stream<CollectionT> items,
const RangeT& ranges,
mediapipe::api2::builder::Graph& graph) {
return stream_split_internal::SplitAndCombine(items, ranges.begin(),
ranges.end(), graph);
}
// Splits stream containing a collection into a sub ranges and combines them
// into a stream containing same collection type.
//
// Example:
// ```
//
// Graph graph;
//
// Stream<std::vector<Detection>> detections = ...;
// Stream<std::vector<Detection>> detections_split_and_combined =
// SplitAndCombine(detections, {{0, 3}, {7, 10}}, graph);
//
// Stream<NormalizedLandmarkList> landmarks = ...;
// Stream<NormalizedLandmarkList> landmarks_split_and_combined =
// SplitAndCombine(landmarks, {{0, 3}, {7, 10}}, graph);
//
// ```
template <typename CollectionT>
Stream<CollectionT> SplitAndCombine(
Stream<CollectionT> items,
std::initializer_list<std::pair<int, int>> ranges,
mediapipe::api2::builder::Graph& graph) {
return stream_split_internal::SplitAndCombine(items, ranges.begin(),
ranges.end(), graph);
}
// Splits stream containing a collection into individual items and combines them
// into a stream containing same collection type.
//
// Example:
// ```
//
// Graph graph;
//
// Stream<std::vector<Detection>> detections = ...;
// Stream<std::vector<Detection>> detections_split_and_combined =
// SplitAndCombine(detections, {0, 7, 10}, graph);
//
// Stream<NormalizedLandmarkList> landmarks = ...;
// Stream<NormalizedLandmarkList> landmarks_split_and_combined =
// SplitAndCombine(landmarks, {0, 7, 10}, graph);
//
// ```
template <typename CollectionT>
Stream<CollectionT> SplitAndCombine(Stream<CollectionT> items,
std::initializer_list<int> ranges,
mediapipe::api2::builder::Graph& graph) {
return stream_split_internal::SplitAndCombine(items, ranges.begin(),
ranges.end(), graph);
}
} // namespace mediapipe::api2::builder
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_SPLIT_H_

View File

@ -0,0 +1,473 @@
#include "mediapipe/framework/api2/stream/split.h"
#include <cstdint>
#include <utility>
#include <vector>
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe::api2::builder {
namespace {
TEST(SplitTest, SplitToRanges2Ranges) {
Graph graph;
Stream<std::vector<mediapipe::Tensor>> tensors =
graph.In("TENSORS").Cast<std::vector<mediapipe::Tensor>>();
std::vector<Stream<std::vector<mediapipe::Tensor>>> result =
SplitToRanges(tensors, {{0, 1}, {1, 2}}, graph);
EXPECT_EQ(result.size(), 2);
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SplitTensorVectorCalculator"
input_stream: "__stream_0"
output_stream: "__stream_1"
output_stream: "__stream_2"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges { begin: 0 end: 1 }
ranges { begin: 1 end: 2 }
}
}
}
input_stream: "TENSORS:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(SplitTest, Split2Items) {
Graph graph;
Stream<std::vector<mediapipe::Tensor>> tensors =
graph.In("TENSORS").Cast<std::vector<mediapipe::Tensor>>();
std::vector<Stream<mediapipe::Tensor>> result = Split(tensors, {0, 1}, graph);
EXPECT_EQ(result.size(), 2);
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SplitTensorVectorCalculator"
input_stream: "__stream_0"
output_stream: "__stream_1"
output_stream: "__stream_2"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges { begin: 0 end: 1 }
ranges { begin: 1 end: 2 }
element_only: true
}
}
}
input_stream: "TENSORS:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(SplitTest, Split2Uint64tItems) {
Graph graph;
Stream<std::vector<uint64_t>> ids =
graph.In("IDS").Cast<std::vector<uint64_t>>();
std::vector<Stream<uint64_t>> result = Split(ids, {0, 1}, graph);
EXPECT_EQ(result.size(), 2);
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SplitUint64tVectorCalculator"
input_stream: "__stream_0"
output_stream: "__stream_1"
output_stream: "__stream_2"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges { begin: 0 end: 1 }
ranges { begin: 1 end: 2 }
element_only: true
}
}
}
input_stream: "IDS:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(SplitTest, SplitToRanges5Ranges) {
Graph graph;
Stream<std::vector<NormalizedRect>> tensors =
graph.In("RECTS").Cast<std::vector<NormalizedRect>>();
std::vector<Stream<std::vector<NormalizedRect>>> result =
SplitToRanges(tensors, {{0, 1}, {1, 2}, {2, 3}, {3, 4}, {4, 5}}, graph);
EXPECT_EQ(result.size(), 5);
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SplitNormalizedRectVectorCalculator"
input_stream: "__stream_0"
output_stream: "__stream_1"
output_stream: "__stream_2"
output_stream: "__stream_3"
output_stream: "__stream_4"
output_stream: "__stream_5"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges { begin: 0 end: 1 }
ranges { begin: 1 end: 2 }
ranges { begin: 2 end: 3 }
ranges { begin: 3 end: 4 }
ranges { begin: 4 end: 5 }
}
}
}
input_stream: "RECTS:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(SplitTest, Split5Items) {
Graph graph;
Stream<std::vector<NormalizedRect>> tensors =
graph.In("RECTS").Cast<std::vector<NormalizedRect>>();
std::vector<Stream<NormalizedRect>> result =
Split(tensors, {0, 1, 2, 3, 4}, graph);
EXPECT_EQ(result.size(), 5);
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SplitNormalizedRectVectorCalculator"
input_stream: "__stream_0"
output_stream: "__stream_1"
output_stream: "__stream_2"
output_stream: "__stream_3"
output_stream: "__stream_4"
output_stream: "__stream_5"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges { begin: 0 end: 1 }
ranges { begin: 1 end: 2 }
ranges { begin: 2 end: 3 }
ranges { begin: 3 end: 4 }
ranges { begin: 4 end: 5 }
element_only: true
}
}
}
input_stream: "RECTS:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(SplitTest, SplitPassingVectorIndices) {
Graph graph;
Stream<std::vector<NormalizedRect>> tensors =
graph.In("RECTS").Cast<std::vector<NormalizedRect>>();
std::vector<int> indices = {250, 300};
std::vector<Stream<NormalizedRect>> second_split_result =
Split(tensors, indices, graph);
EXPECT_EQ(second_split_result.size(), 2);
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SplitNormalizedRectVectorCalculator"
input_stream: "__stream_0"
output_stream: "__stream_1"
output_stream: "__stream_2"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges { begin: 250 end: 251 }
ranges { begin: 300 end: 301 }
element_only: true
}
}
}
input_stream: "RECTS:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(SplitTest, SplitToRangesPassingVectorRanges) {
Graph graph;
Stream<std::vector<NormalizedRect>> tensors =
graph.In("RECTS").Cast<std::vector<NormalizedRect>>();
std::vector<std::pair<int, int>> indices = {{250, 255}, {300, 301}};
std::vector<Stream<std::vector<NormalizedRect>>> second_split_result =
SplitToRanges(tensors, indices, graph);
EXPECT_EQ(second_split_result.size(), 2);
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SplitNormalizedRectVectorCalculator"
input_stream: "__stream_0"
output_stream: "__stream_1"
output_stream: "__stream_2"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges { begin: 250 end: 255 }
ranges { begin: 300 end: 301 }
}
}
}
input_stream: "RECTS:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(SplitTest, SplitToRangesNormalizedLandmarkList) {
Graph graph;
Stream<NormalizedLandmarkList> tensors =
graph.In("LM_LIST").Cast<NormalizedLandmarkList>();
std::vector<std::pair<int, int>> indices = {{250, 255}, {300, 301}};
std::vector<Stream<NormalizedLandmarkList>> second_split_result =
SplitToRanges(tensors, indices, graph);
EXPECT_EQ(second_split_result.size(), 2);
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SplitNormalizedLandmarkListCalculator"
input_stream: "__stream_0"
output_stream: "__stream_1"
output_stream: "__stream_2"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges { begin: 250 end: 255 }
ranges { begin: 300 end: 301 }
}
}
}
input_stream: "LM_LIST:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(SplitTest, SplitNormalizedLandmarkList) {
Graph graph;
Stream<NormalizedLandmarkList> tensors =
graph.In("LM_LIST").Cast<NormalizedLandmarkList>();
std::vector<Stream<NormalizedLandmark>> second_split_result =
Split(tensors, {250, 300}, graph);
EXPECT_EQ(second_split_result.size(), 2);
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SplitNormalizedLandmarkListCalculator"
input_stream: "__stream_0"
output_stream: "__stream_1"
output_stream: "__stream_2"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges { begin: 250 end: 251 }
ranges { begin: 300 end: 301 }
element_only: true
}
}
}
input_stream: "LM_LIST:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(SplitTest, SplitAndCombineRanges) {
Graph graph;
Stream<std::vector<mediapipe::Tensor>> tensors =
graph.In("TENSORS").Cast<std::vector<mediapipe::Tensor>>();
Stream<std::vector<mediapipe::Tensor>> result =
SplitAndCombine(tensors, {{0, 1}, {2, 5}, {70, 75}}, graph);
result.SetName("tensors_split_and_combined");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SplitTensorVectorCalculator"
input_stream: "__stream_0"
output_stream: "tensors_split_and_combined"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges { begin: 0 end: 1 }
ranges { begin: 2 end: 5 }
ranges { begin: 70 end: 75 }
combine_outputs: true
}
}
}
input_stream: "TENSORS:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(SplitTest, SplitAndCombineIndividualIndices) {
Graph graph;
Stream<std::vector<mediapipe::Tensor>> tensors =
graph.In("TENSORS").Cast<std::vector<mediapipe::Tensor>>();
Stream<std::vector<mediapipe::Tensor>> result =
SplitAndCombine(tensors, {0, 2, 70, 100}, graph);
result.SetName("tensors_split_and_combined");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SplitTensorVectorCalculator"
input_stream: "__stream_0"
output_stream: "tensors_split_and_combined"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges { begin: 0 end: 1 }
ranges { begin: 2 end: 3 }
ranges { begin: 70 end: 71 }
ranges { begin: 100 end: 101 }
combine_outputs: true
}
}
}
input_stream: "TENSORS:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(SplitTest, SplitAndCombineLandmarkList) {
Graph graph;
Stream<LandmarkList> tensors = graph.In("LM_LIST").Cast<LandmarkList>();
std::vector<std::pair<int, int>> ranges = {{250, 255}, {300, 301}};
Stream<LandmarkList> landmark_list = SplitAndCombine(tensors, ranges, graph);
landmark_list.SetName("landmark_list");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SplitLandmarkListCalculator"
input_stream: "__stream_0"
output_stream: "landmark_list"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges { begin: 250 end: 255 }
ranges { begin: 300 end: 301 }
combine_outputs: true
}
}
}
input_stream: "LM_LIST:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(SplitTest, SplitAndCombineLandmarkListIndividualIndices) {
Graph graph;
Stream<LandmarkList> tensors = graph.In("LM_LIST").Cast<LandmarkList>();
std::vector<int> indices = {250, 300};
Stream<LandmarkList> landmark_list = SplitAndCombine(tensors, indices, graph);
landmark_list.SetName("landmark_list");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SplitLandmarkListCalculator"
input_stream: "__stream_0"
output_stream: "landmark_list"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges { begin: 250 end: 251 }
ranges { begin: 300 end: 301 }
combine_outputs: true
}
}
}
input_stream: "LM_LIST:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(SplitTest, SplitAndCombineJointList) {
Graph graph;
Stream<JointList> tensors = graph.In("JT_LIST").Cast<JointList>();
std::vector<std::pair<int, int>> ranges = {{250, 255}, {300, 301}};
Stream<JointList> joint_list = SplitAndCombine(tensors, ranges, graph);
joint_list.SetName("joint_list");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SplitJointListCalculator"
input_stream: "__stream_0"
output_stream: "joint_list"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges { begin: 250 end: 255 }
ranges { begin: 300 end: 301 }
combine_outputs: true
}
}
}
input_stream: "JT_LIST:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(SplitTest, SplitAndCombineJointListIndividualIndices) {
Graph graph;
Stream<JointList> tensors = graph.In("LM_LIST").Cast<JointList>();
std::vector<int> indices = {250, 300};
Stream<JointList> joint_list = SplitAndCombine(tensors, indices, graph);
joint_list.SetName("joint_list");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SplitJointListCalculator"
input_stream: "__stream_0"
output_stream: "joint_list"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges { begin: 250 end: 251 }
ranges { begin: 300 end: 301 }
combine_outputs: true
}
}
}
input_stream: "LM_LIST:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
} // namespace
} // namespace mediapipe::api2::builder

View File

@ -533,3 +533,19 @@ cc_library(
"@com_google_absl//absl/status:statusor",
],
)
mediapipe_proto_library(
name = "body_rig_proto",
srcs = ["body_rig.proto"],
)
mediapipe_register_type(
base_name = "body_rig",
include_headers = ["mediapipe/framework/formats/body_rig.pb.h"],
types = [
"::mediapipe::Joint",
"::mediapipe::JointList",
"::std::vector<::mediapipe::JointList>",
],
deps = [":body_rig_cc_proto"],
)