#MediaPipe Add ConcatenateStringVectorCalculator.

PiperOrigin-RevId: 532956844
This commit is contained in:
Rachel Hornung 2023-05-17 17:08:38 -07:00 committed by Copybara-Service
parent 02230f65d1
commit 25458138a9
2 changed files with 26 additions and 3 deletions

View File

@ -55,6 +55,10 @@ MEDIAPIPE_REGISTER_NODE(ConcatenateUInt64VectorCalculator);
typedef ConcatenateVectorCalculator<bool> ConcatenateBoolVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateBoolVectorCalculator);
typedef ConcatenateVectorCalculator<std::string>
ConcatenateStringVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateStringVectorCalculator);
// Example config:
// node {
// calculator: "ConcatenateTfLiteTensorVectorCalculator"

View File

@ -30,13 +30,15 @@ namespace mediapipe {
typedef ConcatenateVectorCalculator<int> TestConcatenateIntVectorCalculator;
MEDIAPIPE_REGISTER_NODE(TestConcatenateIntVectorCalculator);
void AddInputVector(int index, const std::vector<int>& input, int64_t timestamp,
template <typename T>
void AddInputVector(int index, const std::vector<T>& input, int64_t timestamp,
CalculatorRunner* runner) {
runner->MutableInputs()->Index(index).packets.push_back(
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
MakePacket<std::vector<T>>(input).At(Timestamp(timestamp)));
}
void AddInputVectors(const std::vector<std::vector<int>>& inputs,
template <typename T>
void AddInputVectors(const std::vector<std::vector<T>>& inputs,
int64_t timestamp, CalculatorRunner* runner) {
for (int i = 0; i < inputs.size(); ++i) {
AddInputVector(i, inputs[i], timestamp, runner);
@ -382,6 +384,23 @@ TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) {
EXPECT_EQ(0, outputs.size());
}
TEST(ConcatenateStringVectorCalculatorTest, OneTimestamp) {
CalculatorRunner runner("ConcatenateStringVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<std::string>> inputs = {
{"a", "b"}, {"c"}, {"d", "e", "f"}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<std::string> expected_vector = {"a", "b", "c", "d", "e", "f"};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<std::string>>());
}
typedef ConcatenateVectorCalculator<std::unique_ptr<int>>
TestConcatenateUniqueIntPtrCalculator;
MEDIAPIPE_REGISTER_NODE(TestConcatenateUniqueIntPtrCalculator);