Option to remove overlapping values computed for different timestamps.
PiperOrigin-RevId: 499635143
This commit is contained in:
parent
24cc0672c4
commit
43bf02443c
|
@ -37,8 +37,10 @@ class TensorToVectorIntCalculator : public CalculatorBase {
|
|||
|
||||
private:
|
||||
void TokenizeVector(std::vector<int64>* vector) const;
|
||||
void RemoveOverlapVector(std::vector<int64>* vector);
|
||||
|
||||
TensorToVectorIntCalculatorOptions options_;
|
||||
int32_t overlapping_values_;
|
||||
};
|
||||
REGISTER_CALCULATOR(TensorToVectorIntCalculator);
|
||||
|
||||
|
@ -66,6 +68,7 @@ absl::Status TensorToVectorIntCalculator::GetContract(CalculatorContract* cc) {
|
|||
|
||||
absl::Status TensorToVectorIntCalculator::Open(CalculatorContext* cc) {
|
||||
options_ = cc->Options<TensorToVectorIntCalculatorOptions>();
|
||||
overlapping_values_ = 0;
|
||||
|
||||
// Inform mediapipe that this calculator produces an output at time t for
|
||||
// each input received at time t (i.e. this calculator does not buffer
|
||||
|
@ -106,6 +109,7 @@ absl::Status TensorToVectorIntCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
}
|
||||
TokenizeVector(&instance_output);
|
||||
RemoveOverlapVector(&instance_output);
|
||||
}
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
} else {
|
||||
|
@ -128,12 +132,28 @@ absl::Status TensorToVectorIntCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
}
|
||||
TokenizeVector(output.get());
|
||||
RemoveOverlapVector(output.get());
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void TensorToVectorIntCalculator::RemoveOverlapVector(
|
||||
std::vector<int64>* vector) {
|
||||
if (options_.overlap() <= 0) {
|
||||
return;
|
||||
}
|
||||
if (overlapping_values_ > 0) {
|
||||
if (vector->size() < overlapping_values_) {
|
||||
vector->clear();
|
||||
} else {
|
||||
vector->erase(vector->begin(), vector->begin() + overlapping_values_);
|
||||
}
|
||||
}
|
||||
overlapping_values_ = options_.overlap();
|
||||
}
|
||||
|
||||
void TensorToVectorIntCalculator::TokenizeVector(
|
||||
std::vector<int64>* vector) const {
|
||||
if (!options_.tensor_is_token()) {
|
||||
|
|
|
@ -36,4 +36,8 @@ message TensorToVectorIntCalculatorOptions {
|
|||
optional bool tensor_is_token = 3 [default = false];
|
||||
// Threshold for the token generation.
|
||||
optional float token_threshold = 4 [default = 0.5];
|
||||
|
||||
// Values which overlap between timely following vectors. They are removed
|
||||
// from the output to reduce redundancy.
|
||||
optional int32 overlap = 5 [default = 0];
|
||||
}
|
||||
|
|
|
@ -28,7 +28,8 @@ namespace tf = ::tensorflow;
|
|||
class TensorToVectorIntCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUpRunner(const bool tensor_is_2d, const bool flatten_nd,
|
||||
const bool tensor_is_token = false) {
|
||||
const bool tensor_is_token = false,
|
||||
const int32_t overlap = 0) {
|
||||
CalculatorGraphConfig::Node config;
|
||||
config.set_calculator("TensorToVectorIntCalculator");
|
||||
config.add_input_stream("input_tensor");
|
||||
|
@ -38,6 +39,7 @@ class TensorToVectorIntCalculatorTest : public ::testing::Test {
|
|||
options->set_tensor_is_2d(tensor_is_2d);
|
||||
options->set_flatten_nd(flatten_nd);
|
||||
options->set_tensor_is_token(tensor_is_token);
|
||||
options->set_overlap(overlap);
|
||||
runner_ = absl::make_unique<CalculatorRunner>(config);
|
||||
}
|
||||
|
||||
|
@ -188,5 +190,54 @@ TEST_F(TensorToVectorIntCalculatorTest, FlattenShouldTakeAllDimensions) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorToVectorIntCalculatorTest, Overlap) {
|
||||
SetUpRunner(false, false, false, 2);
|
||||
for (int time = 0; time < 3; ++time) {
|
||||
const tf::TensorShape tensor_shape(std::vector<tf::int64>{5});
|
||||
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_INT64, tensor_shape);
|
||||
auto tensor_vec = tensor->vec<int64>();
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
// 2^i can be represented exactly in floating point numbers if 'i' is
|
||||
// small.
|
||||
tensor_vec(i) = static_cast<int64>(time + (1 << i));
|
||||
}
|
||||
|
||||
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(tensor.release()).At(Timestamp(time)));
|
||||
}
|
||||
|
||||
ASSERT_TRUE(runner_->Run().ok());
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Index(0).packets;
|
||||
EXPECT_EQ(3, output_packets.size());
|
||||
|
||||
{
|
||||
// First vector in full.
|
||||
int time = 0;
|
||||
EXPECT_EQ(time, output_packets[time].Timestamp().Value());
|
||||
const std::vector<int64>& output_vector =
|
||||
output_packets[time].Get<std::vector<int64>>();
|
||||
|
||||
EXPECT_EQ(5, output_vector.size());
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
const int64 expected = static_cast<int64>(time + (1 << i));
|
||||
EXPECT_EQ(expected, output_vector[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// All following vectors the overlap removed
|
||||
for (int time = 1; time < 3; ++time) {
|
||||
EXPECT_EQ(time, output_packets[time].Timestamp().Value());
|
||||
const std::vector<int64>& output_vector =
|
||||
output_packets[time].Get<std::vector<int64>>();
|
||||
|
||||
EXPECT_EQ(3, output_vector.size());
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
const int64 expected = static_cast<int64>(time + (1 << (i + 2)));
|
||||
EXPECT_EQ(expected, output_vector[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
Loading…
Reference in New Issue
Block a user