No public description
PiperOrigin-RevId: 583936442
This commit is contained in:
		
							parent
							
								
									42d42a5ea1
								
							
						
					
					
						commit
						bd4be30b02
					
				| 
						 | 
					@ -657,6 +657,7 @@ cc_library(
 | 
				
			||||||
    }),
 | 
					    }),
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":tensor_converter_calculator_cc_proto",
 | 
					        ":tensor_converter_calculator_cc_proto",
 | 
				
			||||||
 | 
					        ":tensor_converter_cpu",
 | 
				
			||||||
        "//mediapipe/framework:calculator_framework",
 | 
					        "//mediapipe/framework:calculator_framework",
 | 
				
			||||||
        "//mediapipe/framework:port",
 | 
					        "//mediapipe/framework:port",
 | 
				
			||||||
        "//mediapipe/framework/formats:image_frame",
 | 
					        "//mediapipe/framework/formats:image_frame",
 | 
				
			||||||
| 
						 | 
					@ -665,6 +666,7 @@ cc_library(
 | 
				
			||||||
        "//mediapipe/framework/port:ret_check",
 | 
					        "//mediapipe/framework/port:ret_check",
 | 
				
			||||||
        "//mediapipe/framework/port:status",
 | 
					        "//mediapipe/framework/port:status",
 | 
				
			||||||
        "//mediapipe/framework/port:statusor",
 | 
					        "//mediapipe/framework/port:statusor",
 | 
				
			||||||
 | 
					        "//mediapipe/gpu:gpu_buffer",
 | 
				
			||||||
        "//mediapipe/gpu:gpu_buffer_format",
 | 
					        "//mediapipe/gpu:gpu_buffer_format",
 | 
				
			||||||
        "//mediapipe/gpu:gpu_origin_cc_proto",
 | 
					        "//mediapipe/gpu:gpu_origin_cc_proto",
 | 
				
			||||||
        "//mediapipe/util:resource_util",
 | 
					        "//mediapipe/util:resource_util",
 | 
				
			||||||
| 
						 | 
					@ -674,10 +676,17 @@ cc_library(
 | 
				
			||||||
        "@com_google_absl//absl/log:check",
 | 
					        "@com_google_absl//absl/log:check",
 | 
				
			||||||
        "@com_google_absl//absl/status",
 | 
					        "@com_google_absl//absl/status",
 | 
				
			||||||
        "@com_google_absl//absl/status:statusor",
 | 
					        "@com_google_absl//absl/status:statusor",
 | 
				
			||||||
 | 
					        "@com_google_absl//absl/strings",
 | 
				
			||||||
        "@com_google_absl//absl/strings:str_format",
 | 
					        "@com_google_absl//absl/strings:str_format",
 | 
				
			||||||
    ] + select({
 | 
					    ] + select({
 | 
				
			||||||
        "//mediapipe/gpu:disable_gpu": [],
 | 
					        "//mediapipe/gpu:disable_gpu": [],
 | 
				
			||||||
        "//conditions:default": ["tensor_converter_calculator_gpu_deps"],
 | 
					        "//conditions:default": [
 | 
				
			||||||
 | 
					            "tensor_converter_calculator_gpu_deps",
 | 
				
			||||||
 | 
					            "//mediapipe/gpu:gl_base",
 | 
				
			||||||
 | 
					            "//mediapipe/gpu:gl_calculator_helper",
 | 
				
			||||||
 | 
					            "//mediapipe/gpu:gl_simple_shaders",
 | 
				
			||||||
 | 
					            "//mediapipe/gpu:shader_util",
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
    }) + select({
 | 
					    }) + select({
 | 
				
			||||||
        "//mediapipe:apple": [
 | 
					        "//mediapipe:apple": [
 | 
				
			||||||
            "//third_party/apple_frameworks:MetalKit",
 | 
					            "//third_party/apple_frameworks:MetalKit",
 | 
				
			||||||
| 
						 | 
					@ -687,6 +696,35 @@ cc_library(
 | 
				
			||||||
    alwayslink = 1,
 | 
					    alwayslink = 1,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_library(
 | 
				
			||||||
 | 
					    name = "tensor_converter_cpu",
 | 
				
			||||||
 | 
					    srcs = ["tensor_converter_cpu.cc"],
 | 
				
			||||||
 | 
					    hdrs = ["tensor_converter_cpu.h"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:image_frame",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:matrix",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:tensor",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:ret_check",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:status",
 | 
				
			||||||
 | 
					        "@com_google_absl//absl/status",
 | 
				
			||||||
 | 
					        "@com_google_absl//absl/status:statusor",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_test(
 | 
				
			||||||
 | 
					    name = "tensor_converter_cpu_test",
 | 
				
			||||||
 | 
					    srcs = ["tensor_converter_cpu_test.cc"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":tensor_converter_cpu",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:matrix",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:tensor",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:gtest",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:gtest_main",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:status_matchers",
 | 
				
			||||||
 | 
					        "//mediapipe/util:image_test_utils",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cc_library(
 | 
					cc_library(
 | 
				
			||||||
    name = "tensor_converter_calculator_gpu_deps",
 | 
					    name = "tensor_converter_calculator_gpu_deps",
 | 
				
			||||||
    visibility = ["//visibility:private"],
 | 
					    visibility = ["//visibility:private"],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,6 +14,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <cstdint>
 | 
					#include <cstdint>
 | 
				
			||||||
#include <string>
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
#include <vector>
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "absl/log/absl_check.h"
 | 
					#include "absl/log/absl_check.h"
 | 
				
			||||||
| 
						 | 
					@ -21,17 +22,22 @@
 | 
				
			||||||
#include "absl/status/status.h"
 | 
					#include "absl/status/status.h"
 | 
				
			||||||
#include "absl/status/statusor.h"
 | 
					#include "absl/status/statusor.h"
 | 
				
			||||||
#include "absl/strings/str_format.h"
 | 
					#include "absl/strings/str_format.h"
 | 
				
			||||||
 | 
					#include "absl/strings/substitute.h"
 | 
				
			||||||
 | 
					#include "absl/types/optional.h"
 | 
				
			||||||
#include "mediapipe/calculators/tensor/tensor_converter_calculator.pb.h"
 | 
					#include "mediapipe/calculators/tensor/tensor_converter_calculator.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/calculators/tensor/tensor_converter_cpu.h"
 | 
				
			||||||
#include "mediapipe/framework/calculator_framework.h"
 | 
					#include "mediapipe/framework/calculator_framework.h"
 | 
				
			||||||
#include "mediapipe/framework/formats/image_frame.h"
 | 
					#include "mediapipe/framework/formats/image_frame.h"
 | 
				
			||||||
#include "mediapipe/framework/formats/matrix.h"
 | 
					#include "mediapipe/framework/formats/matrix.h"
 | 
				
			||||||
#include "mediapipe/framework/formats/tensor.h"
 | 
					#include "mediapipe/framework/formats/tensor.h"
 | 
				
			||||||
#include "mediapipe/framework/port.h"
 | 
					#include "mediapipe/framework/port.h"
 | 
				
			||||||
#include "mediapipe/framework/port/ret_check.h"
 | 
					#include "mediapipe/framework/port/ret_check.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/status_macros.h"
 | 
				
			||||||
#include "mediapipe/gpu/gpu_buffer_format.h"
 | 
					#include "mediapipe/gpu/gpu_buffer_format.h"
 | 
				
			||||||
#include "mediapipe/gpu/gpu_origin.pb.h"
 | 
					#include "mediapipe/gpu/gpu_origin.pb.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#if !MEDIAPIPE_DISABLE_GPU
 | 
					#if !MEDIAPIPE_DISABLE_GPU
 | 
				
			||||||
 | 
					#include "mediapipe/gpu/gl_base.h"
 | 
				
			||||||
#include "mediapipe/gpu/gpu_buffer.h"
 | 
					#include "mediapipe/gpu/gpu_buffer.h"
 | 
				
			||||||
#if MEDIAPIPE_METAL_ENABLED
 | 
					#if MEDIAPIPE_METAL_ENABLED
 | 
				
			||||||
#import <CoreVideo/CoreVideo.h>
 | 
					#import <CoreVideo/CoreVideo.h>
 | 
				
			||||||
| 
						 | 
					@ -94,16 +100,13 @@ absl::StatusOr<bool> ShouldFlipVertically(
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
 | 
					 | 
				
			||||||
    RowMajorMatrixXf;
 | 
					 | 
				
			||||||
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>
 | 
					 | 
				
			||||||
    ColMajorMatrixXf;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
constexpr char kImageFrameTag[] = "IMAGE";
 | 
					constexpr char kImageFrameTag[] = "IMAGE";
 | 
				
			||||||
constexpr char kGpuBufferTag[] = "IMAGE_GPU";
 | 
					constexpr char kGpuBufferTag[] = "IMAGE_GPU";
 | 
				
			||||||
constexpr char kTensorsTag[] = "TENSORS";
 | 
					constexpr char kTensorsTag[] = "TENSORS";
 | 
				
			||||||
constexpr char kMatrixTag[] = "MATRIX";
 | 
					constexpr char kMatrixTag[] = "MATRIX";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					constexpr std::pair<float, float> kDefaultOutputRange = {0.0f, 1.0f};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace
 | 
					}  // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mediapipe {
 | 
					namespace mediapipe {
 | 
				
			||||||
| 
						 | 
					@ -156,10 +159,6 @@ class TensorConverterCalculator : public CalculatorBase {
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  absl::Status InitGpu(CalculatorContext* cc);
 | 
					  absl::Status InitGpu(CalculatorContext* cc);
 | 
				
			||||||
  absl::Status LoadOptions(CalculatorContext* cc, bool use_gpu);
 | 
					  absl::Status LoadOptions(CalculatorContext* cc, bool use_gpu);
 | 
				
			||||||
  template <class T>
 | 
					 | 
				
			||||||
  absl::Status NormalizeImage(const ImageFrame& image_frame,
 | 
					 | 
				
			||||||
                              bool flip_vertically, float* tensor_ptr);
 | 
					 | 
				
			||||||
  absl::Status CopyMatrixToTensor(const Matrix& matrix, float* tensor_ptr);
 | 
					 | 
				
			||||||
  absl::Status ProcessCPU(CalculatorContext* cc);
 | 
					  absl::Status ProcessCPU(CalculatorContext* cc);
 | 
				
			||||||
  absl::Status ProcessGPU(CalculatorContext* cc);
 | 
					  absl::Status ProcessGPU(CalculatorContext* cc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -279,46 +278,21 @@ absl::Status TensorConverterCalculator::ProcessCPU(CalculatorContext* cc) {
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    const auto& image_frame =
 | 
					    const auto& image_frame =
 | 
				
			||||||
        cc->Inputs().Tag(kImageFrameTag).Get<ImageFrame>();
 | 
					        cc->Inputs().Tag(kImageFrameTag).Get<ImageFrame>();
 | 
				
			||||||
    const int height = image_frame.Height();
 | 
					    MP_ASSIGN_OR_RETURN(Tensor output,
 | 
				
			||||||
    const int width = image_frame.Width();
 | 
					                        ConvertImageFrameToTensorOnCpu(
 | 
				
			||||||
    const int channels = image_frame.NumberOfChannels();
 | 
					                            image_frame,
 | 
				
			||||||
    const int channels_preserved = std::min(channels, max_num_channels_);
 | 
					                            output_range_.has_value() ? output_range_.value()
 | 
				
			||||||
    const mediapipe::ImageFormat::Format format = image_frame.Format();
 | 
					                                                      : kDefaultOutputRange,
 | 
				
			||||||
 | 
					                            flip_vertically_, max_num_channels_));
 | 
				
			||||||
    if (!(format == mediapipe::ImageFormat::SRGBA ||
 | 
					    output_tensors->emplace_back(std::move(output));
 | 
				
			||||||
          format == mediapipe::ImageFormat::SRGB ||
 | 
					 | 
				
			||||||
          format == mediapipe::ImageFormat::GRAY8 ||
 | 
					 | 
				
			||||||
          format == mediapipe::ImageFormat::VEC32F1))
 | 
					 | 
				
			||||||
      RET_CHECK_FAIL() << "Unsupported CPU input format.";
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    output_tensors->emplace_back(
 | 
					 | 
				
			||||||
        Tensor::ElementType::kFloat32,
 | 
					 | 
				
			||||||
        Tensor::Shape{1, height, width, channels_preserved});
 | 
					 | 
				
			||||||
    auto cpu_view = output_tensors->back().GetCpuWriteView();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Copy image data into tensor.
 | 
					 | 
				
			||||||
    if (image_frame.ByteDepth() == 1) {
 | 
					 | 
				
			||||||
      MP_RETURN_IF_ERROR(NormalizeImage<uint8_t>(image_frame, flip_vertically_,
 | 
					 | 
				
			||||||
                                                 cpu_view.buffer<float>()));
 | 
					 | 
				
			||||||
    } else if (image_frame.ByteDepth() == 4) {
 | 
					 | 
				
			||||||
      MP_RETURN_IF_ERROR(NormalizeImage<float>(image_frame, flip_vertically_,
 | 
					 | 
				
			||||||
                                               cpu_view.buffer<float>()));
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      return absl::InternalError(
 | 
					 | 
				
			||||||
          "Only byte-based (8 bit) and float (32 bit) images supported.");
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  } else if (cc->Inputs().HasTag(kMatrixTag)) {
 | 
					  } else if (cc->Inputs().HasTag(kMatrixTag)) {
 | 
				
			||||||
    if (cc->Inputs().Tag(kMatrixTag).IsEmpty()) {
 | 
					    if (cc->Inputs().Tag(kMatrixTag).IsEmpty()) {
 | 
				
			||||||
      return absl::OkStatus();
 | 
					      return absl::OkStatus();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    const auto& matrix = cc->Inputs().Tag(kMatrixTag).Get<Matrix>();
 | 
					    const auto& matrix = cc->Inputs().Tag(kMatrixTag).Get<Matrix>();
 | 
				
			||||||
    const int height = matrix.rows();
 | 
					    MP_ASSIGN_OR_RETURN(Tensor output,
 | 
				
			||||||
    const int width = matrix.cols();
 | 
					                        ConvertMatrixToTensorOnCpu(matrix, row_major_matrix_));
 | 
				
			||||||
    const int channels = 1;
 | 
					    output_tensors->emplace_back(std::move(output));
 | 
				
			||||||
    output_tensors->emplace_back(Tensor::ElementType::kFloat32,
 | 
					 | 
				
			||||||
                                 Tensor::Shape{1, height, width, channels});
 | 
					 | 
				
			||||||
    MP_RETURN_IF_ERROR(CopyMatrixToTensor(
 | 
					 | 
				
			||||||
        matrix, output_tensors->back().GetCpuWriteView().buffer<float>()));
 | 
					 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    return absl::OkStatus();
 | 
					    return absl::OkStatus();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					@ -669,67 +643,4 @@ absl::Status TensorConverterCalculator::LoadOptions(CalculatorContext* cc,
 | 
				
			||||||
  return absl::OkStatus();
 | 
					  return absl::OkStatus();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <class T>
 | 
					 | 
				
			||||||
absl::Status TensorConverterCalculator::NormalizeImage(
 | 
					 | 
				
			||||||
    const ImageFrame& image_frame, bool flip_vertically, float* tensor_ptr) {
 | 
					 | 
				
			||||||
  const int height = image_frame.Height();
 | 
					 | 
				
			||||||
  const int width = image_frame.Width();
 | 
					 | 
				
			||||||
  const int channels = image_frame.NumberOfChannels();
 | 
					 | 
				
			||||||
  const int channels_preserved = std::min(channels, max_num_channels_);
 | 
					 | 
				
			||||||
  const int channels_ignored = channels - channels_preserved;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  if (output_range_.has_value()) {
 | 
					 | 
				
			||||||
    // If the output float range is set and we are not using custom
 | 
					 | 
				
			||||||
    // normalization, normalize the pixel values from [0, 255] to the specified
 | 
					 | 
				
			||||||
    // output range.
 | 
					 | 
				
			||||||
    RET_CHECK_NE(output_range_->first, output_range_->second);
 | 
					 | 
				
			||||||
    const float scale = (output_range_->second - output_range_->first) / 255.0f;
 | 
					 | 
				
			||||||
    const float bias = output_range_->first;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for (int i = 0; i < height; ++i) {
 | 
					 | 
				
			||||||
      const T* image_ptr = reinterpret_cast<const T*>(
 | 
					 | 
				
			||||||
          image_frame.PixelData() +
 | 
					 | 
				
			||||||
          (flip_vertically ? height - 1 - i : i) * image_frame.WidthStep());
 | 
					 | 
				
			||||||
      for (int j = 0; j < width; ++j) {
 | 
					 | 
				
			||||||
        for (int c = 0; c < channels_preserved; ++c) {
 | 
					 | 
				
			||||||
          *tensor_ptr++ = *image_ptr++ * scale + bias;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        image_ptr += channels_ignored;
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  } else {
 | 
					 | 
				
			||||||
    // [0,1], scale only (bias == 0)
 | 
					 | 
				
			||||||
    // Verified that there are no precision issues with 1.0f / 255.0f expression
 | 
					 | 
				
			||||||
    const float scale = 1.0f / 255.0f;
 | 
					 | 
				
			||||||
    for (int i = 0; i < height; ++i) {
 | 
					 | 
				
			||||||
      const T* image_ptr = reinterpret_cast<const T*>(
 | 
					 | 
				
			||||||
          image_frame.PixelData() +
 | 
					 | 
				
			||||||
          (flip_vertically ? height - 1 - i : i) * image_frame.WidthStep());
 | 
					 | 
				
			||||||
      for (int j = 0; j < width; ++j) {
 | 
					 | 
				
			||||||
        for (int c = 0; c < channels_preserved; ++c) {
 | 
					 | 
				
			||||||
          *tensor_ptr++ = *image_ptr++ * scale;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        image_ptr += channels_ignored;
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  return absl::OkStatus();
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
absl::Status TensorConverterCalculator::CopyMatrixToTensor(const Matrix& matrix,
 | 
					 | 
				
			||||||
                                                           float* tensor_ptr) {
 | 
					 | 
				
			||||||
  if (row_major_matrix_) {
 | 
					 | 
				
			||||||
    auto matrix_map =
 | 
					 | 
				
			||||||
        Eigen::Map<RowMajorMatrixXf>(tensor_ptr, matrix.rows(), matrix.cols());
 | 
					 | 
				
			||||||
    matrix_map = matrix;
 | 
					 | 
				
			||||||
  } else {
 | 
					 | 
				
			||||||
    auto matrix_map =
 | 
					 | 
				
			||||||
        Eigen::Map<ColMajorMatrixXf>(tensor_ptr, matrix.rows(), matrix.cols());
 | 
					 | 
				
			||||||
    matrix_map = matrix;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  return absl::OkStatus();
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}  // namespace mediapipe
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -321,6 +321,61 @@ TEST_F(TensorConverterCalculatorTest, SetOutputRange) {
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_F(TensorConverterCalculatorTest,
 | 
				
			||||||
 | 
					       ShouldConvertImageWithDefaultOutputRange) {
 | 
				
			||||||
 | 
					  CalculatorGraph graph;
 | 
				
			||||||
 | 
					  CalculatorGraphConfig graph_config =
 | 
				
			||||||
 | 
					      mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
 | 
				
			||||||
 | 
					          R"pb(
 | 
				
			||||||
 | 
					            input_stream: "input_image"
 | 
				
			||||||
 | 
					            node {
 | 
				
			||||||
 | 
					              calculator: "TensorConverterCalculator"
 | 
				
			||||||
 | 
					              input_stream: "IMAGE:input_image"
 | 
				
			||||||
 | 
					              output_stream: "TENSORS:tensor"
 | 
				
			||||||
 | 
					              options {
 | 
				
			||||||
 | 
					                [mediapipe.TensorConverterCalculatorOptions.ext] {
 | 
				
			||||||
 | 
					                  zero_center: false
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					              }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					          )pb");
 | 
				
			||||||
 | 
					  std::vector<Packet> output_packets;
 | 
				
			||||||
 | 
					  tool::AddVectorSink("tensor", &graph_config, &output_packets);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Run the graph.
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.Initialize(graph_config));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.StartRun({}));
 | 
				
			||||||
 | 
					  auto input_image = std::make_unique<ImageFrame>(ImageFormat::GRAY8, 1, 1);
 | 
				
			||||||
 | 
					  cv::Mat mat = mediapipe::formats::MatView(input_image.get());
 | 
				
			||||||
 | 
					  mat.at<uint8_t>(0, 0) = 200;
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.AddPacketToInputStream(
 | 
				
			||||||
 | 
					      "input_image", Adopt(input_image.release()).At(Timestamp(0))));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Wait until the calculator finishes processing.
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  ASSERT_EQ(output_packets.size(), 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Get and process results.
 | 
				
			||||||
 | 
					  const std::vector<Tensor>& tensor_vec =
 | 
				
			||||||
 | 
					      output_packets[0].Get<std::vector<Tensor>>();
 | 
				
			||||||
 | 
					  ASSERT_EQ(tensor_vec.size(), 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const Tensor* tensor = &tensor_vec[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Calculate the expected normalized value:
 | 
				
			||||||
 | 
					  float expected_value = 200.0 / 255.0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32);
 | 
				
			||||||
 | 
					  auto view = tensor->GetCpuReadView();
 | 
				
			||||||
 | 
					  float actual_value = *view.buffer<float>();
 | 
				
			||||||
 | 
					  EXPECT_FLOAT_EQ(actual_value, expected_value);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Fully close graph at end, otherwise calculator+tensors are destroyed
 | 
				
			||||||
 | 
					  // after calling WaitUntilDone().
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.CloseInputStream("input_image"));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST_F(TensorConverterCalculatorTest, FlipVertically) {
 | 
					TEST_F(TensorConverterCalculatorTest, FlipVertically) {
 | 
				
			||||||
  CalculatorGraph graph;
 | 
					  CalculatorGraph graph;
 | 
				
			||||||
  CalculatorGraphConfig graph_config =
 | 
					  CalculatorGraphConfig graph_config =
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										145
									
								
								mediapipe/calculators/tensor/tensor_converter_cpu.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										145
									
								
								mediapipe/calculators/tensor/tensor_converter_cpu.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,145 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/calculators/tensor/tensor_converter_cpu.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <algorithm>
 | 
				
			||||||
 | 
					#include <cstdint>
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "absl/status/status.h"
 | 
				
			||||||
 | 
					#include "absl/status/statusor.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/image_frame.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/matrix.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/tensor.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/ret_check.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/status_macros.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe {
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
 | 
				
			||||||
 | 
					    RowMajorMatrixXf;
 | 
				
			||||||
 | 
					typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>
 | 
				
			||||||
 | 
					    ColMajorMatrixXf;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <class T>
 | 
				
			||||||
 | 
					absl::Status NormalizeImage(const ImageFrame& image_frame, bool flip_vertically,
 | 
				
			||||||
 | 
					                            const std::pair<float, float>& output_range,
 | 
				
			||||||
 | 
					                            int max_num_channels, float* tensor_ptr) {
 | 
				
			||||||
 | 
					  const int height = image_frame.Height();
 | 
				
			||||||
 | 
					  const int width = image_frame.Width();
 | 
				
			||||||
 | 
					  const int channels = image_frame.NumberOfChannels();
 | 
				
			||||||
 | 
					  const int channels_preserved = std::min(channels, max_num_channels);
 | 
				
			||||||
 | 
					  const int channels_ignored = channels - channels_preserved;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  RET_CHECK_NE(output_range.first, output_range.second);
 | 
				
			||||||
 | 
					  const float scale = (output_range.second - output_range.first) / 255.0f;
 | 
				
			||||||
 | 
					  const float bias = output_range.first;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int i = 0; i < height; ++i) {
 | 
				
			||||||
 | 
					    const T* image_ptr = reinterpret_cast<const T*>(
 | 
				
			||||||
 | 
					        image_frame.PixelData() +
 | 
				
			||||||
 | 
					        (flip_vertically ? height - 1 - i : i) * image_frame.WidthStep());
 | 
				
			||||||
 | 
					    for (int j = 0; j < width; ++j) {
 | 
				
			||||||
 | 
					      for (int c = 0; c < channels_preserved; ++c) {
 | 
				
			||||||
 | 
					        *tensor_ptr++ = *image_ptr++ * scale + bias;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      image_ptr += channels_ignored;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return absl::OkStatus();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					absl::Status NormalizeUInt8Image(const ImageFrame& image_frame,
 | 
				
			||||||
 | 
					                                 bool flip_vertically,
 | 
				
			||||||
 | 
					                                 const std::pair<float, float>& output_range,
 | 
				
			||||||
 | 
					                                 int max_num_channels, float* tensor_ptr) {
 | 
				
			||||||
 | 
					  return NormalizeImage<uint8_t>(image_frame, flip_vertically, output_range,
 | 
				
			||||||
 | 
					                                 max_num_channels, tensor_ptr);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					absl::Status NormalizeFloatImage(const ImageFrame& image_frame,
 | 
				
			||||||
 | 
					                                 bool flip_vertically,
 | 
				
			||||||
 | 
					                                 const std::pair<float, float>& output_range,
 | 
				
			||||||
 | 
					                                 int max_num_channels, float* tensor_ptr) {
 | 
				
			||||||
 | 
					  return NormalizeImage<float>(image_frame, flip_vertically, output_range,
 | 
				
			||||||
 | 
					                               max_num_channels, tensor_ptr);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					absl::Status CopyMatrixToTensor(const Matrix& matrix, bool is_row_major_matrix,
 | 
				
			||||||
 | 
					                                float* tensor_ptr) {
 | 
				
			||||||
 | 
					  if (is_row_major_matrix) {
 | 
				
			||||||
 | 
					    auto matrix_map =
 | 
				
			||||||
 | 
					        Eigen::Map<RowMajorMatrixXf>(tensor_ptr, matrix.rows(), matrix.cols());
 | 
				
			||||||
 | 
					    matrix_map = matrix;
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    auto matrix_map =
 | 
				
			||||||
 | 
					        Eigen::Map<ColMajorMatrixXf>(tensor_ptr, matrix.rows(), matrix.cols());
 | 
				
			||||||
 | 
					    matrix_map = matrix;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return absl::OkStatus();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					absl::StatusOr<Tensor> ConvertImageFrameToTensorOnCpu(
 | 
				
			||||||
 | 
					    const ImageFrame& image_frame, const std::pair<float, float>& output_range,
 | 
				
			||||||
 | 
					    bool flip_vertically, int max_num_channels) {
 | 
				
			||||||
 | 
					  const int height = image_frame.Height();
 | 
				
			||||||
 | 
					  const int width = image_frame.Width();
 | 
				
			||||||
 | 
					  const int channels = image_frame.NumberOfChannels();
 | 
				
			||||||
 | 
					  const int channels_preserved = std::min(channels, max_num_channels);
 | 
				
			||||||
 | 
					  const mediapipe::ImageFormat::Format format = image_frame.Format();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!(format == mediapipe::ImageFormat::SRGBA ||
 | 
				
			||||||
 | 
					        format == mediapipe::ImageFormat::SRGB ||
 | 
				
			||||||
 | 
					        format == mediapipe::ImageFormat::GRAY8 ||
 | 
				
			||||||
 | 
					        format == mediapipe::ImageFormat::VEC32F1))
 | 
				
			||||||
 | 
					    RET_CHECK_FAIL() << "Unsupported CPU input format.";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Tensor output_tensor(Tensor::ElementType::kFloat32,
 | 
				
			||||||
 | 
					                       Tensor::Shape{1, height, width, channels_preserved});
 | 
				
			||||||
 | 
					  auto cpu_view = output_tensor.GetCpuWriteView();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Copy image data into tensor.
 | 
				
			||||||
 | 
					  if (image_frame.ByteDepth() == 1) {
 | 
				
			||||||
 | 
					    MP_RETURN_IF_ERROR(NormalizeUInt8Image(image_frame, flip_vertically,
 | 
				
			||||||
 | 
					                                           output_range, max_num_channels,
 | 
				
			||||||
 | 
					                                           cpu_view.buffer<float>()));
 | 
				
			||||||
 | 
					  } else if (image_frame.ByteDepth() == 4) {
 | 
				
			||||||
 | 
					    MP_RETURN_IF_ERROR(NormalizeFloatImage(image_frame, flip_vertically,
 | 
				
			||||||
 | 
					                                           output_range, max_num_channels,
 | 
				
			||||||
 | 
					                                           cpu_view.buffer<float>()));
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    return absl::InternalError(
 | 
				
			||||||
 | 
					        "Only byte-based (8 bit) and float (32 bit) images supported.");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return output_tensor;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					absl::StatusOr<Tensor> ConvertMatrixToTensorOnCpu(const Matrix& matrix,
 | 
				
			||||||
 | 
					                                                  bool row_major_matrix) {
 | 
				
			||||||
 | 
					  const int height = matrix.rows();
 | 
				
			||||||
 | 
					  const int width = matrix.cols();
 | 
				
			||||||
 | 
					  const int channels = 1;
 | 
				
			||||||
 | 
					  Tensor output_tensor(Tensor::ElementType::kFloat32,
 | 
				
			||||||
 | 
					                       Tensor::Shape{1, height, width, channels});
 | 
				
			||||||
 | 
					  MP_RETURN_IF_ERROR(
 | 
				
			||||||
 | 
					      CopyMatrixToTensor(matrix, row_major_matrix,
 | 
				
			||||||
 | 
					                         output_tensor.GetCpuWriteView().buffer<float>()));
 | 
				
			||||||
 | 
					  return output_tensor;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace mediapipe
 | 
				
			||||||
							
								
								
									
										61
									
								
								mediapipe/calculators/tensor/tensor_converter_cpu.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								mediapipe/calculators/tensor/tensor_converter_cpu.h
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,61 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#ifndef MEDIAPIPE_CALCULATORS_TENSOR_TENSOR_CONVERTER_CPU_H_
 | 
				
			||||||
 | 
					#define MEDIAPIPE_CALCULATORS_TENSOR_TENSOR_CONVERTER_CPU_H_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "absl/status/status.h"
 | 
				
			||||||
 | 
					#include "absl/status/statusor.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/image_frame.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/matrix.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/tensor.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Converts an ImageFrame to a vector of Tensors.
 | 
				
			||||||
 | 
					// @flip_vertically enables to flip the image during conversion.
 | 
				
			||||||
 | 
					// @max_num_channels can be used to reserve extra channels in the output
 | 
				
			||||||
 | 
					// tensors.
 | 
				
			||||||
 | 
					// Returns output Tensor.
 | 
				
			||||||
 | 
					absl::StatusOr<Tensor> ConvertImageFrameToTensorOnCpu(
 | 
				
			||||||
 | 
					    const ImageFrame& image_frame, const std::pair<float, float>& output_range,
 | 
				
			||||||
 | 
					    bool flip_vertically, int max_num_channels);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Converts a Matrix to a vector of Tensors.
 | 
				
			||||||
 | 
					// @row_major_matrix defines the ordering in the input matrix.
 | 
				
			||||||
 | 
					// @max_num_channels can be used to reserve extra channels in the output
 | 
				
			||||||
 | 
					// tensors.
 | 
				
			||||||
 | 
					// Returns output Tensor.
 | 
				
			||||||
 | 
					absl::StatusOr<Tensor> ConvertMatrixToTensorOnCpu(const Matrix& matrix,
 | 
				
			||||||
 | 
					                                                  bool row_major_matrix);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// For testing only below.
 | 
				
			||||||
 | 
					absl::Status NormalizeUInt8Image(const ImageFrame& image_frame,
 | 
				
			||||||
 | 
					                                 bool flip_vertically,
 | 
				
			||||||
 | 
					                                 const std::pair<float, float>& output_range,
 | 
				
			||||||
 | 
					                                 int max_num_channels, float* tensor_ptr);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					absl::Status NormalizeFloatImage(const ImageFrame& image_frame,
 | 
				
			||||||
 | 
					                                 bool flip_vertically,
 | 
				
			||||||
 | 
					                                 const std::pair<float, float>& output_range,
 | 
				
			||||||
 | 
					                                 int max_num_channels, float* tensor_ptr);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					absl::Status CopyMatrixToTensor(const Matrix& matrix, bool is_row_major_matrix,
 | 
				
			||||||
 | 
					                                float* tensor_ptr);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace mediapipe
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#endif  // MEDIAPIPE_CALCULATORS_TENSOR_TENSOR_CONVERTER_CPU_H_
 | 
				
			||||||
							
								
								
									
										175
									
								
								mediapipe/calculators/tensor/tensor_converter_cpu_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										175
									
								
								mediapipe/calculators/tensor/tensor_converter_cpu_test.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,175 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/calculators/tensor/tensor_converter_cpu.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <cstdint>
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/matrix.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/tensor.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/gmock.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/gtest.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/status_matchers.h"
 | 
				
			||||||
 | 
					#include "mediapipe/util/image_test_utils.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe {
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Matrix CreateTestMatrix(int num_rows, int num_columns) {
 | 
				
			||||||
 | 
					  Matrix matrix(num_rows, num_columns);
 | 
				
			||||||
 | 
					  for (int r = 0; r < num_rows; ++r) {
 | 
				
			||||||
 | 
					    for (int c = 0; c < num_columns; ++c) {
 | 
				
			||||||
 | 
					      matrix(r, c) = r * num_columns + c;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return matrix;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(TensorConverterCpuTest, ShouldCopyMatrixInRowMajorFormatToTensor) {
 | 
				
			||||||
 | 
					  auto test_matrix = CreateTestMatrix(/* num_rows=*/3, /*num_columns=*/4);
 | 
				
			||||||
 | 
					  std::vector<float> tensor_data(test_matrix.size(), 0.0f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(CopyMatrixToTensor(test_matrix, /*is_row_major_matrix=*/true,
 | 
				
			||||||
 | 
					                                  tensor_data.data()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int i = 0; i < tensor_data.size(); ++i) {
 | 
				
			||||||
 | 
					    const int row = i / test_matrix.cols();
 | 
				
			||||||
 | 
					    const int column = i % test_matrix.cols();
 | 
				
			||||||
 | 
					    EXPECT_FLOAT_EQ(tensor_data[i], (test_matrix)(row, column));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(TensorConverterCpuTest, ShouldCopyMatrixInColumnMajorFormatToTensor) {
 | 
				
			||||||
 | 
					  auto test_matrix = CreateTestMatrix(/*num_rows=*/3, /*num_columns=*/4);
 | 
				
			||||||
 | 
					  std::vector<float> tensor_data(test_matrix.size(), 0.0f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(CopyMatrixToTensor(test_matrix, /*is_row_major_matrix=*/false,
 | 
				
			||||||
 | 
					                                  tensor_data.data()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int i = 0; i < tensor_data.size(); ++i) {
 | 
				
			||||||
 | 
					    const int row = i % test_matrix.rows();
 | 
				
			||||||
 | 
					    const int column = i / test_matrix.rows();
 | 
				
			||||||
 | 
					    EXPECT_FLOAT_EQ(tensor_data[i], (test_matrix)(row, column));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(TensorConverterCpuTest, ShouldNormalizeGrey8ImageWithDefaultRange) {
 | 
				
			||||||
 | 
					  auto grey8_image_frame = CreateTestGrey8ImageFrame(/*width=*/3, /*height=*/4);
 | 
				
			||||||
 | 
					  std::vector<float> tensor_data(
 | 
				
			||||||
 | 
					      grey8_image_frame.Width() * grey8_image_frame.Height(), 0.0f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(NormalizeUInt8Image(grey8_image_frame, /*flip_vertically=*/false,
 | 
				
			||||||
 | 
					                                   {0.0f, 1.0f}, /*num_tensor_channels=*/1,
 | 
				
			||||||
 | 
					                                   tensor_data.data()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int i = 0; i < tensor_data.size(); ++i) {
 | 
				
			||||||
 | 
					    EXPECT_FLOAT_EQ(
 | 
				
			||||||
 | 
					        tensor_data[i],
 | 
				
			||||||
 | 
					        static_cast<uint8_t>(grey8_image_frame.PixelData()[i]) / 255.0f);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(TensorConverterCpuTest, ShouldNormalizeGrey8ImageWithSpecifiedRange) {
 | 
				
			||||||
 | 
					  auto grey8_image_frame = CreateTestGrey8ImageFrame(/*width=*/3, /*height=*/4);
 | 
				
			||||||
 | 
					  std::vector<float> tensor_data(
 | 
				
			||||||
 | 
					      grey8_image_frame.Width() * grey8_image_frame.Height(), 0.0f);
 | 
				
			||||||
 | 
					  const auto range = std::make_pair(2.0f, 3.0f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(
 | 
				
			||||||
 | 
					      NormalizeUInt8Image(grey8_image_frame, /*flip_vertically=*/false, range,
 | 
				
			||||||
 | 
					                          /*num_tensor_channels=*/1, tensor_data.data()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int i = 0; i < tensor_data.size(); ++i) {
 | 
				
			||||||
 | 
					    EXPECT_FLOAT_EQ(tensor_data[i],
 | 
				
			||||||
 | 
					                    static_cast<uint8_t>(grey8_image_frame.PixelData()[i]) /
 | 
				
			||||||
 | 
					                            255.0f * (range.second - range.first) +
 | 
				
			||||||
 | 
					                        range.first);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(TensorConverterCpuTest, ShouldNormalizeGrey8ImageFlipped) {
 | 
				
			||||||
 | 
					  auto grey8_image_frame = CreateTestGrey8ImageFrame(/*width=*/3, /*height=*/4);
 | 
				
			||||||
 | 
					  std::vector<float> tensor_data(
 | 
				
			||||||
 | 
					      grey8_image_frame.Width() * grey8_image_frame.Height(), 0.0f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(NormalizeUInt8Image(grey8_image_frame, /*flip_vertically=*/true,
 | 
				
			||||||
 | 
					                                   {0.0f, 1.0f}, /*num_tensor_channels=*/1,
 | 
				
			||||||
 | 
					                                   tensor_data.data()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int i = 0; i < tensor_data.size(); ++i) {
 | 
				
			||||||
 | 
					    const int x = i % grey8_image_frame.Width();
 | 
				
			||||||
 | 
					    const int y = i / grey8_image_frame.Width();
 | 
				
			||||||
 | 
					    const int flipped_y = grey8_image_frame.Height() - y - 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const int index = flipped_y * grey8_image_frame.Width() + x;
 | 
				
			||||||
 | 
					    EXPECT_FLOAT_EQ(
 | 
				
			||||||
 | 
					        tensor_data[index],
 | 
				
			||||||
 | 
					        static_cast<uint8_t>(grey8_image_frame.PixelData()[i]) / 255.0f);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(TensorConverterCpuTest, ShouldNormalizeFloatImageWithDefaultRange) {
 | 
				
			||||||
 | 
					  auto float_image_frame =
 | 
				
			||||||
 | 
					      CreateTestFloat32ImageFrame(/*width=*/3, /*height=*/4);
 | 
				
			||||||
 | 
					  std::vector<float> tensor_data(
 | 
				
			||||||
 | 
					      float_image_frame.Width() * float_image_frame.Height(), 0.0f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(NormalizeFloatImage(float_image_frame, /*flip_vertically=*/false,
 | 
				
			||||||
 | 
					                                   {0.0f, 1.0f}, /*num_tensor_channels=*/1,
 | 
				
			||||||
 | 
					                                   tensor_data.data()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int i = 0; i < tensor_data.size(); ++i) {
 | 
				
			||||||
 | 
					    EXPECT_FLOAT_EQ(tensor_data[i], reinterpret_cast<const float*>(
 | 
				
			||||||
 | 
					                                        float_image_frame.PixelData())[i] /
 | 
				
			||||||
 | 
					                                        255.0f);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(TensorConverterCpuTest, ConvertImageFrameToTensorOnCpu) {
 | 
				
			||||||
 | 
					  auto grey8_image_frame = CreateTestGrey8ImageFrame(/*width=*/3, /*height=*/4);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MP_ASSERT_OK_AND_ASSIGN(Tensor output, ConvertImageFrameToTensorOnCpu(
 | 
				
			||||||
 | 
					                                             grey8_image_frame, {0.0f, 1.0f},
 | 
				
			||||||
 | 
					                                             /*flip_vertically=*/false,
 | 
				
			||||||
 | 
					                                             /*max_num_channels=*/1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const auto cpu_read_view = output.GetCpuReadView();
 | 
				
			||||||
 | 
					  const float* tensor_ptr = cpu_read_view.buffer<float>();
 | 
				
			||||||
 | 
					  for (int i = 0; i < grey8_image_frame.Width() * grey8_image_frame.Height();
 | 
				
			||||||
 | 
					       ++i) {
 | 
				
			||||||
 | 
					    EXPECT_FLOAT_EQ(
 | 
				
			||||||
 | 
					        tensor_ptr[i],
 | 
				
			||||||
 | 
					        static_cast<uint8_t>(grey8_image_frame.PixelData()[i]) / 255.0);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(TensorConverterCpuTest, ConvertMatrixToTensorOnCpu) {
 | 
				
			||||||
 | 
					  auto test_matrix = CreateTestMatrix(/*num_rows=*/3, /*num_columns=*/4);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MP_ASSERT_OK_AND_ASSIGN(
 | 
				
			||||||
 | 
					      Tensor output, ConvertMatrixToTensorOnCpu(test_matrix,
 | 
				
			||||||
 | 
					                                                /*row_major_matrix=*/false));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const auto cpu_read_view = output.GetCpuReadView();
 | 
				
			||||||
 | 
					  const float* tensor_ptr = cpu_read_view.buffer<float>();
 | 
				
			||||||
 | 
					  for (int i = 0; i < test_matrix.size(); ++i) {
 | 
				
			||||||
 | 
					    EXPECT_FLOAT_EQ(tensor_ptr[i], test_matrix.data()[i]);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					@ -17,6 +17,34 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mediapipe {
 | 
					namespace mediapipe {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <ImageFormat::Format Format, typename DataType>
 | 
				
			||||||
 | 
					ImageFrame CreateTestImageFrame(int width, int height, DataType max_value) {
 | 
				
			||||||
 | 
					  ImageFrame image_frame(Format, width, height,
 | 
				
			||||||
 | 
					                         /*alignment_boundary=*/1);
 | 
				
			||||||
 | 
					  const int num_channels = image_frame.NumberOfChannels();
 | 
				
			||||||
 | 
					  const float num_values = width * height * num_channels;
 | 
				
			||||||
 | 
					  uint8_t* const data_ptr =
 | 
				
			||||||
 | 
					      reinterpret_cast<uint8_t*>(image_frame.MutablePixelData());
 | 
				
			||||||
 | 
					  for (int y = 0; y < height; ++y) {
 | 
				
			||||||
 | 
					    uint8_t* const row = data_ptr + image_frame.WidthStep() * y;
 | 
				
			||||||
 | 
					    for (int x = 0; x < width; ++x) {
 | 
				
			||||||
 | 
					      DataType* pixel = reinterpret_cast<DataType*>(row) + x * num_channels;
 | 
				
			||||||
 | 
					      for (int c = 0; c < num_channels; ++c) {
 | 
				
			||||||
 | 
					        // Fill pixel channel with a value in [0:max_value] range.
 | 
				
			||||||
 | 
					        pixel[c] =
 | 
				
			||||||
 | 
					            static_cast<DataType>(static_cast<float>(y * width * num_channels +
 | 
				
			||||||
 | 
					                                                     x * num_channels + c) /
 | 
				
			||||||
 | 
					                                  num_values * max_value);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return image_frame;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cv::Mat GetRgb(const std::string& path) {
 | 
					cv::Mat GetRgb(const std::string& path) {
 | 
				
			||||||
  cv::Mat bgr = cv::imread(path);
 | 
					  cv::Mat bgr = cv::imread(path);
 | 
				
			||||||
  cv::Mat rgb;
 | 
					  cv::Mat rgb;
 | 
				
			||||||
| 
						 | 
					@ -71,4 +99,14 @@ cv::Mat RgbaToBgr(cv::Mat rgba) {
 | 
				
			||||||
  return bgra;
 | 
					  return bgra;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ImageFrame CreateTestFloat32ImageFrame(int width, int height) {
 | 
				
			||||||
 | 
					  return CreateTestImageFrame<ImageFormat::VEC32F1, float>(width, height,
 | 
				
			||||||
 | 
					                                                           /*max_value=*/1.0f);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ImageFrame CreateTestGrey8ImageFrame(int width, int height) {
 | 
				
			||||||
 | 
					  return CreateTestImageFrame<ImageFormat::GRAY8, uint8_t>(width, height,
 | 
				
			||||||
 | 
					                                                           /*max_value=*/255);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace mediapipe
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4,6 +4,7 @@
 | 
				
			||||||
#include <string>
 | 
					#include <string>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mediapipe/framework/formats/image_format.pb.h"
 | 
					#include "mediapipe/framework/formats/image_format.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/image_frame.h"
 | 
				
			||||||
#include "mediapipe/framework/packet.h"
 | 
					#include "mediapipe/framework/packet.h"
 | 
				
			||||||
#include "mediapipe/framework/port/opencv_core_inc.h"
 | 
					#include "mediapipe/framework/port/opencv_core_inc.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -30,6 +31,12 @@ Packet MakeImagePacket(cv::Mat input, int timestamp = 0);
 | 
				
			||||||
// Converts RGBA Mat to BGR.
 | 
					// Converts RGBA Mat to BGR.
 | 
				
			||||||
cv::Mat RgbaToBgr(cv::Mat rgba);
 | 
					cv::Mat RgbaToBgr(cv::Mat rgba);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Generates single-channel float32 ImageFrame with increasing [0,1] values.
 | 
				
			||||||
 | 
					ImageFrame CreateTestFloat32ImageFrame(int width, int height);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Generates single-channel uint8 ImageFrame with increasing [0,255] values.
 | 
				
			||||||
 | 
					ImageFrame CreateTestGrey8ImageFrame(int width, int height);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace mediapipe
 | 
					}  // namespace mediapipe
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#endif  // MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_
 | 
					#endif  // MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user