Modify the TensorToImageFrameCalculator to support normalized outputs.
PiperOrigin-RevId: 540104988
This commit is contained in:
parent
b97d11fa76
commit
02d55dfb0a
|
@ -1077,6 +1077,7 @@ cc_test(
|
|||
linkstatic = 1,
|
||||
deps = [
|
||||
":tensor_to_image_frame_calculator",
|
||||
":tensor_to_image_frame_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
|
|
|
@ -65,6 +65,7 @@ class TensorToImageFrameCalculator : public CalculatorBase {
|
|||
|
||||
private:
|
||||
float scale_factor_;
|
||||
bool scale_per_frame_min_max_;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(TensorToImageFrameCalculator);
|
||||
|
@ -88,6 +89,8 @@ absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) {
|
|||
absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) {
|
||||
scale_factor_ =
|
||||
cc->Options<TensorToImageFrameCalculatorOptions>().scale_factor();
|
||||
scale_per_frame_min_max_ = cc->Options<TensorToImageFrameCalculatorOptions>()
|
||||
.scale_per_frame_min_max();
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -109,16 +112,38 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
|
|||
auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
|
||||
const int32_t total_size = height * width * depth;
|
||||
|
||||
if (scale_per_frame_min_max_) {
|
||||
RET_CHECK_EQ(input_tensor.dtype(), tensorflow::DT_FLOAT)
|
||||
<< "Setting scale_per_frame_min_max requires FLOAT input tensors.";
|
||||
}
|
||||
::std::unique_ptr<const ImageFrame> output;
|
||||
if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
|
||||
// Allocate buffer with alignments.
|
||||
std::unique_ptr<uint8_t[]> buffer(
|
||||
new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
|
||||
auto data = input_tensor.flat<float>().data();
|
||||
float min = 1e23;
|
||||
float max = -1e23;
|
||||
if (scale_per_frame_min_max_) {
|
||||
for (int i = 0; i < total_size; ++i) {
|
||||
float d = scale_factor_ * data[i];
|
||||
if (d < min) {
|
||||
min = d;
|
||||
}
|
||||
if (d > max) {
|
||||
max = d;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < total_size; ++i) {
|
||||
float d = data[i];
|
||||
if (scale_per_frame_min_max_) {
|
||||
d = 255 * (d - min) / (max - min + 1e-9);
|
||||
} else {
|
||||
d = scale_factor_ * d;
|
||||
if (d < 0) d = 0;
|
||||
if (d > 255) d = 255;
|
||||
}
|
||||
buffer[i] = d;
|
||||
}
|
||||
output = ::absl::make_unique<ImageFrame>(
|
||||
|
|
|
@ -26,4 +26,8 @@ message TensorToImageFrameCalculatorOptions {
|
|||
// Multiples floating point tensor outputs by this value before converting to
|
||||
// uint8. This is useful for converting from range [0, 1] to [0, 255]
|
||||
optional float scale_factor = 1 [default = 1.0];
|
||||
|
||||
// If true, scales any FLOAT tensor input of [min, max] to be between [0, 255]
|
||||
// per frame. This overrides any explicit scale_factor.
|
||||
optional bool scale_per_frame_min_max = 2 [default = false];
|
||||
}
|
||||
|
|
|
@ -11,7 +11,9 @@
|
|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include <type_traits>
|
||||
|
||||
#include "mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
|
@ -32,11 +34,14 @@ constexpr char kImage[] = "IMAGE";
|
|||
template <class TypeParam>
|
||||
class TensorToImageFrameCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUpRunner() {
|
||||
void SetUpRunner(bool scale_per_frame_min_max = false) {
|
||||
CalculatorGraphConfig::Node config;
|
||||
config.set_calculator("TensorToImageFrameCalculator");
|
||||
config.add_input_stream("TENSOR:input_tensor");
|
||||
config.add_output_stream("IMAGE:output_image");
|
||||
config.mutable_options()
|
||||
->MutableExtension(mediapipe::TensorToImageFrameCalculatorOptions::ext)
|
||||
->set_scale_per_frame_min_max(scale_per_frame_min_max);
|
||||
runner_ = absl::make_unique<CalculatorRunner>(config);
|
||||
}
|
||||
|
||||
|
@ -157,4 +162,47 @@ TYPED_TEST(TensorToImageFrameCalculatorTest,
|
|||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TensorToImageFrameCalculatorTest,
|
||||
Converts3DTensorToImageFrame2DGrayWithScaling) {
|
||||
this->SetUpRunner(true);
|
||||
auto& runner = this->runner_;
|
||||
constexpr int kWidth = 16;
|
||||
constexpr int kHeight = 8;
|
||||
const tf::TensorShape tensor_shape{kHeight, kWidth};
|
||||
auto tensor = absl::make_unique<tf::Tensor>(
|
||||
tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
|
||||
auto tensor_vec = tensor->template flat<TypeParam>().data();
|
||||
|
||||
// Writing sequence of integers as floats which we want normalized.
|
||||
tensor_vec[0] = 255;
|
||||
for (int i = 1; i < kWidth * kHeight; ++i) {
|
||||
tensor_vec[i] = 200;
|
||||
}
|
||||
|
||||
const int64_t time = 1234;
|
||||
runner->MutableInputs()->Tag(kTensor).packets.push_back(
|
||||
Adopt(tensor.release()).At(Timestamp(time)));
|
||||
|
||||
if (!std::is_same<TypeParam, float>::value) {
|
||||
EXPECT_FALSE(runner->Run().ok());
|
||||
return; // Short circuit because does not apply to other types.
|
||||
} else {
|
||||
EXPECT_TRUE(runner->Run().ok());
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner->Outputs().Tag(kImage).packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
|
||||
EXPECT_EQ(ImageFormat::GRAY8, output_image.Format());
|
||||
EXPECT_EQ(kWidth, output_image.Width());
|
||||
EXPECT_EQ(kHeight, output_image.Height());
|
||||
|
||||
EXPECT_EQ(255, output_image.PixelData()[0]);
|
||||
for (int i = 1; i < kWidth * kHeight; ++i) {
|
||||
const uint8_t pixel_value = output_image.PixelData()[i];
|
||||
ASSERT_EQ(0, pixel_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
Loading…
Reference in New Issue
Block a user