Make TensorToVectorFloatCalculator compatible with unaligned tensors.
No performance impact is expected, since the unaligned Eigen::TensorMap is used only to populate a std::vector<float>. PiperOrigin-RevId: 505251810
This commit is contained in:
parent
702cc0c42c
commit
ee2f940e1f
|
@ -1054,6 +1054,7 @@ cc_test(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/util:packet_test_util",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
|
|
|
@ -102,7 +102,7 @@ absl::Status TensorToVectorFloatCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
auto output =
|
||||
absl::make_unique<std::vector<float>>(input_tensor.NumElements());
|
||||
const auto& tensor_values = input_tensor.flat<float>();
|
||||
const auto& tensor_values = input_tensor.unaligned_flat<float>();
|
||||
for (int i = 0; i < input_tensor.NumElements(); ++i) {
|
||||
output->at(i) = tensor_values(i);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/util/packet_test_util.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
|
@ -129,5 +130,28 @@ TEST_F(TensorToVectorFloatCalculatorTest, FlattenShouldTakeAllDimensions) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorToVectorFloatCalculatorTest, AcceptsUnalignedTensors) {
|
||||
SetUpRunner(/*tensor_is_2d=*/false, /*flatten_nd=*/false);
|
||||
|
||||
const tf::TensorShape tensor_shape(std::vector<tf::int64>{2, 5});
|
||||
tf::Tensor tensor(tf::DT_FLOAT, tensor_shape);
|
||||
auto slice = tensor.Slice(1, 1).flat<float>();
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
slice(i) = i;
|
||||
}
|
||||
|
||||
auto input_tensor = tensor.SubSlice(1);
|
||||
// Ensure that the input tensor is unaligned.
|
||||
ASSERT_FALSE(input_tensor.IsAligned());
|
||||
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||
MakePacket<tf::Tensor>(input_tensor).At(Timestamp(5)));
|
||||
|
||||
ASSERT_TRUE(runner_->Run().ok());
|
||||
|
||||
EXPECT_THAT(runner_->Outputs().Index(0).packets,
|
||||
ElementsAre(PacketContainsTimestampAndPayload<std::vector<float>>(
|
||||
Timestamp(5), std::vector<float>({0, 1, 2, 3, 4}))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
Loading…
Reference in New Issue
Block a user