Modify the TensorToImageFrameCalculator to support normalized outputs.
PiperOrigin-RevId: 540104988
This commit is contained in:
parent
b97d11fa76
commit
02d55dfb0a
|
@ -1077,6 +1077,7 @@ cc_test(
|
||||||
linkstatic = 1,
|
linkstatic = 1,
|
||||||
deps = [
|
deps = [
|
||||||
":tensor_to_image_frame_calculator",
|
":tensor_to_image_frame_calculator",
|
||||||
|
":tensor_to_image_frame_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:calculator_runner",
|
"//mediapipe/framework:calculator_runner",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
|
|
|
@ -65,6 +65,7 @@ class TensorToImageFrameCalculator : public CalculatorBase {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
float scale_factor_;
|
float scale_factor_;
|
||||||
|
bool scale_per_frame_min_max_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_CALCULATOR(TensorToImageFrameCalculator);
|
REGISTER_CALCULATOR(TensorToImageFrameCalculator);
|
||||||
|
@ -88,6 +89,8 @@ absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) {
|
||||||
absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) {
|
absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) {
|
||||||
scale_factor_ =
|
scale_factor_ =
|
||||||
cc->Options<TensorToImageFrameCalculatorOptions>().scale_factor();
|
cc->Options<TensorToImageFrameCalculatorOptions>().scale_factor();
|
||||||
|
scale_per_frame_min_max_ = cc->Options<TensorToImageFrameCalculatorOptions>()
|
||||||
|
.scale_per_frame_min_max();
|
||||||
cc->SetOffset(TimestampDiff(0));
|
cc->SetOffset(TimestampDiff(0));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -109,16 +112,38 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
|
||||||
auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
|
auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
|
||||||
const int32_t total_size = height * width * depth;
|
const int32_t total_size = height * width * depth;
|
||||||
|
|
||||||
|
if (scale_per_frame_min_max_) {
|
||||||
|
RET_CHECK_EQ(input_tensor.dtype(), tensorflow::DT_FLOAT)
|
||||||
|
<< "Setting scale_per_frame_min_max requires FLOAT input tensors.";
|
||||||
|
}
|
||||||
::std::unique_ptr<const ImageFrame> output;
|
::std::unique_ptr<const ImageFrame> output;
|
||||||
if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
|
if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
|
||||||
// Allocate buffer with alignments.
|
// Allocate buffer with alignments.
|
||||||
std::unique_ptr<uint8_t[]> buffer(
|
std::unique_ptr<uint8_t[]> buffer(
|
||||||
new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
|
new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
|
||||||
auto data = input_tensor.flat<float>().data();
|
auto data = input_tensor.flat<float>().data();
|
||||||
|
float min = 1e23;
|
||||||
|
float max = -1e23;
|
||||||
|
if (scale_per_frame_min_max_) {
|
||||||
for (int i = 0; i < total_size; ++i) {
|
for (int i = 0; i < total_size; ++i) {
|
||||||
float d = scale_factor_ * data[i];
|
float d = scale_factor_ * data[i];
|
||||||
|
if (d < min) {
|
||||||
|
min = d;
|
||||||
|
}
|
||||||
|
if (d > max) {
|
||||||
|
max = d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < total_size; ++i) {
|
||||||
|
float d = data[i];
|
||||||
|
if (scale_per_frame_min_max_) {
|
||||||
|
d = 255 * (d - min) / (max - min + 1e-9);
|
||||||
|
} else {
|
||||||
|
d = scale_factor_ * d;
|
||||||
if (d < 0) d = 0;
|
if (d < 0) d = 0;
|
||||||
if (d > 255) d = 255;
|
if (d > 255) d = 255;
|
||||||
|
}
|
||||||
buffer[i] = d;
|
buffer[i] = d;
|
||||||
}
|
}
|
||||||
output = ::absl::make_unique<ImageFrame>(
|
output = ::absl::make_unique<ImageFrame>(
|
||||||
|
|
|
@ -26,4 +26,8 @@ message TensorToImageFrameCalculatorOptions {
|
||||||
// Multiples floating point tensor outputs by this value before converting to
|
// Multiples floating point tensor outputs by this value before converting to
|
||||||
// uint8. This is useful for converting from range [0, 1] to [0, 255]
|
// uint8. This is useful for converting from range [0, 1] to [0, 255]
|
||||||
optional float scale_factor = 1 [default = 1.0];
|
optional float scale_factor = 1 [default = 1.0];
|
||||||
|
|
||||||
|
// If true, scales any FLOAT tensor input of [min, max] to be between [0, 255]
|
||||||
|
// per frame. This overrides any explicit scale_factor.
|
||||||
|
optional bool scale_per_frame_min_max = 2 [default = false];
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,9 @@
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/calculator_runner.h"
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
#include "mediapipe/framework/formats/image_frame.h"
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
|
@ -32,11 +34,14 @@ constexpr char kImage[] = "IMAGE";
|
||||||
template <class TypeParam>
|
template <class TypeParam>
|
||||||
class TensorToImageFrameCalculatorTest : public ::testing::Test {
|
class TensorToImageFrameCalculatorTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void SetUpRunner() {
|
void SetUpRunner(bool scale_per_frame_min_max = false) {
|
||||||
CalculatorGraphConfig::Node config;
|
CalculatorGraphConfig::Node config;
|
||||||
config.set_calculator("TensorToImageFrameCalculator");
|
config.set_calculator("TensorToImageFrameCalculator");
|
||||||
config.add_input_stream("TENSOR:input_tensor");
|
config.add_input_stream("TENSOR:input_tensor");
|
||||||
config.add_output_stream("IMAGE:output_image");
|
config.add_output_stream("IMAGE:output_image");
|
||||||
|
config.mutable_options()
|
||||||
|
->MutableExtension(mediapipe::TensorToImageFrameCalculatorOptions::ext)
|
||||||
|
->set_scale_per_frame_min_max(scale_per_frame_min_max);
|
||||||
runner_ = absl::make_unique<CalculatorRunner>(config);
|
runner_ = absl::make_unique<CalculatorRunner>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,4 +162,47 @@ TYPED_TEST(TensorToImageFrameCalculatorTest,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(TensorToImageFrameCalculatorTest,
|
||||||
|
Converts3DTensorToImageFrame2DGrayWithScaling) {
|
||||||
|
this->SetUpRunner(true);
|
||||||
|
auto& runner = this->runner_;
|
||||||
|
constexpr int kWidth = 16;
|
||||||
|
constexpr int kHeight = 8;
|
||||||
|
const tf::TensorShape tensor_shape{kHeight, kWidth};
|
||||||
|
auto tensor = absl::make_unique<tf::Tensor>(
|
||||||
|
tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
|
||||||
|
auto tensor_vec = tensor->template flat<TypeParam>().data();
|
||||||
|
|
||||||
|
// Writing sequence of integers as floats which we want normalized.
|
||||||
|
tensor_vec[0] = 255;
|
||||||
|
for (int i = 1; i < kWidth * kHeight; ++i) {
|
||||||
|
tensor_vec[i] = 200;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t time = 1234;
|
||||||
|
runner->MutableInputs()->Tag(kTensor).packets.push_back(
|
||||||
|
Adopt(tensor.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
|
if (!std::is_same<TypeParam, float>::value) {
|
||||||
|
EXPECT_FALSE(runner->Run().ok());
|
||||||
|
return; // Short circuit because does not apply to other types.
|
||||||
|
} else {
|
||||||
|
EXPECT_TRUE(runner->Run().ok());
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner->Outputs().Tag(kImage).packets;
|
||||||
|
EXPECT_EQ(1, output_packets.size());
|
||||||
|
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||||
|
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
|
||||||
|
EXPECT_EQ(ImageFormat::GRAY8, output_image.Format());
|
||||||
|
EXPECT_EQ(kWidth, output_image.Width());
|
||||||
|
EXPECT_EQ(kHeight, output_image.Height());
|
||||||
|
|
||||||
|
EXPECT_EQ(255, output_image.PixelData()[0]);
|
||||||
|
for (int i = 1; i < kWidth * kHeight; ++i) {
|
||||||
|
const uint8_t pixel_value = output_image.PixelData()[i];
|
||||||
|
ASSERT_EQ(0, pixel_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
Loading…
Reference in New Issue
Block a user