diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc index 2f4ff28cf..f92ddf08d 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc @@ -37,8 +37,10 @@ class TensorToVectorIntCalculator : public CalculatorBase { private: void TokenizeVector(std::vector* vector) const; + void RemoveOverlapVector(std::vector* vector); TensorToVectorIntCalculatorOptions options_; + int32_t overlapping_values_; }; REGISTER_CALCULATOR(TensorToVectorIntCalculator); @@ -66,6 +68,7 @@ absl::Status TensorToVectorIntCalculator::GetContract(CalculatorContract* cc) { absl::Status TensorToVectorIntCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); + overlapping_values_ = 0; // Inform mediapipe that this calculator produces an output at time t for // each input received at time t (i.e. this calculator does not buffer @@ -106,6 +109,7 @@ absl::Status TensorToVectorIntCalculator::Process(CalculatorContext* cc) { } } TokenizeVector(&instance_output); + RemoveOverlapVector(&instance_output); } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); } else { @@ -128,12 +132,28 @@ absl::Status TensorToVectorIntCalculator::Process(CalculatorContext* cc) { } } TokenizeVector(output.get()); + RemoveOverlapVector(output.get()); cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); } return absl::OkStatus(); } +void TensorToVectorIntCalculator::RemoveOverlapVector( + std::vector* vector) { + if (options_.overlap() <= 0) { + return; + } + if (overlapping_values_ > 0) { + if (vector->size() < overlapping_values_) { + vector->clear(); + } else { + vector->erase(vector->begin(), vector->begin() + overlapping_values_); + } + } + overlapping_values_ = options_.overlap(); +} + void TensorToVectorIntCalculator::TokenizeVector( std::vector* vector) const { if (!options_.tensor_is_token()) { diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto index 9da3298b9..76b9be952 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto @@ -36,4 +36,8 @@ message TensorToVectorIntCalculatorOptions { optional bool tensor_is_token = 3 [default = false]; // Threshold for the token generation. optional float token_threshold = 4 [default = 0.5]; + + // Values which overlap between timely following vectors. They are removed + // from the output to reduce redundancy. + optional int32 overlap = 5 [default = 0]; } diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc index 60c0d47ec..406c2c1a7 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc @@ -28,7 +28,8 @@ namespace tf = ::tensorflow; class TensorToVectorIntCalculatorTest : public ::testing::Test { protected: void SetUpRunner(const bool tensor_is_2d, const bool flatten_nd, - const bool tensor_is_token = false) { + const bool tensor_is_token = false, + const int32_t overlap = 0) { CalculatorGraphConfig::Node config; config.set_calculator("TensorToVectorIntCalculator"); config.add_input_stream("input_tensor"); @@ -38,6 +39,7 @@ class TensorToVectorIntCalculatorTest : public ::testing::Test { options->set_tensor_is_2d(tensor_is_2d); options->set_flatten_nd(flatten_nd); options->set_tensor_is_token(tensor_is_token); + options->set_overlap(overlap); runner_ = absl::make_unique(config); } @@ -188,5 +190,54 @@ TEST_F(TensorToVectorIntCalculatorTest, FlattenShouldTakeAllDimensions) { } } +TEST_F(TensorToVectorIntCalculatorTest, Overlap) { + SetUpRunner(false, false, false, 2); + for (int time = 0; time < 3; ++time) { + const tf::TensorShape tensor_shape(std::vector{5}); + auto tensor = absl::make_unique(tf::DT_INT64, tensor_shape); + auto tensor_vec = tensor->vec(); + for (int i = 0; i < 5; ++i) { + // 2^i can be represented exactly in floating point numbers if 'i' is + // small. + tensor_vec(i) = static_cast(time + (1 << i)); + } + + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + } + + ASSERT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(3, output_packets.size()); + + { + // First vector in full. + int time = 0; + EXPECT_EQ(time, output_packets[time].Timestamp().Value()); + const std::vector& output_vector = + output_packets[time].Get>(); + + EXPECT_EQ(5, output_vector.size()); + for (int i = 0; i < 5; ++i) { + const int64 expected = static_cast(time + (1 << i)); + EXPECT_EQ(expected, output_vector[i]); + } + } + + // All following vectors the overlap removed + for (int time = 1; time < 3; ++time) { + EXPECT_EQ(time, output_packets[time].Timestamp().Value()); + const std::vector& output_vector = + output_packets[time].Get>(); + + EXPECT_EQ(3, output_vector.size()); + for (int i = 0; i < 3; ++i) { + const int64 expected = static_cast(time + (1 << (i + 2))); + EXPECT_EQ(expected, output_vector[i]); + } + } +} + } // namespace } // namespace mediapipe