split stream utility function.
PiperOrigin-RevId: 566722901
This commit is contained in:
parent
58bb2d1b92
commit
bbf40cba87
|
@ -944,6 +944,7 @@ cc_library(
|
|||
deps = [
|
||||
":split_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
@ -196,6 +197,18 @@ class SplitLandmarkListCalculator
|
|||
};
|
||||
REGISTER_CALCULATOR(SplitLandmarkListCalculator);
|
||||
|
||||
class SplitJointListCalculator : public SplitListsCalculator<Joint, JointList> {
|
||||
protected:
|
||||
int ListSize(const JointList& list) const override {
|
||||
return list.joint_size();
|
||||
}
|
||||
const Joint GetItem(const JointList& list, int idx) const override {
|
||||
return list.joint(idx);
|
||||
}
|
||||
Joint* AddItem(JointList& list) const override { return list.add_joint(); }
|
||||
};
|
||||
REGISTER_CALCULATOR(SplitJointListCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
|
|
|
@ -192,3 +192,40 @@ cc_test(
|
|||
"//mediapipe/framework/port:status_matchers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "split",
|
||||
hdrs = ["split.h"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:split_proto_list_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"@org_tensorflow//tensorflow/lite/c:common",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "split_test",
|
||||
srcs = ["split_test.cc"],
|
||||
deps = [
|
||||
":split",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status_matchers",
|
||||
],
|
||||
)
|
||||
|
|
335
mediapipe/framework/api2/stream/split.h
Normal file
335
mediapipe/framework/api2/stream/split.h
Normal file
|
@ -0,0 +1,335 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_SPLIT_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_SPLIT_H_
|
||||
|
||||
#include <initializer_list>
|
||||
#include <iterator>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
namespace mediapipe::api2::builder {
|
||||
|
||||
namespace stream_split_internal {
|
||||
|
||||
// Helper function that adds a node to a graph, that is capable of splitting a
|
||||
// specific type (T).
|
||||
template <class T>
|
||||
mediapipe::api2::builder::GenericNode& AddSplitVectorNode(
|
||||
mediapipe::api2::builder::Graph& graph) {
|
||||
if constexpr (std::is_same_v<T, std::vector<TfLiteTensor>>) {
|
||||
return graph.AddNode("SplitTfLiteTensorVectorCalculator");
|
||||
} else if constexpr (std::is_same_v<T, std::vector<mediapipe::Tensor>>) {
|
||||
return graph.AddNode("SplitTensorVectorCalculator");
|
||||
} else if constexpr (std::is_same_v<T, std::vector<uint64_t>>) {
|
||||
return graph.AddNode("SplitUint64tVectorCalculator");
|
||||
} else if constexpr (std::is_same_v<
|
||||
T, std::vector<mediapipe::NormalizedLandmark>>) {
|
||||
return graph.AddNode("SplitLandmarkVectorCalculator");
|
||||
} else if constexpr (std::is_same_v<
|
||||
T, std::vector<mediapipe::NormalizedLandmarkList>>) {
|
||||
return graph.AddNode("SplitNormalizedLandmarkListVectorCalculator");
|
||||
} else if constexpr (std::is_same_v<T,
|
||||
std::vector<mediapipe::NormalizedRect>>) {
|
||||
return graph.AddNode("SplitNormalizedRectVectorCalculator");
|
||||
} else if constexpr (std::is_same_v<T, std::vector<Matrix>>) {
|
||||
return graph.AddNode("SplitMatrixVectorCalculator");
|
||||
} else if constexpr (std::is_same_v<T, std::vector<mediapipe::Detection>>) {
|
||||
return graph.AddNode("SplitDetectionVectorCalculator");
|
||||
} else if constexpr (std::is_same_v<
|
||||
T, std::vector<mediapipe::ClassificationList>>) {
|
||||
return graph.AddNode("SplitClassificationListVectorCalculator");
|
||||
} else if constexpr (std::is_same_v<T, mediapipe::NormalizedLandmarkList>) {
|
||||
return graph.AddNode("SplitNormalizedLandmarkListCalculator");
|
||||
} else if constexpr (std::is_same_v<T, mediapipe::LandmarkList>) {
|
||||
return graph.AddNode("SplitLandmarkListCalculator");
|
||||
} else if constexpr (std::is_same_v<T, mediapipe::JointList>) {
|
||||
return graph.AddNode("SplitJointListCalculator");
|
||||
} else {
|
||||
static_assert(dependent_false<T>::value,
|
||||
"Split node is not available for the specified type.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool kIteratorContainsRanges = false>
|
||||
struct split_result_item {
|
||||
using type = typename T::value_type;
|
||||
};
|
||||
template <>
|
||||
struct split_result_item<mediapipe::NormalizedLandmarkList,
|
||||
/*kIteratorContainsRanges=*/false> {
|
||||
using type = mediapipe::NormalizedLandmark;
|
||||
};
|
||||
template <>
|
||||
struct split_result_item<mediapipe::LandmarkList,
|
||||
/*kIteratorContainsRanges=*/false> {
|
||||
using type = mediapipe::Landmark;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct split_result_item<T, /*kIteratorContainsRanges=*/true> {
|
||||
using type = std::vector<typename T::value_type>;
|
||||
};
|
||||
template <>
|
||||
struct split_result_item<mediapipe::NormalizedLandmarkList,
|
||||
/*kIteratorContainsRanges=*/true> {
|
||||
using type = mediapipe::NormalizedLandmarkList;
|
||||
};
|
||||
template <>
|
||||
struct split_result_item<mediapipe::LandmarkList,
|
||||
/*kIteratorContainsRanges=*/true> {
|
||||
using type = mediapipe::LandmarkList;
|
||||
};
|
||||
|
||||
template <typename CollectionT, typename I>
|
||||
auto Split(Stream<CollectionT> items, I begin, I end,
|
||||
mediapipe::api2::builder::Graph& graph) {
|
||||
auto& splitter = AddSplitVectorNode<CollectionT>(graph);
|
||||
items.ConnectTo(splitter.In(""));
|
||||
|
||||
constexpr bool kIteratorContainsRanges =
|
||||
std::is_same_v<typename std::iterator_traits<I>::value_type,
|
||||
std::pair<int, int>>;
|
||||
using R =
|
||||
typename split_result_item<CollectionT, kIteratorContainsRanges>::type;
|
||||
auto& splitter_opts =
|
||||
splitter.template GetOptions<mediapipe::SplitVectorCalculatorOptions>();
|
||||
if constexpr (!kIteratorContainsRanges) {
|
||||
splitter_opts.set_element_only(true);
|
||||
}
|
||||
std::vector<Stream<R>> result;
|
||||
int output = 0;
|
||||
for (auto it = begin; it != end; ++it) {
|
||||
auto* range = splitter_opts.add_ranges();
|
||||
if constexpr (kIteratorContainsRanges) {
|
||||
range->set_begin(it->first);
|
||||
range->set_end(it->second);
|
||||
} else {
|
||||
range->set_begin(*it);
|
||||
range->set_end(*it + 1);
|
||||
}
|
||||
result.push_back(splitter.Out("")[output++].template Cast<R>());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename CollectionT, typename I>
|
||||
Stream<CollectionT> SplitAndCombine(Stream<CollectionT> items, I begin, I end,
|
||||
mediapipe::api2::builder::Graph& graph) {
|
||||
auto& splitter = AddSplitVectorNode<CollectionT>(graph);
|
||||
items.ConnectTo(splitter.In(""));
|
||||
|
||||
constexpr bool kIteratorContainsRanges =
|
||||
std::is_same_v<typename std::iterator_traits<I>::value_type,
|
||||
std::pair<int, int>>;
|
||||
|
||||
auto& splitter_opts =
|
||||
splitter.template GetOptions<mediapipe::SplitVectorCalculatorOptions>();
|
||||
splitter_opts.set_combine_outputs(true);
|
||||
|
||||
for (auto it = begin; it != end; ++it) {
|
||||
auto* range = splitter_opts.add_ranges();
|
||||
if constexpr (kIteratorContainsRanges) {
|
||||
range->set_begin(it->first);
|
||||
range->set_end(it->second);
|
||||
} else {
|
||||
range->set_begin(*it);
|
||||
range->set_end(*it + 1);
|
||||
}
|
||||
}
|
||||
return splitter.Out("").template Cast<CollectionT>();
|
||||
}
|
||||
|
||||
} // namespace stream_split_internal
|
||||
|
||||
// Splits stream containing a collection based on passed @indices into a vector
|
||||
// of streams where each steam repesents individual item of a collection.
|
||||
//
|
||||
// Example:
|
||||
// ```
|
||||
//
|
||||
// Graph graph;
|
||||
// std::vector<int> indices = {0, 1, 2, 3};
|
||||
//
|
||||
// Stream<std::vector<Detection>> detections = ...;
|
||||
// std::vector<Stream<Detection>> detections_split =
|
||||
// Split(detections, indices, graph);
|
||||
//
|
||||
// Stream<NormalizedLandmarkList> landmarks = ...;
|
||||
// std::vector<Stream<NormalizedLandmark>> landmarks_split =
|
||||
// Split(landmarks, indices, graph);
|
||||
//
|
||||
// ```
|
||||
template <typename CollectionT, typename I>
|
||||
auto Split(Stream<CollectionT> items, const I& indices,
|
||||
mediapipe::api2::builder::Graph& graph) {
|
||||
return stream_split_internal::Split(items, indices.begin(), indices.end(),
|
||||
graph);
|
||||
}
|
||||
// Splits stream containing a collection based on passed @indices into a vector
|
||||
// of streams where each steam repesents individual item of a collection.
|
||||
//
|
||||
// Example:
|
||||
// ```
|
||||
//
|
||||
// Graph graph;
|
||||
// std::vector<int> indices = {0, 1, 2, 3};
|
||||
//
|
||||
// Stream<std::vector<Detection>> detections = ...;
|
||||
// std::vector<Stream<Detection>> detections_split =
|
||||
// Split(detections, indices, graph);
|
||||
//
|
||||
// Stream<NormalizedLandmarkList> landmarks = ...;
|
||||
// std::vector<Stream<NormalizedLandmark>> landmarks_split =
|
||||
// Split(landmarks, indices, graph);
|
||||
//
|
||||
// ```
|
||||
template <typename CollectionT>
|
||||
auto Split(Stream<CollectionT> items, std::initializer_list<int> indices,
|
||||
mediapipe::api2::builder::Graph& graph) {
|
||||
return stream_split_internal::Split(items, indices.begin(), indices.end(),
|
||||
graph);
|
||||
}
|
||||
|
||||
// Splits stream containing a collection into a sub ranges, each represented as
|
||||
// a stream containing same collection type.
|
||||
//
|
||||
// Example:
|
||||
// ```
|
||||
//
|
||||
// Graph graph;
|
||||
// std::vector<std::pair<int, int>> ranges = {{0, 3}, {7, 10}};
|
||||
//
|
||||
// Stream<std::vector<Detection>> detections = ...;
|
||||
// std::vector<Stream<std::vector<Detection>>> detections_split =
|
||||
// SplitToRanges(detections, ranges, graph);
|
||||
//
|
||||
// Stream<NormalizedLandmarkList> landmarks = ...;
|
||||
// std::vector<Stream<NormalizedLandmarkList>> landmarks_split =
|
||||
// SplitToRanges(landmarks, ranges, graph);
|
||||
//
|
||||
// ```
|
||||
template <typename CollectionT, typename RangeT>
|
||||
auto SplitToRanges(Stream<CollectionT> items, const RangeT& ranges,
|
||||
mediapipe::api2::builder::Graph& graph) {
|
||||
return stream_split_internal::Split(items, ranges.begin(), ranges.end(),
|
||||
graph);
|
||||
}
|
||||
|
||||
// Splits stream containing a collection into a sub ranges, each represented as
|
||||
// a stream containing same collection type.
|
||||
//
|
||||
// Example:
|
||||
// ```
|
||||
//
|
||||
// Graph graph;
|
||||
// std::vector<std::pair<int, int>> ranges = {{0, 3}, {7, 10}};
|
||||
//
|
||||
// Stream<std::vector<Detection>> detections = ...;
|
||||
// std::vector<Stream<std::vector<Detection>>> detections_split =
|
||||
// SplitToRanges(detections, ranges, graph);
|
||||
//
|
||||
// Stream<NormalizedLandmarkList> landmarks = ...;
|
||||
// std::vector<Stream<NormalizedLandmarkList>> landmarks_split =
|
||||
// SplitToRanges(landmarks, ranges, graph);
|
||||
//
|
||||
// ```
|
||||
template <typename CollectionT>
|
||||
auto SplitToRanges(Stream<CollectionT> items,
|
||||
std::initializer_list<std::pair<int, int>> ranges,
|
||||
mediapipe::api2::builder::Graph& graph) {
|
||||
return stream_split_internal::Split(items, ranges.begin(), ranges.end(),
|
||||
graph);
|
||||
}
|
||||
|
||||
// Splits stream containing a collection into a sub ranges and combines them
|
||||
// into a stream containing same collection type.
|
||||
//
|
||||
// Example:
|
||||
// ```
|
||||
//
|
||||
// Graph graph;
|
||||
// std::vector<std::pair<int, int>> ranges = {{0, 3}, {7, 10}};
|
||||
//
|
||||
// Stream<std::vector<Detection>> detections = ...;
|
||||
// Stream<std::vector<Detection>> detections_split_and_combined =
|
||||
// SplitAndCombine(detections, ranges, graph);
|
||||
//
|
||||
// Stream<NormalizedLandmarkList> landmarks = ...;
|
||||
// Stream<NormalizedLandmarkList> landmarks_split_and_combined =
|
||||
// SplitAndCombine(landmarks, ranges, graph);
|
||||
//
|
||||
// ```
|
||||
template <typename CollectionT, typename RangeT>
|
||||
Stream<CollectionT> SplitAndCombine(Stream<CollectionT> items,
|
||||
const RangeT& ranges,
|
||||
mediapipe::api2::builder::Graph& graph) {
|
||||
return stream_split_internal::SplitAndCombine(items, ranges.begin(),
|
||||
ranges.end(), graph);
|
||||
}
|
||||
|
||||
// Splits stream containing a collection into a sub ranges and combines them
|
||||
// into a stream containing same collection type.
|
||||
//
|
||||
// Example:
|
||||
// ```
|
||||
//
|
||||
// Graph graph;
|
||||
//
|
||||
// Stream<std::vector<Detection>> detections = ...;
|
||||
// Stream<std::vector<Detection>> detections_split_and_combined =
|
||||
// SplitAndCombine(detections, {{0, 3}, {7, 10}}, graph);
|
||||
//
|
||||
// Stream<NormalizedLandmarkList> landmarks = ...;
|
||||
// Stream<NormalizedLandmarkList> landmarks_split_and_combined =
|
||||
// SplitAndCombine(landmarks, {{0, 3}, {7, 10}}, graph);
|
||||
//
|
||||
// ```
|
||||
template <typename CollectionT>
|
||||
Stream<CollectionT> SplitAndCombine(
|
||||
Stream<CollectionT> items,
|
||||
std::initializer_list<std::pair<int, int>> ranges,
|
||||
mediapipe::api2::builder::Graph& graph) {
|
||||
return stream_split_internal::SplitAndCombine(items, ranges.begin(),
|
||||
ranges.end(), graph);
|
||||
}
|
||||
|
||||
// Splits stream containing a collection into individual items and combines them
|
||||
// into a stream containing same collection type.
|
||||
//
|
||||
// Example:
|
||||
// ```
|
||||
//
|
||||
// Graph graph;
|
||||
//
|
||||
// Stream<std::vector<Detection>> detections = ...;
|
||||
// Stream<std::vector<Detection>> detections_split_and_combined =
|
||||
// SplitAndCombine(detections, {0, 7, 10}, graph);
|
||||
//
|
||||
// Stream<NormalizedLandmarkList> landmarks = ...;
|
||||
// Stream<NormalizedLandmarkList> landmarks_split_and_combined =
|
||||
// SplitAndCombine(landmarks, {0, 7, 10}, graph);
|
||||
//
|
||||
// ```
|
||||
template <typename CollectionT>
|
||||
Stream<CollectionT> SplitAndCombine(Stream<CollectionT> items,
|
||||
std::initializer_list<int> ranges,
|
||||
mediapipe::api2::builder::Graph& graph) {
|
||||
return stream_split_internal::SplitAndCombine(items, ranges.begin(),
|
||||
ranges.end(), graph);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::api2::builder
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_SPLIT_H_
|
473
mediapipe/framework/api2/stream/split_test.cc
Normal file
473
mediapipe/framework/api2/stream/split_test.cc
Normal file
|
@ -0,0 +1,473 @@
|
|||
#include "mediapipe/framework/api2/stream/split.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe::api2::builder {
|
||||
namespace {
|
||||
|
||||
TEST(SplitTest, SplitToRanges2Ranges) {
|
||||
Graph graph;
|
||||
Stream<std::vector<mediapipe::Tensor>> tensors =
|
||||
graph.In("TENSORS").Cast<std::vector<mediapipe::Tensor>>();
|
||||
std::vector<Stream<std::vector<mediapipe::Tensor>>> result =
|
||||
SplitToRanges(tensors, {{0, 1}, {1, 2}}, graph);
|
||||
EXPECT_EQ(result.size(), 2);
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "SplitTensorVectorCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "__stream_1"
|
||||
output_stream: "__stream_2"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges { begin: 0 end: 1 }
|
||||
ranges { begin: 1 end: 2 }
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "TENSORS:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(SplitTest, Split2Items) {
|
||||
Graph graph;
|
||||
Stream<std::vector<mediapipe::Tensor>> tensors =
|
||||
graph.In("TENSORS").Cast<std::vector<mediapipe::Tensor>>();
|
||||
std::vector<Stream<mediapipe::Tensor>> result = Split(tensors, {0, 1}, graph);
|
||||
EXPECT_EQ(result.size(), 2);
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "SplitTensorVectorCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "__stream_1"
|
||||
output_stream: "__stream_2"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges { begin: 0 end: 1 }
|
||||
ranges { begin: 1 end: 2 }
|
||||
element_only: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "TENSORS:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(SplitTest, Split2Uint64tItems) {
|
||||
Graph graph;
|
||||
Stream<std::vector<uint64_t>> ids =
|
||||
graph.In("IDS").Cast<std::vector<uint64_t>>();
|
||||
std::vector<Stream<uint64_t>> result = Split(ids, {0, 1}, graph);
|
||||
EXPECT_EQ(result.size(), 2);
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "SplitUint64tVectorCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "__stream_1"
|
||||
output_stream: "__stream_2"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges { begin: 0 end: 1 }
|
||||
ranges { begin: 1 end: 2 }
|
||||
element_only: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "IDS:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(SplitTest, SplitToRanges5Ranges) {
|
||||
Graph graph;
|
||||
Stream<std::vector<NormalizedRect>> tensors =
|
||||
graph.In("RECTS").Cast<std::vector<NormalizedRect>>();
|
||||
std::vector<Stream<std::vector<NormalizedRect>>> result =
|
||||
SplitToRanges(tensors, {{0, 1}, {1, 2}, {2, 3}, {3, 4}, {4, 5}}, graph);
|
||||
EXPECT_EQ(result.size(), 5);
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "SplitNormalizedRectVectorCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "__stream_1"
|
||||
output_stream: "__stream_2"
|
||||
output_stream: "__stream_3"
|
||||
output_stream: "__stream_4"
|
||||
output_stream: "__stream_5"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges { begin: 0 end: 1 }
|
||||
ranges { begin: 1 end: 2 }
|
||||
ranges { begin: 2 end: 3 }
|
||||
ranges { begin: 3 end: 4 }
|
||||
ranges { begin: 4 end: 5 }
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "RECTS:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(SplitTest, Split5Items) {
|
||||
Graph graph;
|
||||
Stream<std::vector<NormalizedRect>> tensors =
|
||||
graph.In("RECTS").Cast<std::vector<NormalizedRect>>();
|
||||
std::vector<Stream<NormalizedRect>> result =
|
||||
Split(tensors, {0, 1, 2, 3, 4}, graph);
|
||||
EXPECT_EQ(result.size(), 5);
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "SplitNormalizedRectVectorCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "__stream_1"
|
||||
output_stream: "__stream_2"
|
||||
output_stream: "__stream_3"
|
||||
output_stream: "__stream_4"
|
||||
output_stream: "__stream_5"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges { begin: 0 end: 1 }
|
||||
ranges { begin: 1 end: 2 }
|
||||
ranges { begin: 2 end: 3 }
|
||||
ranges { begin: 3 end: 4 }
|
||||
ranges { begin: 4 end: 5 }
|
||||
element_only: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "RECTS:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(SplitTest, SplitPassingVectorIndices) {
|
||||
Graph graph;
|
||||
Stream<std::vector<NormalizedRect>> tensors =
|
||||
graph.In("RECTS").Cast<std::vector<NormalizedRect>>();
|
||||
std::vector<int> indices = {250, 300};
|
||||
std::vector<Stream<NormalizedRect>> second_split_result =
|
||||
Split(tensors, indices, graph);
|
||||
EXPECT_EQ(second_split_result.size(), 2);
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "SplitNormalizedRectVectorCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "__stream_1"
|
||||
output_stream: "__stream_2"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges { begin: 250 end: 251 }
|
||||
ranges { begin: 300 end: 301 }
|
||||
element_only: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "RECTS:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(SplitTest, SplitToRangesPassingVectorRanges) {
|
||||
Graph graph;
|
||||
Stream<std::vector<NormalizedRect>> tensors =
|
||||
graph.In("RECTS").Cast<std::vector<NormalizedRect>>();
|
||||
std::vector<std::pair<int, int>> indices = {{250, 255}, {300, 301}};
|
||||
std::vector<Stream<std::vector<NormalizedRect>>> second_split_result =
|
||||
SplitToRanges(tensors, indices, graph);
|
||||
EXPECT_EQ(second_split_result.size(), 2);
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "SplitNormalizedRectVectorCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "__stream_1"
|
||||
output_stream: "__stream_2"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges { begin: 250 end: 255 }
|
||||
ranges { begin: 300 end: 301 }
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "RECTS:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(SplitTest, SplitToRangesNormalizedLandmarkList) {
|
||||
Graph graph;
|
||||
Stream<NormalizedLandmarkList> tensors =
|
||||
graph.In("LM_LIST").Cast<NormalizedLandmarkList>();
|
||||
std::vector<std::pair<int, int>> indices = {{250, 255}, {300, 301}};
|
||||
std::vector<Stream<NormalizedLandmarkList>> second_split_result =
|
||||
SplitToRanges(tensors, indices, graph);
|
||||
EXPECT_EQ(second_split_result.size(), 2);
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "SplitNormalizedLandmarkListCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "__stream_1"
|
||||
output_stream: "__stream_2"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges { begin: 250 end: 255 }
|
||||
ranges { begin: 300 end: 301 }
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "LM_LIST:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(SplitTest, SplitNormalizedLandmarkList) {
|
||||
Graph graph;
|
||||
Stream<NormalizedLandmarkList> tensors =
|
||||
graph.In("LM_LIST").Cast<NormalizedLandmarkList>();
|
||||
std::vector<Stream<NormalizedLandmark>> second_split_result =
|
||||
Split(tensors, {250, 300}, graph);
|
||||
EXPECT_EQ(second_split_result.size(), 2);
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "SplitNormalizedLandmarkListCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "__stream_1"
|
||||
output_stream: "__stream_2"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges { begin: 250 end: 251 }
|
||||
ranges { begin: 300 end: 301 }
|
||||
element_only: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "LM_LIST:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(SplitTest, SplitAndCombineRanges) {
|
||||
Graph graph;
|
||||
Stream<std::vector<mediapipe::Tensor>> tensors =
|
||||
graph.In("TENSORS").Cast<std::vector<mediapipe::Tensor>>();
|
||||
Stream<std::vector<mediapipe::Tensor>> result =
|
||||
SplitAndCombine(tensors, {{0, 1}, {2, 5}, {70, 75}}, graph);
|
||||
result.SetName("tensors_split_and_combined");
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "SplitTensorVectorCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "tensors_split_and_combined"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges { begin: 0 end: 1 }
|
||||
ranges { begin: 2 end: 5 }
|
||||
ranges { begin: 70 end: 75 }
|
||||
combine_outputs: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "TENSORS:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(SplitTest, SplitAndCombineIndividualIndices) {
|
||||
Graph graph;
|
||||
Stream<std::vector<mediapipe::Tensor>> tensors =
|
||||
graph.In("TENSORS").Cast<std::vector<mediapipe::Tensor>>();
|
||||
Stream<std::vector<mediapipe::Tensor>> result =
|
||||
SplitAndCombine(tensors, {0, 2, 70, 100}, graph);
|
||||
result.SetName("tensors_split_and_combined");
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "SplitTensorVectorCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "tensors_split_and_combined"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges { begin: 0 end: 1 }
|
||||
ranges { begin: 2 end: 3 }
|
||||
ranges { begin: 70 end: 71 }
|
||||
ranges { begin: 100 end: 101 }
|
||||
combine_outputs: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "TENSORS:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(SplitTest, SplitAndCombineLandmarkList) {
|
||||
Graph graph;
|
||||
Stream<LandmarkList> tensors = graph.In("LM_LIST").Cast<LandmarkList>();
|
||||
std::vector<std::pair<int, int>> ranges = {{250, 255}, {300, 301}};
|
||||
Stream<LandmarkList> landmark_list = SplitAndCombine(tensors, ranges, graph);
|
||||
landmark_list.SetName("landmark_list");
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "SplitLandmarkListCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "landmark_list"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges { begin: 250 end: 255 }
|
||||
ranges { begin: 300 end: 301 }
|
||||
combine_outputs: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "LM_LIST:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(SplitTest, SplitAndCombineLandmarkListIndividualIndices) {
|
||||
Graph graph;
|
||||
Stream<LandmarkList> tensors = graph.In("LM_LIST").Cast<LandmarkList>();
|
||||
std::vector<int> indices = {250, 300};
|
||||
Stream<LandmarkList> landmark_list = SplitAndCombine(tensors, indices, graph);
|
||||
landmark_list.SetName("landmark_list");
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "SplitLandmarkListCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "landmark_list"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges { begin: 250 end: 251 }
|
||||
ranges { begin: 300 end: 301 }
|
||||
combine_outputs: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "LM_LIST:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(SplitTest, SplitAndCombineJointList) {
|
||||
Graph graph;
|
||||
Stream<JointList> tensors = graph.In("JT_LIST").Cast<JointList>();
|
||||
std::vector<std::pair<int, int>> ranges = {{250, 255}, {300, 301}};
|
||||
Stream<JointList> joint_list = SplitAndCombine(tensors, ranges, graph);
|
||||
joint_list.SetName("joint_list");
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "SplitJointListCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "joint_list"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges { begin: 250 end: 255 }
|
||||
ranges { begin: 300 end: 301 }
|
||||
combine_outputs: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "JT_LIST:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(SplitTest, SplitAndCombineJointListIndividualIndices) {
|
||||
Graph graph;
|
||||
Stream<JointList> tensors = graph.In("LM_LIST").Cast<JointList>();
|
||||
std::vector<int> indices = {250, 300};
|
||||
Stream<JointList> joint_list = SplitAndCombine(tensors, indices, graph);
|
||||
joint_list.SetName("joint_list");
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "SplitJointListCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "joint_list"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges { begin: 250 end: 251 }
|
||||
ranges { begin: 300 end: 301 }
|
||||
combine_outputs: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "LM_LIST:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::api2::builder
|
|
@ -533,3 +533,19 @@ cc_library(
|
|||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "body_rig_proto",
|
||||
srcs = ["body_rig.proto"],
|
||||
)
|
||||
|
||||
mediapipe_register_type(
|
||||
base_name = "body_rig",
|
||||
include_headers = ["mediapipe/framework/formats/body_rig.pb.h"],
|
||||
types = [
|
||||
"::mediapipe::Joint",
|
||||
"::mediapipe::JointList",
|
||||
"::std::vector<::mediapipe::JointList>",
|
||||
],
|
||||
deps = [":body_rig_cc_proto"],
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user