get_vector_item stream utility function.
PiperOrigin-RevId: 568998504
This commit is contained in:
parent
2ecccaf076
commit
8837b49026
|
@ -64,6 +64,37 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "get_vector_item",
|
||||||
|
hdrs = ["get_vector_item.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/core:get_vector_item_calculator",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
|
"@org_tensorflow//tensorflow/lite/c:common",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "get_vector_item_test",
|
||||||
|
srcs = ["get_vector_item_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":get_vector_item",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/port:status_matchers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "landmarks_to_detection",
|
name = "landmarks_to_detection",
|
||||||
srcs = ["landmarks_to_detection.cc"],
|
srcs = ["landmarks_to_detection.cc"],
|
||||||
|
|
66
mediapipe/framework/api2/stream/get_vector_item.h
Normal file
66
mediapipe/framework/api2/stream/get_vector_item.h
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_GET_VECTOR_ITEM_H_
|
||||||
|
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_GET_VECTOR_ITEM_H_
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/core/get_vector_item_calculator.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
|
||||||
|
namespace mediapipe::api2::builder {
|
||||||
|
|
||||||
|
namespace internal_get_vector_item {
|
||||||
|
|
||||||
|
// Helper function that adds a node to a graph, that is capable of getting item
|
||||||
|
// from a vector of type (T).
|
||||||
|
template <class T>
|
||||||
|
mediapipe::api2::builder::GenericNode& AddGetVectorItemNode(
|
||||||
|
mediapipe::api2::builder::Graph& graph) {
|
||||||
|
if constexpr (std::is_same_v<T, mediapipe::NormalizedLandmarkList>) {
|
||||||
|
return graph.AddNode("GetNormalizedLandmarkListVectorItemCalculator");
|
||||||
|
} else if constexpr (std::is_same_v<T, mediapipe::LandmarkList>) {
|
||||||
|
return graph.AddNode("GetLandmarkListVectorItemCalculator");
|
||||||
|
} else if constexpr (std::is_same_v<T, mediapipe::ClassificationList>) {
|
||||||
|
return graph.AddNode("GetClassificationListVectorItemCalculator");
|
||||||
|
} else if constexpr (std::is_same_v<T, mediapipe::NormalizedRect>) {
|
||||||
|
return graph.AddNode("GetNormalizedRectVectorItemCalculator");
|
||||||
|
} else if constexpr (std::is_same_v<T, mediapipe::Rect>) {
|
||||||
|
return graph.AddNode("GetRectVectorItemCalculator");
|
||||||
|
} else {
|
||||||
|
static_assert(
|
||||||
|
dependent_false<T>::value,
|
||||||
|
"Get vector item node is not available for the specified type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace internal_get_vector_item
|
||||||
|
|
||||||
|
// Gets item from the vector.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// ```
|
||||||
|
//
|
||||||
|
// Graph graph;
|
||||||
|
//
|
||||||
|
// Stream<std::vector<LandmarkList>> multi_landmarks = ...;
|
||||||
|
// Stream<LandmarkList> landmarks =
|
||||||
|
// GetItem(multi_landmarks, 0, graph);
|
||||||
|
//
|
||||||
|
// ```
|
||||||
|
template <typename T>
|
||||||
|
Stream<T> GetItem(Stream<std::vector<T>> items, Stream<int> idx,
|
||||||
|
mediapipe::api2::builder::Graph& graph) {
|
||||||
|
auto& getter = internal_get_vector_item::AddGetVectorItemNode<T>(graph);
|
||||||
|
items.ConnectTo(getter.In("VECTOR"));
|
||||||
|
idx.ConnectTo(getter.In("INDEX"));
|
||||||
|
return getter.Out("ITEM").template Cast<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::api2::builder
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_GET_VECTOR_ITEM_H_
|
130
mediapipe/framework/api2/stream/get_vector_item_test.cc
Normal file
130
mediapipe/framework/api2/stream/get_vector_item_test.cc
Normal file
|
@ -0,0 +1,130 @@
|
||||||
|
#include "mediapipe/framework/api2/stream/get_vector_item.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe::api2::builder {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::api2::builder::Graph;
|
||||||
|
|
||||||
|
TEST(GetItem, GetNormalizedLandmarkListVectorItem) {
|
||||||
|
Graph graph;
|
||||||
|
Stream<std::vector<NormalizedLandmarkList>> items =
|
||||||
|
graph.In("ITEMS").Cast<std::vector<NormalizedLandmarkList>>();
|
||||||
|
Stream<int> idx = graph.In("IDX").Cast<int>();
|
||||||
|
Stream<NormalizedLandmarkList> item = GetItem(items, idx, graph);
|
||||||
|
item.SetName("item");
|
||||||
|
EXPECT_THAT(graph.GetConfig(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "GetNormalizedLandmarkListVectorItemCalculator"
|
||||||
|
input_stream: "INDEX:__stream_0"
|
||||||
|
input_stream: "VECTOR:__stream_1"
|
||||||
|
output_stream: "ITEM:item"
|
||||||
|
}
|
||||||
|
input_stream: "IDX:__stream_0"
|
||||||
|
input_stream: "ITEMS:__stream_1"
|
||||||
|
)pb")));
|
||||||
|
CalculatorGraph calculator_graph;
|
||||||
|
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GetItem, GetLandmarkListVectorItem) {
|
||||||
|
Graph graph;
|
||||||
|
Stream<std::vector<LandmarkList>> items =
|
||||||
|
graph.In("ITEMS").Cast<std::vector<LandmarkList>>();
|
||||||
|
Stream<int> idx = graph.In("IDX").Cast<int>();
|
||||||
|
Stream<LandmarkList> item = GetItem(items, idx, graph);
|
||||||
|
item.SetName("item");
|
||||||
|
EXPECT_THAT(graph.GetConfig(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "GetLandmarkListVectorItemCalculator"
|
||||||
|
input_stream: "INDEX:__stream_0"
|
||||||
|
input_stream: "VECTOR:__stream_1"
|
||||||
|
output_stream: "ITEM:item"
|
||||||
|
}
|
||||||
|
input_stream: "IDX:__stream_0"
|
||||||
|
input_stream: "ITEMS:__stream_1"
|
||||||
|
)pb")));
|
||||||
|
CalculatorGraph calculator_graph;
|
||||||
|
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GetItem, GetClassificationListVectorItem) {
|
||||||
|
Graph graph;
|
||||||
|
Stream<std::vector<ClassificationList>> items =
|
||||||
|
graph.In("ITEMS").Cast<std::vector<ClassificationList>>();
|
||||||
|
Stream<int> idx = graph.In("IDX").Cast<int>();
|
||||||
|
Stream<ClassificationList> item = GetItem(items, idx, graph);
|
||||||
|
item.SetName("item");
|
||||||
|
EXPECT_THAT(graph.GetConfig(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "GetClassificationListVectorItemCalculator"
|
||||||
|
input_stream: "INDEX:__stream_0"
|
||||||
|
input_stream: "VECTOR:__stream_1"
|
||||||
|
output_stream: "ITEM:item"
|
||||||
|
}
|
||||||
|
input_stream: "IDX:__stream_0"
|
||||||
|
input_stream: "ITEMS:__stream_1"
|
||||||
|
)pb")));
|
||||||
|
CalculatorGraph calculator_graph;
|
||||||
|
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GetItem, GetNormalizedRectVectorItem) {
|
||||||
|
Graph graph;
|
||||||
|
Stream<std::vector<NormalizedRect>> items =
|
||||||
|
graph.In("ITEMS").Cast<std::vector<NormalizedRect>>();
|
||||||
|
Stream<int> idx = graph.In("IDX").Cast<int>();
|
||||||
|
Stream<NormalizedRect> item = GetItem(items, idx, graph);
|
||||||
|
item.SetName("item");
|
||||||
|
EXPECT_THAT(graph.GetConfig(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "GetNormalizedRectVectorItemCalculator"
|
||||||
|
input_stream: "INDEX:__stream_0"
|
||||||
|
input_stream: "VECTOR:__stream_1"
|
||||||
|
output_stream: "ITEM:item"
|
||||||
|
}
|
||||||
|
input_stream: "IDX:__stream_0"
|
||||||
|
input_stream: "ITEMS:__stream_1"
|
||||||
|
)pb")));
|
||||||
|
CalculatorGraph calculator_graph;
|
||||||
|
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GetItem, GetRectVectorItem) {
|
||||||
|
Graph graph;
|
||||||
|
Stream<std::vector<Rect>> items = graph.In("ITEMS").Cast<std::vector<Rect>>();
|
||||||
|
Stream<int> idx = graph.In("IDX").Cast<int>();
|
||||||
|
Stream<Rect> item = GetItem(items, idx, graph);
|
||||||
|
item.SetName("item");
|
||||||
|
EXPECT_THAT(graph.GetConfig(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "GetRectVectorItemCalculator"
|
||||||
|
input_stream: "INDEX:__stream_0"
|
||||||
|
input_stream: "VECTOR:__stream_1"
|
||||||
|
output_stream: "ITEM:item"
|
||||||
|
}
|
||||||
|
input_stream: "IDX:__stream_0"
|
||||||
|
input_stream: "ITEMS:__stream_1"
|
||||||
|
)pb")));
|
||||||
|
CalculatorGraph calculator_graph;
|
||||||
|
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe::api2::builder
|
Loading…
Reference in New Issue
Block a user