Extracted common logics from the ImageToTensorCalculator such that it can be
reused by other calculators. PiperOrigin-RevId: 485472451
This commit is contained in:
parent
aaf98ea43c
commit
f1f123d255
|
@ -1294,13 +1294,30 @@ cc_library(
|
|||
name = "image_to_tensor_utils",
|
||||
srcs = ["image_to_tensor_utils.cc"],
|
||||
hdrs = ["image_to_tensor_utils.h"],
|
||||
copts = select({
|
||||
"//mediapipe:apple": [
|
||||
"-x objective-c++",
|
||||
"-fobjc-arc", # enable reference-counting
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":image_to_tensor_calculator_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
"//mediapipe/gpu:gpu_origin_cc_proto",
|
||||
] + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
"//conditions:default": ["//mediapipe/gpu:gpu_buffer"],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_test(
|
||||
|
@ -1310,6 +1327,8 @@ cc_test(
|
|||
":image_to_tensor_utils",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -54,13 +54,6 @@
|
|||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
#if MEDIAPIPE_DISABLE_GPU
|
||||
// Just a placeholder to not have to depend on mediapipe::GpuBuffer.
|
||||
using GpuBuffer = AnyType;
|
||||
#else
|
||||
using GpuBuffer = mediapipe::GpuBuffer;
|
||||
#endif // MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
// Converts image into Tensor, possibly with cropping, resizing and
|
||||
// normalization, according to specified inputs and options.
|
||||
//
|
||||
|
@ -141,42 +134,7 @@ class ImageToTensorCalculator : public Node {
|
|||
const auto& options =
|
||||
cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
|
||||
|
||||
RET_CHECK(options.has_output_tensor_float_range() ||
|
||||
options.has_output_tensor_int_range() ||
|
||||
options.has_output_tensor_uint_range())
|
||||
<< "Output tensor range is required.";
|
||||
if (options.has_output_tensor_float_range()) {
|
||||
RET_CHECK_LT(options.output_tensor_float_range().min(),
|
||||
options.output_tensor_float_range().max())
|
||||
<< "Valid output float tensor range is required.";
|
||||
}
|
||||
if (options.has_output_tensor_uint_range()) {
|
||||
RET_CHECK_LT(options.output_tensor_uint_range().min(),
|
||||
options.output_tensor_uint_range().max())
|
||||
<< "Valid output uint tensor range is required.";
|
||||
RET_CHECK_GE(options.output_tensor_uint_range().min(), 0)
|
||||
<< "The minimum of the output uint tensor range must be "
|
||||
"non-negative.";
|
||||
RET_CHECK_LE(options.output_tensor_uint_range().max(), 255)
|
||||
<< "The maximum of the output uint tensor range must be less than or "
|
||||
"equal to 255.";
|
||||
}
|
||||
if (options.has_output_tensor_int_range()) {
|
||||
RET_CHECK_LT(options.output_tensor_int_range().min(),
|
||||
options.output_tensor_int_range().max())
|
||||
<< "Valid output int tensor range is required.";
|
||||
RET_CHECK_GE(options.output_tensor_int_range().min(), -128)
|
||||
<< "The minimum of the output int tensor range must be greater than "
|
||||
"or equal to -128.";
|
||||
RET_CHECK_LE(options.output_tensor_int_range().max(), 127)
|
||||
<< "The maximum of the output int tensor range must be less than or "
|
||||
"equal to 127.";
|
||||
}
|
||||
RET_CHECK_GT(options.output_tensor_width(), 0)
|
||||
<< "Valid output tensor width is required.";
|
||||
RET_CHECK_GT(options.output_tensor_height(), 0)
|
||||
<< "Valid output tensor height is required.";
|
||||
|
||||
RET_CHECK_OK(ValidateOptionOutputDims(options));
|
||||
RET_CHECK(kIn(cc).IsConnected() ^ kInGpu(cc).IsConnected())
|
||||
<< "One and only one of IMAGE and IMAGE_GPU input is expected.";
|
||||
|
||||
|
@ -198,21 +156,7 @@ class ImageToTensorCalculator : public Node {
|
|||
|
||||
absl::Status Open(CalculatorContext* cc) {
|
||||
options_ = cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
|
||||
output_width_ = options_.output_tensor_width();
|
||||
output_height_ = options_.output_tensor_height();
|
||||
is_float_output_ = options_.has_output_tensor_float_range();
|
||||
if (options_.has_output_tensor_uint_range()) {
|
||||
range_min_ =
|
||||
static_cast<float>(options_.output_tensor_uint_range().min());
|
||||
range_max_ =
|
||||
static_cast<float>(options_.output_tensor_uint_range().max());
|
||||
} else if (options_.has_output_tensor_int_range()) {
|
||||
range_min_ = static_cast<float>(options_.output_tensor_int_range().min());
|
||||
range_max_ = static_cast<float>(options_.output_tensor_int_range().max());
|
||||
} else {
|
||||
range_min_ = options_.output_tensor_float_range().min();
|
||||
range_max_ = options_.output_tensor_float_range().max();
|
||||
}
|
||||
params_ = GetOutputTensorParams(options_);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -242,7 +186,13 @@ class ImageToTensorCalculator : public Node {
|
|||
}
|
||||
}
|
||||
|
||||
ASSIGN_OR_RETURN(auto image, GetInputImage(cc));
|
||||
#if MEDIAPIPE_DISABLE_GPU
|
||||
ASSIGN_OR_RETURN(auto image, GetInputImage(kIn(cc)));
|
||||
#else
|
||||
const bool is_input_gpu = kInGpu(cc).IsConnected();
|
||||
ASSIGN_OR_RETURN(auto image, is_input_gpu ? GetInputImage(kInGpu(cc))
|
||||
: GetInputImage(kIn(cc)));
|
||||
#endif // MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
RotatedRect roi = GetRoi(image->width(), image->height(), norm_rect);
|
||||
ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
|
||||
|
@ -263,11 +213,13 @@ class ImageToTensorCalculator : public Node {
|
|||
MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get()));
|
||||
|
||||
Tensor::ElementType output_tensor_type =
|
||||
GetOutputTensorType(image->UsesGpu());
|
||||
Tensor tensor(output_tensor_type, {1, output_height_, output_width_,
|
||||
GetNumOutputChannels(*image)});
|
||||
GetOutputTensorType(image->UsesGpu(), params_);
|
||||
Tensor tensor(output_tensor_type,
|
||||
{1, params_.output_height, params_.output_width,
|
||||
GetNumOutputChannels(*image)});
|
||||
MP_RETURN_IF_ERROR((image->UsesGpu() ? gpu_converter_ : cpu_converter_)
|
||||
->Convert(*image, roi, range_min_, range_max_,
|
||||
->Convert(*image, roi, params_.range_min,
|
||||
params_.range_max,
|
||||
/*tensor_buffer_offset=*/0, tensor));
|
||||
|
||||
auto result = std::make_unique<std::vector<Tensor>>();
|
||||
|
@ -278,81 +230,11 @@ class ImageToTensorCalculator : public Node {
|
|||
}
|
||||
|
||||
private:
|
||||
bool DoesGpuInputStartAtBottom() {
|
||||
return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
|
||||
}
|
||||
|
||||
BorderMode GetBorderMode() {
|
||||
switch (options_.border_mode()) {
|
||||
case mediapipe::
|
||||
ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED:
|
||||
return BorderMode::kReplicate;
|
||||
case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO:
|
||||
return BorderMode::kZero;
|
||||
case mediapipe::
|
||||
ImageToTensorCalculatorOptions_BorderMode_BORDER_REPLICATE:
|
||||
return BorderMode::kReplicate;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor::ElementType GetOutputTensorType(bool uses_gpu) {
|
||||
if (!uses_gpu) {
|
||||
if (is_float_output_) {
|
||||
return Tensor::ElementType::kFloat32;
|
||||
}
|
||||
if (range_min_ < 0) {
|
||||
return Tensor::ElementType::kInt8;
|
||||
} else {
|
||||
return Tensor::ElementType::kUInt8;
|
||||
}
|
||||
}
|
||||
// Always use float32 when GPU is enabled.
|
||||
return Tensor::ElementType::kFloat32;
|
||||
}
|
||||
|
||||
int GetNumOutputChannels(const Image& image) {
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
if (image.UsesGpu()) {
|
||||
return 4;
|
||||
}
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
// All of the processors except for Metal expect 3 channels.
|
||||
return 3;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
||||
CalculatorContext* cc) {
|
||||
if (kIn(cc).IsConnected()) {
|
||||
const auto& packet = kIn(cc).packet();
|
||||
return kIn(cc).Visit(
|
||||
[&packet](const mediapipe::Image&) {
|
||||
return SharedPtrWithPacket<mediapipe::Image>(packet);
|
||||
},
|
||||
[&packet](const mediapipe::ImageFrame&) {
|
||||
return std::make_shared<const mediapipe::Image>(
|
||||
std::const_pointer_cast<mediapipe::ImageFrame>(
|
||||
SharedPtrWithPacket<mediapipe::ImageFrame>(packet)));
|
||||
});
|
||||
} else { // if (kInGpu(cc).IsConnected())
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
const GpuBuffer& input = *kInGpu(cc);
|
||||
// A shallow copy is okay since the resulting 'image' object is local in
|
||||
// Process(), and thus never outlives 'input'.
|
||||
return std::make_shared<const mediapipe::Image>(input);
|
||||
#else
|
||||
return absl::UnimplementedError(
|
||||
"GPU processing is disabled in build flags");
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status InitConverterIfNecessary(CalculatorContext* cc,
|
||||
const Image& image) {
|
||||
// Lazy initialization of the GPU or CPU converter.
|
||||
if (image.UsesGpu()) {
|
||||
if (!is_float_output_) {
|
||||
if (!params_.is_float_output) {
|
||||
return absl::UnimplementedError(
|
||||
"ImageToTensorConverter for the input GPU image currently doesn't "
|
||||
"support quantization.");
|
||||
|
@ -360,18 +242,20 @@ class ImageToTensorCalculator : public Node {
|
|||
if (!gpu_converter_) {
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
ASSIGN_OR_RETURN(gpu_converter_,
|
||||
CreateMetalConverter(cc, GetBorderMode()));
|
||||
ASSIGN_OR_RETURN(
|
||||
gpu_converter_,
|
||||
CreateMetalConverter(cc, GetBorderMode(options_.border_mode())));
|
||||
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
ASSIGN_OR_RETURN(gpu_converter_,
|
||||
CreateImageToGlBufferTensorConverter(
|
||||
cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
|
||||
cc, DoesGpuInputStartAtBottom(options_),
|
||||
GetBorderMode(options_.border_mode())));
|
||||
#else
|
||||
if (!gpu_converter_) {
|
||||
ASSIGN_OR_RETURN(
|
||||
gpu_converter_,
|
||||
CreateImageToGlTextureTensorConverter(
|
||||
cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
|
||||
ASSIGN_OR_RETURN(gpu_converter_,
|
||||
CreateImageToGlTextureTensorConverter(
|
||||
cc, DoesGpuInputStartAtBottom(options_),
|
||||
GetBorderMode(options_.border_mode())));
|
||||
}
|
||||
if (!gpu_converter_) {
|
||||
return absl::UnimplementedError(
|
||||
|
@ -383,10 +267,10 @@ class ImageToTensorCalculator : public Node {
|
|||
} else {
|
||||
if (!cpu_converter_) {
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
ASSIGN_OR_RETURN(
|
||||
cpu_converter_,
|
||||
CreateOpenCvConverter(cc, GetBorderMode(),
|
||||
GetOutputTensorType(/*uses_gpu=*/false)));
|
||||
ASSIGN_OR_RETURN(cpu_converter_,
|
||||
CreateOpenCvConverter(
|
||||
cc, GetBorderMode(options_.border_mode()),
|
||||
GetOutputTensorType(/*uses_gpu=*/false, params_)));
|
||||
#else
|
||||
LOG(FATAL) << "Cannot create image to tensor opencv converter since "
|
||||
"MEDIAPIPE_DISABLE_OPENCV is defined.";
|
||||
|
@ -399,11 +283,7 @@ class ImageToTensorCalculator : public Node {
|
|||
std::unique_ptr<ImageToTensorConverter> gpu_converter_;
|
||||
std::unique_ptr<ImageToTensorConverter> cpu_converter_;
|
||||
mediapipe::ImageToTensorCalculatorOptions options_;
|
||||
int output_width_ = 0;
|
||||
int output_height_ = 0;
|
||||
bool is_float_output_ = false;
|
||||
float range_min_ = 0.0f;
|
||||
float range_max_ = 1.0f;
|
||||
OutputTensorParams params_;
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(ImageToTensorCalculator);
|
||||
|
|
|
@ -27,12 +27,6 @@ struct Size {
|
|||
int height;
|
||||
};
|
||||
|
||||
// Pixel extrapolation method.
|
||||
// When converting image to tensor it may happen that tensor needs to read
|
||||
// pixels outside image boundaries. Border mode helps to specify how such pixels
|
||||
// will be calculated.
|
||||
enum class BorderMode { kZero, kReplicate };
|
||||
|
||||
// Converts image to tensor.
|
||||
class ImageToTensorConverter {
|
||||
public:
|
||||
|
|
|
@ -16,7 +16,9 @@
|
|||
|
||||
#include <array>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
|
||||
|
@ -214,4 +216,68 @@ void GetTransposedRotatedSubRectToRectTransformMatrix(
|
|||
matrix[15] = 1.0f;
|
||||
}
|
||||
|
||||
BorderMode GetBorderMode(
|
||||
const mediapipe::ImageToTensorCalculatorOptions::BorderMode& mode) {
|
||||
switch (mode) {
|
||||
case mediapipe::
|
||||
ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED:
|
||||
return BorderMode::kReplicate;
|
||||
case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO:
|
||||
return BorderMode::kZero;
|
||||
case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_REPLICATE:
|
||||
return BorderMode::kReplicate;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor::ElementType GetOutputTensorType(bool uses_gpu,
|
||||
const OutputTensorParams& params) {
|
||||
if (!uses_gpu) {
|
||||
if (params.is_float_output) {
|
||||
return Tensor::ElementType::kFloat32;
|
||||
}
|
||||
if (params.range_min < 0) {
|
||||
return Tensor::ElementType::kInt8;
|
||||
} else {
|
||||
return Tensor::ElementType::kUInt8;
|
||||
}
|
||||
}
|
||||
// Always use float32 when GPU is enabled.
|
||||
return Tensor::ElementType::kFloat32;
|
||||
}
|
||||
|
||||
int GetNumOutputChannels(const mediapipe::Image& image) {
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
if (image.UsesGpu()) {
|
||||
return 4;
|
||||
}
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
// All of the processors except for Metal expect 3 channels.
|
||||
return 3;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
||||
const api2::Packet<api2::OneOf<Image, mediapipe::ImageFrame>>&
|
||||
image_packet) {
|
||||
return image_packet.Visit(
|
||||
[&image_packet](const mediapipe::Image&) {
|
||||
return SharedPtrWithPacket<mediapipe::Image>(image_packet);
|
||||
},
|
||||
[&image_packet](const mediapipe::ImageFrame&) {
|
||||
return std::make_shared<const mediapipe::Image>(
|
||||
std::const_pointer_cast<mediapipe::ImageFrame>(
|
||||
SharedPtrWithPacket<mediapipe::ImageFrame>(image_packet)));
|
||||
});
|
||||
}
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
||||
const api2::Packet<mediapipe::GpuBuffer>& image_gpu_packet) {
|
||||
// A shallow copy is okay since the resulting 'image' object is local in
|
||||
// Process(), and thus never outlives 'input'.
|
||||
return std::make_shared<const mediapipe::Image>(image_gpu_packet.Get());
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -18,8 +18,18 @@
|
|||
#include <array>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
#include "mediapipe/gpu/gpu_origin.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -31,6 +41,24 @@ struct RotatedRect {
|
|||
float rotation;
|
||||
};
|
||||
|
||||
// Pixel extrapolation method.
|
||||
// When converting image to tensor it may happen that tensor needs to read
|
||||
// pixels outside image boundaries. Border mode helps to specify how such pixels
|
||||
// will be calculated.
|
||||
// TODO: Consider moving this to a separate border_mode.h file.
|
||||
enum class BorderMode { kZero, kReplicate };
|
||||
|
||||
// Struct that host commonly accessed parameters used in the
|
||||
// ImageTo[Batch]TensorCalculator.
|
||||
struct OutputTensorParams {
|
||||
int output_height;
|
||||
int output_width;
|
||||
int output_batch;
|
||||
bool is_float_output;
|
||||
float range_min;
|
||||
float range_max;
|
||||
};
|
||||
|
||||
// Generates a new ROI or converts it from normalized rect.
|
||||
RotatedRect GetRoi(int input_width, int input_height,
|
||||
absl::optional<mediapipe::NormalizedRect> norm_rect);
|
||||
|
@ -95,6 +123,103 @@ void GetTransposedRotatedSubRectToRectTransformMatrix(
|
|||
const RotatedRect& sub_rect, int rect_width, int rect_height,
|
||||
bool flip_horizontaly, std::array<float, 16>* matrix);
|
||||
|
||||
// Validates the output dimensions set in the option proto. The input option
|
||||
// proto is expected to have to following fields:
|
||||
// output_tensor_float_range, output_tensor_int_range, output_tensor_uint_range
|
||||
// output_tensor_width, output_tensor_height.
|
||||
// See ImageToTensorCalculatorOptions for the description of each field.
|
||||
template <typename T>
|
||||
absl::Status ValidateOptionOutputDims(const T& options) {
|
||||
RET_CHECK(options.has_output_tensor_float_range() ||
|
||||
options.has_output_tensor_int_range() ||
|
||||
options.has_output_tensor_uint_range())
|
||||
<< "Output tensor range is required.";
|
||||
if (options.has_output_tensor_float_range()) {
|
||||
RET_CHECK_LT(options.output_tensor_float_range().min(),
|
||||
options.output_tensor_float_range().max())
|
||||
<< "Valid output float tensor range is required.";
|
||||
}
|
||||
if (options.has_output_tensor_uint_range()) {
|
||||
RET_CHECK_LT(options.output_tensor_uint_range().min(),
|
||||
options.output_tensor_uint_range().max())
|
||||
<< "Valid output uint tensor range is required.";
|
||||
RET_CHECK_GE(options.output_tensor_uint_range().min(), 0)
|
||||
<< "The minimum of the output uint tensor range must be "
|
||||
"non-negative.";
|
||||
RET_CHECK_LE(options.output_tensor_uint_range().max(), 255)
|
||||
<< "The maximum of the output uint tensor range must be less than or "
|
||||
"equal to 255.";
|
||||
}
|
||||
if (options.has_output_tensor_int_range()) {
|
||||
RET_CHECK_LT(options.output_tensor_int_range().min(),
|
||||
options.output_tensor_int_range().max())
|
||||
<< "Valid output int tensor range is required.";
|
||||
RET_CHECK_GE(options.output_tensor_int_range().min(), -128)
|
||||
<< "The minimum of the output int tensor range must be greater than "
|
||||
"or equal to -128.";
|
||||
RET_CHECK_LE(options.output_tensor_int_range().max(), 127)
|
||||
<< "The maximum of the output int tensor range must be less than or "
|
||||
"equal to 127.";
|
||||
}
|
||||
RET_CHECK_GT(options.output_tensor_width(), 0)
|
||||
<< "Valid output tensor width is required.";
|
||||
RET_CHECK_GT(options.output_tensor_height(), 0)
|
||||
<< "Valid output tensor height is required.";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
OutputTensorParams GetOutputTensorParams(const T& options) {
|
||||
OutputTensorParams params;
|
||||
if (options.has_output_tensor_uint_range()) {
|
||||
params.range_min =
|
||||
static_cast<float>(options.output_tensor_uint_range().min());
|
||||
params.range_max =
|
||||
static_cast<float>(options.output_tensor_uint_range().max());
|
||||
} else if (options.has_output_tensor_int_range()) {
|
||||
params.range_min =
|
||||
static_cast<float>(options.output_tensor_int_range().min());
|
||||
params.range_max =
|
||||
static_cast<float>(options.output_tensor_int_range().max());
|
||||
} else {
|
||||
params.range_min = options.output_tensor_float_range().min();
|
||||
params.range_max = options.output_tensor_float_range().max();
|
||||
}
|
||||
params.output_width = options.output_tensor_width();
|
||||
params.output_height = options.output_tensor_height();
|
||||
params.is_float_output = options.has_output_tensor_float_range();
|
||||
params.output_batch = 1;
|
||||
return params;
|
||||
}
|
||||
|
||||
// Returns whether the GPU input format starts at the bottom.
|
||||
template <typename T>
|
||||
bool DoesGpuInputStartAtBottom(const T& options) {
|
||||
return options.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
|
||||
}
|
||||
|
||||
// Converts the BorderMode proto into struct.
|
||||
BorderMode GetBorderMode(
|
||||
const mediapipe::ImageToTensorCalculatorOptions::BorderMode& mode);
|
||||
|
||||
// Gets the output tensor type.
|
||||
Tensor::ElementType GetOutputTensorType(bool uses_gpu,
|
||||
const OutputTensorParams& params);
|
||||
|
||||
// Gets the number of output channels from the input Image format.
|
||||
int GetNumOutputChannels(const mediapipe::Image& image);
|
||||
|
||||
// Converts the packet that hosts different format (Image, ImageFrame,
|
||||
// GpuBuffer) into the mediapipe::Image format.
|
||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
||||
const api2::Packet<api2::OneOf<Image, mediapipe::ImageFrame>>&
|
||||
image_packet);
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
||||
const api2::Packet<mediapipe::GpuBuffer>& image_gpu_packet);
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_UTILS_H_
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -23,6 +25,7 @@ namespace {
|
|||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::ElementsAreArray;
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
testing::Matcher<RotatedRect> EqRotatedRect(float width, float height,
|
||||
float center_x, float center_y,
|
||||
|
@ -157,5 +160,95 @@ TEST(GetValueRangeTransformation, FloatToPixel) {
|
|||
EqValueTransformation(/*scale=*/255.0f, /*offset=*/0.0f));
|
||||
}
|
||||
|
||||
constexpr char kValidFloatProto[] = R"(
|
||||
output_tensor_float_range { min: 0.0 max: 1.0 }
|
||||
output_tensor_width: 100
|
||||
output_tensor_height: 200
|
||||
)";
|
||||
|
||||
constexpr char kValidIntProto[] = R"(
|
||||
output_tensor_float_range { min: 0 max: 255 }
|
||||
output_tensor_width: 100
|
||||
output_tensor_height: 200
|
||||
)";
|
||||
|
||||
TEST(ValidateOptionOutputDims, ValidProtos) {
|
||||
const auto float_options =
|
||||
mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(
|
||||
kValidFloatProto);
|
||||
MP_EXPECT_OK(ValidateOptionOutputDims(float_options));
|
||||
}
|
||||
|
||||
TEST(ValidateOptionOutputDims, EmptyProto) {
|
||||
mediapipe::ImageToTensorCalculatorOptions options;
|
||||
// No output tensor range set.
|
||||
EXPECT_THAT(ValidateOptionOutputDims(options),
|
||||
StatusIs(absl::StatusCode::kInternal,
|
||||
HasSubstr("Output tensor range is required")));
|
||||
|
||||
// Invalid output float tensor range.
|
||||
options.mutable_output_tensor_float_range()->set_min(1.0);
|
||||
options.mutable_output_tensor_float_range()->set_max(0.0);
|
||||
EXPECT_THAT(
|
||||
ValidateOptionOutputDims(options),
|
||||
StatusIs(absl::StatusCode::kInternal,
|
||||
HasSubstr("Valid output float tensor range is required")));
|
||||
|
||||
// Output width/height is not set.
|
||||
options.mutable_output_tensor_float_range()->set_min(0.0);
|
||||
options.mutable_output_tensor_float_range()->set_max(1.0);
|
||||
EXPECT_THAT(ValidateOptionOutputDims(options),
|
||||
StatusIs(absl::StatusCode::kInternal,
|
||||
HasSubstr("Valid output tensor width is required")));
|
||||
}
|
||||
|
||||
TEST(GetOutputTensorParams, SetValues) {
|
||||
// Test int range with ImageToTensorCalculatorOptions.
|
||||
const auto int_options =
|
||||
mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(
|
||||
kValidIntProto);
|
||||
const auto params2 = GetOutputTensorParams(int_options);
|
||||
EXPECT_EQ(params2.range_min, 0.0f);
|
||||
EXPECT_EQ(params2.range_max, 255.0f);
|
||||
EXPECT_EQ(params2.output_batch, 1);
|
||||
EXPECT_EQ(params2.output_width, 100);
|
||||
EXPECT_EQ(params2.output_height, 200);
|
||||
}
|
||||
|
||||
TEST(GetBorderMode, GetBorderMode) {
|
||||
// Default to REPLICATE.
|
||||
auto border_mode =
|
||||
mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED;
|
||||
EXPECT_EQ(BorderMode::kReplicate, GetBorderMode(border_mode));
|
||||
|
||||
// Set to ZERO.
|
||||
border_mode =
|
||||
mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO;
|
||||
EXPECT_EQ(BorderMode::kZero, GetBorderMode(border_mode));
|
||||
}
|
||||
|
||||
TEST(GetOutputTensorType, GetOutputTensorType) {
|
||||
OutputTensorParams params;
|
||||
// Return float32 when GPU is enabled.
|
||||
EXPECT_EQ(Tensor::ElementType::kFloat32,
|
||||
GetOutputTensorType(/*uses_gpu=*/true, params));
|
||||
|
||||
// Return float32 when is_float_output is set to true.
|
||||
params.is_float_output = true;
|
||||
EXPECT_EQ(Tensor::ElementType::kFloat32,
|
||||
GetOutputTensorType(/*uses_gpu=*/false, params));
|
||||
|
||||
// Return int8 when range_min is negative.
|
||||
params.is_float_output = false;
|
||||
params.range_min = -255.0f;
|
||||
EXPECT_EQ(Tensor::ElementType::kInt8,
|
||||
GetOutputTensorType(/*uses_gpu=*/false, params));
|
||||
|
||||
// Return 8int8 when range_min is non-negative.
|
||||
params.range_min = 0.0f;
|
||||
EXPECT_EQ(Tensor::ElementType::kUInt8,
|
||||
GetOutputTensorType(/*uses_gpu=*/false, params));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
Loading…
Reference in New Issue
Block a user