Make TensorToVectorFloatCalculator compatible with unaligned tensors.

No performance impact is expected, since the unaligned Eigen::TensorMap is used only to populate a std::vector<float>.

PiperOrigin-RevId: 505251810
This commit is contained in:
MediaPipe Team 2023-01-27 18:06:09 -08:00 committed by Copybara-Service
parent 702cc0c42c
commit ee2f940e1f
3 changed files with 26 additions and 1 deletions

View File

@ -1054,6 +1054,7 @@ cc_test(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/util:packet_test_util",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],

View File

@ -102,7 +102,7 @@ absl::Status TensorToVectorFloatCalculator::Process(CalculatorContext* cc) {
}
auto output =
absl::make_unique<std::vector<float>>(input_tensor.NumElements());
const auto& tensor_values = input_tensor.flat<float>();
const auto& tensor_values = input_tensor.unaligned_flat<float>();
for (int i = 0; i < input_tensor.NumElements(); ++i) {
output->at(i) = tensor_values(i);
}

View File

@ -16,6 +16,7 @@
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/util/packet_test_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
@ -129,5 +130,28 @@ TEST_F(TensorToVectorFloatCalculatorTest, FlattenShouldTakeAllDimensions) {
}
}
TEST_F(TensorToVectorFloatCalculatorTest, AcceptsUnalignedTensors) {
SetUpRunner(/*tensor_is_2d=*/false, /*flatten_nd=*/false);
const tf::TensorShape tensor_shape(std::vector<tf::int64>{2, 5});
tf::Tensor tensor(tf::DT_FLOAT, tensor_shape);
auto slice = tensor.Slice(1, 1).flat<float>();
for (int i = 0; i < 5; ++i) {
slice(i) = i;
}
auto input_tensor = tensor.SubSlice(1);
// Ensure that the input tensor is unaligned.
ASSERT_FALSE(input_tensor.IsAligned());
runner_->MutableInputs()->Index(0).packets.push_back(
MakePacket<tf::Tensor>(input_tensor).At(Timestamp(5)));
ASSERT_TRUE(runner_->Run().ok());
EXPECT_THAT(runner_->Outputs().Index(0).packets,
ElementsAre(PacketContainsTimestampAndPayload<std::vector<float>>(
Timestamp(5), std::vector<float>({0, 1, 2, 3, 4}))));
}
} // namespace
} // namespace mediapipe