Extracted common logics from the ImageToTensorCalculator such that it can be

reused by other calculators.

PiperOrigin-RevId: 485472451
This commit is contained in:
MediaPipe Team 2022-11-01 18:41:49 -07:00 committed by Copybara-Service
parent aaf98ea43c
commit f1f123d255
6 changed files with 335 additions and 158 deletions

View File

@ -1294,13 +1294,30 @@ cc_library(
name = "image_to_tensor_utils",
srcs = ["image_to_tensor_utils.cc"],
hdrs = ["image_to_tensor_utils.h"],
copts = select({
"//mediapipe:apple": [
"-x objective-c++",
"-fobjc-arc", # enable reference-counting
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":image_to_tensor_calculator_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:optional",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:statusor",
"@com_google_absl//absl/types:optional",
],
"//mediapipe/gpu:gpu_origin_cc_proto",
] + select({
"//mediapipe/gpu:disable_gpu": [],
"//conditions:default": ["//mediapipe/gpu:gpu_buffer"],
}),
)
cc_test(
@ -1310,6 +1327,8 @@ cc_test(
":image_to_tensor_utils",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
],
)

View File

@ -54,13 +54,6 @@
namespace mediapipe {
namespace api2 {
#if MEDIAPIPE_DISABLE_GPU
// Just a placeholder to not have to depend on mediapipe::GpuBuffer.
using GpuBuffer = AnyType;
#else
using GpuBuffer = mediapipe::GpuBuffer;
#endif // MEDIAPIPE_DISABLE_GPU
// Converts image into Tensor, possibly with cropping, resizing and
// normalization, according to specified inputs and options.
//
@ -141,42 +134,7 @@ class ImageToTensorCalculator : public Node {
const auto& options =
cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
RET_CHECK(options.has_output_tensor_float_range() ||
options.has_output_tensor_int_range() ||
options.has_output_tensor_uint_range())
<< "Output tensor range is required.";
if (options.has_output_tensor_float_range()) {
RET_CHECK_LT(options.output_tensor_float_range().min(),
options.output_tensor_float_range().max())
<< "Valid output float tensor range is required.";
}
if (options.has_output_tensor_uint_range()) {
RET_CHECK_LT(options.output_tensor_uint_range().min(),
options.output_tensor_uint_range().max())
<< "Valid output uint tensor range is required.";
RET_CHECK_GE(options.output_tensor_uint_range().min(), 0)
<< "The minimum of the output uint tensor range must be "
"non-negative.";
RET_CHECK_LE(options.output_tensor_uint_range().max(), 255)
<< "The maximum of the output uint tensor range must be less than or "
"equal to 255.";
}
if (options.has_output_tensor_int_range()) {
RET_CHECK_LT(options.output_tensor_int_range().min(),
options.output_tensor_int_range().max())
<< "Valid output int tensor range is required.";
RET_CHECK_GE(options.output_tensor_int_range().min(), -128)
<< "The minimum of the output int tensor range must be greater than "
"or equal to -128.";
RET_CHECK_LE(options.output_tensor_int_range().max(), 127)
<< "The maximum of the output int tensor range must be less than or "
"equal to 127.";
}
RET_CHECK_GT(options.output_tensor_width(), 0)
<< "Valid output tensor width is required.";
RET_CHECK_GT(options.output_tensor_height(), 0)
<< "Valid output tensor height is required.";
RET_CHECK_OK(ValidateOptionOutputDims(options));
RET_CHECK(kIn(cc).IsConnected() ^ kInGpu(cc).IsConnected())
<< "One and only one of IMAGE and IMAGE_GPU input is expected.";
@ -198,21 +156,7 @@ class ImageToTensorCalculator : public Node {
absl::Status Open(CalculatorContext* cc) {
options_ = cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
output_width_ = options_.output_tensor_width();
output_height_ = options_.output_tensor_height();
is_float_output_ = options_.has_output_tensor_float_range();
if (options_.has_output_tensor_uint_range()) {
range_min_ =
static_cast<float>(options_.output_tensor_uint_range().min());
range_max_ =
static_cast<float>(options_.output_tensor_uint_range().max());
} else if (options_.has_output_tensor_int_range()) {
range_min_ = static_cast<float>(options_.output_tensor_int_range().min());
range_max_ = static_cast<float>(options_.output_tensor_int_range().max());
} else {
range_min_ = options_.output_tensor_float_range().min();
range_max_ = options_.output_tensor_float_range().max();
}
params_ = GetOutputTensorParams(options_);
return absl::OkStatus();
}
@ -242,7 +186,13 @@ class ImageToTensorCalculator : public Node {
}
}
ASSIGN_OR_RETURN(auto image, GetInputImage(cc));
#if MEDIAPIPE_DISABLE_GPU
ASSIGN_OR_RETURN(auto image, GetInputImage(kIn(cc)));
#else
const bool is_input_gpu = kInGpu(cc).IsConnected();
ASSIGN_OR_RETURN(auto image, is_input_gpu ? GetInputImage(kInGpu(cc))
: GetInputImage(kIn(cc)));
#endif // MEDIAPIPE_DISABLE_GPU
RotatedRect roi = GetRoi(image->width(), image->height(), norm_rect);
ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
@ -263,11 +213,13 @@ class ImageToTensorCalculator : public Node {
MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get()));
Tensor::ElementType output_tensor_type =
GetOutputTensorType(image->UsesGpu());
Tensor tensor(output_tensor_type, {1, output_height_, output_width_,
GetOutputTensorType(image->UsesGpu(), params_);
Tensor tensor(output_tensor_type,
{1, params_.output_height, params_.output_width,
GetNumOutputChannels(*image)});
MP_RETURN_IF_ERROR((image->UsesGpu() ? gpu_converter_ : cpu_converter_)
->Convert(*image, roi, range_min_, range_max_,
->Convert(*image, roi, params_.range_min,
params_.range_max,
/*tensor_buffer_offset=*/0, tensor));
auto result = std::make_unique<std::vector<Tensor>>();
@ -278,81 +230,11 @@ class ImageToTensorCalculator : public Node {
}
private:
bool DoesGpuInputStartAtBottom() {
return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
}
BorderMode GetBorderMode() {
switch (options_.border_mode()) {
case mediapipe::
ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED:
return BorderMode::kReplicate;
case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO:
return BorderMode::kZero;
case mediapipe::
ImageToTensorCalculatorOptions_BorderMode_BORDER_REPLICATE:
return BorderMode::kReplicate;
}
}
Tensor::ElementType GetOutputTensorType(bool uses_gpu) {
if (!uses_gpu) {
if (is_float_output_) {
return Tensor::ElementType::kFloat32;
}
if (range_min_ < 0) {
return Tensor::ElementType::kInt8;
} else {
return Tensor::ElementType::kUInt8;
}
}
// Always use float32 when GPU is enabled.
return Tensor::ElementType::kFloat32;
}
int GetNumOutputChannels(const Image& image) {
#if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED
if (image.UsesGpu()) {
return 4;
}
#endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU
// All of the processors except for Metal expect 3 channels.
return 3;
}
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
CalculatorContext* cc) {
if (kIn(cc).IsConnected()) {
const auto& packet = kIn(cc).packet();
return kIn(cc).Visit(
[&packet](const mediapipe::Image&) {
return SharedPtrWithPacket<mediapipe::Image>(packet);
},
[&packet](const mediapipe::ImageFrame&) {
return std::make_shared<const mediapipe::Image>(
std::const_pointer_cast<mediapipe::ImageFrame>(
SharedPtrWithPacket<mediapipe::ImageFrame>(packet)));
});
} else { // if (kInGpu(cc).IsConnected())
#if !MEDIAPIPE_DISABLE_GPU
const GpuBuffer& input = *kInGpu(cc);
// A shallow copy is okay since the resulting 'image' object is local in
// Process(), and thus never outlives 'input'.
return std::make_shared<const mediapipe::Image>(input);
#else
return absl::UnimplementedError(
"GPU processing is disabled in build flags");
#endif // !MEDIAPIPE_DISABLE_GPU
}
}
absl::Status InitConverterIfNecessary(CalculatorContext* cc,
const Image& image) {
// Lazy initialization of the GPU or CPU converter.
if (image.UsesGpu()) {
if (!is_float_output_) {
if (!params_.is_float_output) {
return absl::UnimplementedError(
"ImageToTensorConverter for the input GPU image currently doesn't "
"support quantization.");
@ -360,18 +242,20 @@ class ImageToTensorCalculator : public Node {
if (!gpu_converter_) {
#if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED
ASSIGN_OR_RETURN(gpu_converter_,
CreateMetalConverter(cc, GetBorderMode()));
ASSIGN_OR_RETURN(
gpu_converter_,
CreateMetalConverter(cc, GetBorderMode(options_.border_mode())));
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
ASSIGN_OR_RETURN(gpu_converter_,
CreateImageToGlBufferTensorConverter(
cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
cc, DoesGpuInputStartAtBottom(options_),
GetBorderMode(options_.border_mode())));
#else
if (!gpu_converter_) {
ASSIGN_OR_RETURN(
gpu_converter_,
ASSIGN_OR_RETURN(gpu_converter_,
CreateImageToGlTextureTensorConverter(
cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
cc, DoesGpuInputStartAtBottom(options_),
GetBorderMode(options_.border_mode())));
}
if (!gpu_converter_) {
return absl::UnimplementedError(
@ -383,10 +267,10 @@ class ImageToTensorCalculator : public Node {
} else {
if (!cpu_converter_) {
#if !MEDIAPIPE_DISABLE_OPENCV
ASSIGN_OR_RETURN(
cpu_converter_,
CreateOpenCvConverter(cc, GetBorderMode(),
GetOutputTensorType(/*uses_gpu=*/false)));
ASSIGN_OR_RETURN(cpu_converter_,
CreateOpenCvConverter(
cc, GetBorderMode(options_.border_mode()),
GetOutputTensorType(/*uses_gpu=*/false, params_)));
#else
LOG(FATAL) << "Cannot create image to tensor opencv converter since "
"MEDIAPIPE_DISABLE_OPENCV is defined.";
@ -399,11 +283,7 @@ class ImageToTensorCalculator : public Node {
std::unique_ptr<ImageToTensorConverter> gpu_converter_;
std::unique_ptr<ImageToTensorConverter> cpu_converter_;
mediapipe::ImageToTensorCalculatorOptions options_;
int output_width_ = 0;
int output_height_ = 0;
bool is_float_output_ = false;
float range_min_ = 0.0f;
float range_max_ = 1.0f;
OutputTensorParams params_;
};
MEDIAPIPE_REGISTER_NODE(ImageToTensorCalculator);

View File

@ -27,12 +27,6 @@ struct Size {
int height;
};
// Pixel extrapolation method.
// When converting image to tensor it may happen that tensor needs to read
// pixels outside image boundaries. Border mode helps to specify how such pixels
// will be calculated.
enum class BorderMode { kZero, kReplicate };
// Converts image to tensor.
class ImageToTensorConverter {
public:

View File

@ -16,7 +16,9 @@
#include <array>
#include "absl/status/status.h"
#include "absl/types/optional.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/statusor.h"
@ -214,4 +216,68 @@ void GetTransposedRotatedSubRectToRectTransformMatrix(
matrix[15] = 1.0f;
}
BorderMode GetBorderMode(
const mediapipe::ImageToTensorCalculatorOptions::BorderMode& mode) {
switch (mode) {
case mediapipe::
ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED:
return BorderMode::kReplicate;
case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO:
return BorderMode::kZero;
case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_REPLICATE:
return BorderMode::kReplicate;
}
}
Tensor::ElementType GetOutputTensorType(bool uses_gpu,
const OutputTensorParams& params) {
if (!uses_gpu) {
if (params.is_float_output) {
return Tensor::ElementType::kFloat32;
}
if (params.range_min < 0) {
return Tensor::ElementType::kInt8;
} else {
return Tensor::ElementType::kUInt8;
}
}
// Always use float32 when GPU is enabled.
return Tensor::ElementType::kFloat32;
}
int GetNumOutputChannels(const mediapipe::Image& image) {
#if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED
if (image.UsesGpu()) {
return 4;
}
#endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU
// All of the processors except for Metal expect 3 channels.
return 3;
}
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
const api2::Packet<api2::OneOf<Image, mediapipe::ImageFrame>>&
image_packet) {
return image_packet.Visit(
[&image_packet](const mediapipe::Image&) {
return SharedPtrWithPacket<mediapipe::Image>(image_packet);
},
[&image_packet](const mediapipe::ImageFrame&) {
return std::make_shared<const mediapipe::Image>(
std::const_pointer_cast<mediapipe::ImageFrame>(
SharedPtrWithPacket<mediapipe::ImageFrame>(image_packet)));
});
}
#if !MEDIAPIPE_DISABLE_GPU
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
const api2::Packet<mediapipe::GpuBuffer>& image_gpu_packet) {
// A shallow copy is okay since the resulting 'image' object is local in
// Process(), and thus never outlives 'input'.
return std::make_shared<const mediapipe::Image>(image_gpu_packet.Get());
}
#endif // !MEDIAPIPE_DISABLE_GPU
} // namespace mediapipe

View File

@ -18,8 +18,18 @@
#include <array>
#include "absl/types/optional.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/statusor.h"
#if !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_buffer.h"
#endif // !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_origin.pb.h"
namespace mediapipe {
@ -31,6 +41,24 @@ struct RotatedRect {
float rotation;
};
// Pixel extrapolation method.
// When converting image to tensor it may happen that tensor needs to read
// pixels outside image boundaries. Border mode helps to specify how such pixels
// will be calculated.
// TODO: Consider moving this to a separate border_mode.h file.
enum class BorderMode { kZero, kReplicate };
// Struct that host commonly accessed parameters used in the
// ImageTo[Batch]TensorCalculator.
struct OutputTensorParams {
int output_height;
int output_width;
int output_batch;
bool is_float_output;
float range_min;
float range_max;
};
// Generates a new ROI or converts it from normalized rect.
RotatedRect GetRoi(int input_width, int input_height,
absl::optional<mediapipe::NormalizedRect> norm_rect);
@ -95,6 +123,103 @@ void GetTransposedRotatedSubRectToRectTransformMatrix(
const RotatedRect& sub_rect, int rect_width, int rect_height,
bool flip_horizontaly, std::array<float, 16>* matrix);
// Validates the output dimensions set in the option proto. The input option
// proto is expected to have to following fields:
// output_tensor_float_range, output_tensor_int_range, output_tensor_uint_range
// output_tensor_width, output_tensor_height.
// See ImageToTensorCalculatorOptions for the description of each field.
template <typename T>
absl::Status ValidateOptionOutputDims(const T& options) {
RET_CHECK(options.has_output_tensor_float_range() ||
options.has_output_tensor_int_range() ||
options.has_output_tensor_uint_range())
<< "Output tensor range is required.";
if (options.has_output_tensor_float_range()) {
RET_CHECK_LT(options.output_tensor_float_range().min(),
options.output_tensor_float_range().max())
<< "Valid output float tensor range is required.";
}
if (options.has_output_tensor_uint_range()) {
RET_CHECK_LT(options.output_tensor_uint_range().min(),
options.output_tensor_uint_range().max())
<< "Valid output uint tensor range is required.";
RET_CHECK_GE(options.output_tensor_uint_range().min(), 0)
<< "The minimum of the output uint tensor range must be "
"non-negative.";
RET_CHECK_LE(options.output_tensor_uint_range().max(), 255)
<< "The maximum of the output uint tensor range must be less than or "
"equal to 255.";
}
if (options.has_output_tensor_int_range()) {
RET_CHECK_LT(options.output_tensor_int_range().min(),
options.output_tensor_int_range().max())
<< "Valid output int tensor range is required.";
RET_CHECK_GE(options.output_tensor_int_range().min(), -128)
<< "The minimum of the output int tensor range must be greater than "
"or equal to -128.";
RET_CHECK_LE(options.output_tensor_int_range().max(), 127)
<< "The maximum of the output int tensor range must be less than or "
"equal to 127.";
}
RET_CHECK_GT(options.output_tensor_width(), 0)
<< "Valid output tensor width is required.";
RET_CHECK_GT(options.output_tensor_height(), 0)
<< "Valid output tensor height is required.";
return absl::OkStatus();
}
template <typename T>
OutputTensorParams GetOutputTensorParams(const T& options) {
OutputTensorParams params;
if (options.has_output_tensor_uint_range()) {
params.range_min =
static_cast<float>(options.output_tensor_uint_range().min());
params.range_max =
static_cast<float>(options.output_tensor_uint_range().max());
} else if (options.has_output_tensor_int_range()) {
params.range_min =
static_cast<float>(options.output_tensor_int_range().min());
params.range_max =
static_cast<float>(options.output_tensor_int_range().max());
} else {
params.range_min = options.output_tensor_float_range().min();
params.range_max = options.output_tensor_float_range().max();
}
params.output_width = options.output_tensor_width();
params.output_height = options.output_tensor_height();
params.is_float_output = options.has_output_tensor_float_range();
params.output_batch = 1;
return params;
}
// Returns whether the GPU input format starts at the bottom.
template <typename T>
bool DoesGpuInputStartAtBottom(const T& options) {
return options.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
}
// Converts the BorderMode proto into struct.
BorderMode GetBorderMode(
const mediapipe::ImageToTensorCalculatorOptions::BorderMode& mode);
// Gets the output tensor type.
Tensor::ElementType GetOutputTensorType(bool uses_gpu,
const OutputTensorParams& params);
// Gets the number of output channels from the input Image format.
int GetNumOutputChannels(const mediapipe::Image& image);
// Converts the packet that hosts different format (Image, ImageFrame,
// GpuBuffer) into the mediapipe::Image format.
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
const api2::Packet<api2::OneOf<Image, mediapipe::ImageFrame>>&
image_packet);
#if !MEDIAPIPE_DISABLE_GPU
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
const api2::Packet<mediapipe::GpuBuffer>& image_gpu_packet);
#endif // !MEDIAPIPE_DISABLE_GPU
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_UTILS_H_

View File

@ -16,6 +16,8 @@
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
@ -23,6 +25,7 @@ namespace {
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
using ::testing::HasSubstr;
testing::Matcher<RotatedRect> EqRotatedRect(float width, float height,
float center_x, float center_y,
@ -157,5 +160,95 @@ TEST(GetValueRangeTransformation, FloatToPixel) {
EqValueTransformation(/*scale=*/255.0f, /*offset=*/0.0f));
}
constexpr char kValidFloatProto[] = R"(
output_tensor_float_range { min: 0.0 max: 1.0 }
output_tensor_width: 100
output_tensor_height: 200
)";
constexpr char kValidIntProto[] = R"(
output_tensor_float_range { min: 0 max: 255 }
output_tensor_width: 100
output_tensor_height: 200
)";
TEST(ValidateOptionOutputDims, ValidProtos) {
const auto float_options =
mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(
kValidFloatProto);
MP_EXPECT_OK(ValidateOptionOutputDims(float_options));
}
TEST(ValidateOptionOutputDims, EmptyProto) {
mediapipe::ImageToTensorCalculatorOptions options;
// No output tensor range set.
EXPECT_THAT(ValidateOptionOutputDims(options),
StatusIs(absl::StatusCode::kInternal,
HasSubstr("Output tensor range is required")));
// Invalid output float tensor range.
options.mutable_output_tensor_float_range()->set_min(1.0);
options.mutable_output_tensor_float_range()->set_max(0.0);
EXPECT_THAT(
ValidateOptionOutputDims(options),
StatusIs(absl::StatusCode::kInternal,
HasSubstr("Valid output float tensor range is required")));
// Output width/height is not set.
options.mutable_output_tensor_float_range()->set_min(0.0);
options.mutable_output_tensor_float_range()->set_max(1.0);
EXPECT_THAT(ValidateOptionOutputDims(options),
StatusIs(absl::StatusCode::kInternal,
HasSubstr("Valid output tensor width is required")));
}
TEST(GetOutputTensorParams, SetValues) {
// Test int range with ImageToTensorCalculatorOptions.
const auto int_options =
mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(
kValidIntProto);
const auto params2 = GetOutputTensorParams(int_options);
EXPECT_EQ(params2.range_min, 0.0f);
EXPECT_EQ(params2.range_max, 255.0f);
EXPECT_EQ(params2.output_batch, 1);
EXPECT_EQ(params2.output_width, 100);
EXPECT_EQ(params2.output_height, 200);
}
TEST(GetBorderMode, GetBorderMode) {
// Default to REPLICATE.
auto border_mode =
mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED;
EXPECT_EQ(BorderMode::kReplicate, GetBorderMode(border_mode));
// Set to ZERO.
border_mode =
mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO;
EXPECT_EQ(BorderMode::kZero, GetBorderMode(border_mode));
}
TEST(GetOutputTensorType, GetOutputTensorType) {
OutputTensorParams params;
// Return float32 when GPU is enabled.
EXPECT_EQ(Tensor::ElementType::kFloat32,
GetOutputTensorType(/*uses_gpu=*/true, params));
// Return float32 when is_float_output is set to true.
params.is_float_output = true;
EXPECT_EQ(Tensor::ElementType::kFloat32,
GetOutputTensorType(/*uses_gpu=*/false, params));
// Return int8 when range_min is negative.
params.is_float_output = false;
params.range_min = -255.0f;
EXPECT_EQ(Tensor::ElementType::kInt8,
GetOutputTensorType(/*uses_gpu=*/false, params));
// Return 8int8 when range_min is non-negative.
params.range_min = 0.0f;
EXPECT_EQ(Tensor::ElementType::kUInt8,
GetOutputTensorType(/*uses_gpu=*/false, params));
}
} // namespace
} // namespace mediapipe