Project import generated by Copybara.
GitOrigin-RevId: b66251317fbebfbb8e1f2ddc64ea5da84bceb7e5
This commit is contained in:
parent
7fb37c80e8
commit
4a20e9909d
|
@ -22,7 +22,6 @@
|
||||||
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
||||||
#include "mediapipe/framework/deps/file_path.h"
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
#include "mediapipe/util/tflite/config.h"
|
#include "mediapipe/util/tflite/config.h"
|
||||||
#include "tensorflow/lite/interpreter_builder.h"
|
|
||||||
|
|
||||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||||
|
@ -53,11 +52,9 @@ class InferenceCalculatorGlImpl
|
||||||
private:
|
private:
|
||||||
absl::Status ReadGpuCaches();
|
absl::Status ReadGpuCaches();
|
||||||
absl::Status SaveGpuCaches();
|
absl::Status SaveGpuCaches();
|
||||||
absl::Status InitInterpreter(CalculatorContext* cc);
|
absl::Status LoadModel(CalculatorContext* cc);
|
||||||
absl::Status LoadDelegate(CalculatorContext* cc,
|
absl::Status LoadDelegate(CalculatorContext* cc);
|
||||||
tflite::InterpreterBuilder* interpreter_builder);
|
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
|
||||||
absl::Status BindBuffersToTensors();
|
|
||||||
absl::Status AllocateTensors();
|
|
||||||
absl::Status InitTFLiteGPURunner(CalculatorContext* cc);
|
absl::Status InitTFLiteGPURunner(CalculatorContext* cc);
|
||||||
|
|
||||||
// TfLite requires us to keep the model alive as long as the interpreter is.
|
// TfLite requires us to keep the model alive as long as the interpreter is.
|
||||||
|
@ -140,11 +137,17 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
|
||||||
#endif // MEDIAPIPE_ANDROID
|
#endif // MEDIAPIPE_ANDROID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner
|
||||||
|
// for everything.
|
||||||
|
if (!use_advanced_gpu_api_) {
|
||||||
|
MP_RETURN_IF_ERROR(LoadModel(cc));
|
||||||
|
}
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
|
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
|
||||||
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc)
|
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc)
|
||||||
: InitInterpreter(cc);
|
: LoadDelegateAndAllocateTensors(cc);
|
||||||
}));
|
}));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -289,6 +292,9 @@ absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() {
|
||||||
|
|
||||||
absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
|
absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
|
||||||
CalculatorContext* cc) {
|
CalculatorContext* cc) {
|
||||||
|
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
||||||
|
const auto& model = *model_packet_.Get();
|
||||||
|
|
||||||
// Create runner
|
// Create runner
|
||||||
tflite::gpu::InferenceOptions options;
|
tflite::gpu::InferenceOptions options;
|
||||||
options.priority1 = allow_precision_loss_
|
options.priority1 = allow_precision_loss_
|
||||||
|
@ -326,12 +332,17 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
if (kSideInOpResolver(cc).IsConnected()) {
|
||||||
const auto& model = *model_packet_.Get();
|
const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get();
|
||||||
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
|
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
|
||||||
const auto& op_resolver = op_resolver_packet.Get();
|
model, op_resolver, /*allow_quant_ops=*/true));
|
||||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
|
} else {
|
||||||
model, op_resolver, /*allow_quant_ops=*/true));
|
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||||
|
kSideInCustomOpResolver(cc).GetOr(
|
||||||
|
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||||
|
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
|
||||||
|
model, op_resolver, /*allow_quant_ops=*/true));
|
||||||
|
}
|
||||||
|
|
||||||
// Create and bind OpenGL buffers for outputs.
|
// Create and bind OpenGL buffers for outputs.
|
||||||
// The buffers are created once and their ids are passed to calculator outputs
|
// The buffers are created once and their ids are passed to calculator outputs
|
||||||
|
@ -350,27 +361,35 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status InferenceCalculatorGlImpl::InitInterpreter(CalculatorContext* cc) {
|
absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
|
||||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
||||||
const auto& model = *model_packet_.Get();
|
const auto& model = *model_packet_.Get();
|
||||||
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
|
if (kSideInOpResolver(cc).IsConnected()) {
|
||||||
const auto& op_resolver = op_resolver_packet.Get();
|
const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get();
|
||||||
tflite::InterpreterBuilder interpreter_builder(model, op_resolver);
|
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
|
||||||
MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder));
|
} else {
|
||||||
|
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||||
|
kSideInCustomOpResolver(cc).GetOr(
|
||||||
|
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||||
|
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
|
||||||
|
}
|
||||||
|
RET_CHECK(interpreter_);
|
||||||
|
|
||||||
#if defined(__EMSCRIPTEN__)
|
#if defined(__EMSCRIPTEN__)
|
||||||
interpreter_builder.SetNumThreads(1);
|
interpreter_->SetNumThreads(1);
|
||||||
#else
|
#else
|
||||||
interpreter_builder.SetNumThreads(
|
interpreter_->SetNumThreads(
|
||||||
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
|
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
|
||||||
#endif // __EMSCRIPTEN__
|
#endif // __EMSCRIPTEN__
|
||||||
RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk);
|
|
||||||
RET_CHECK(interpreter_);
|
|
||||||
MP_RETURN_IF_ERROR(BindBuffersToTensors());
|
|
||||||
MP_RETURN_IF_ERROR(AllocateTensors());
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status InferenceCalculatorGlImpl::AllocateTensors() {
|
absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors(
|
||||||
|
CalculatorContext* cc) {
|
||||||
|
MP_RETURN_IF_ERROR(LoadDelegate(cc));
|
||||||
|
|
||||||
|
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
|
||||||
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||||
// TODO: Support quantized tensors.
|
// TODO: Support quantized tensors.
|
||||||
RET_CHECK_NE(
|
RET_CHECK_NE(
|
||||||
|
@ -379,8 +398,7 @@ absl::Status InferenceCalculatorGlImpl::AllocateTensors() {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status InferenceCalculatorGlImpl::LoadDelegate(
|
absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
|
||||||
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
|
|
||||||
// Configure and create the delegate.
|
// Configure and create the delegate.
|
||||||
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
||||||
options.compile_options.precision_loss_allowed =
|
options.compile_options.precision_loss_allowed =
|
||||||
|
@ -391,11 +409,7 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegate(
|
||||||
options.compile_options.inline_parameters = 1;
|
options.compile_options.inline_parameters = 1;
|
||||||
delegate_ = TfLiteDelegatePtr(TfLiteGpuDelegateCreate(&options),
|
delegate_ = TfLiteDelegatePtr(TfLiteGpuDelegateCreate(&options),
|
||||||
&TfLiteGpuDelegateDelete);
|
&TfLiteGpuDelegateDelete);
|
||||||
interpreter_builder->AddDelegate(delegate_.get());
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status InferenceCalculatorGlImpl::BindBuffersToTensors() {
|
|
||||||
// Get input image sizes.
|
// Get input image sizes.
|
||||||
const auto& input_indices = interpreter_->inputs();
|
const auto& input_indices = interpreter_->inputs();
|
||||||
for (int i = 0; i < input_indices.size(); ++i) {
|
for (int i = 0; i < input_indices.size(); ++i) {
|
||||||
|
@ -427,6 +441,11 @@ absl::Status InferenceCalculatorGlImpl::BindBuffersToTensors() {
|
||||||
output_indices[i]),
|
output_indices[i]),
|
||||||
kTfLiteOk);
|
kTfLiteOk);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Must call this last.
|
||||||
|
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
|
||||||
|
kTfLiteOk);
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -180,7 +180,7 @@ class Packet {
|
||||||
// Returns an error if the packet does not contain data of type T.
|
// Returns an error if the packet does not contain data of type T.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
absl::Status ValidateAsType() const {
|
absl::Status ValidateAsType() const {
|
||||||
return ValidateAsType(tool::TypeId<T>());
|
return ValidateAsType(tool::TypeInfo::Get<T>());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns an error if the packet is not an instance of
|
// Returns an error if the packet is not an instance of
|
||||||
|
@ -428,7 +428,7 @@ StatusOr<std::vector<const proto_ns::MessageLite*>>
|
||||||
ConvertToVectorOfProtoMessageLitePtrs(const T* data,
|
ConvertToVectorOfProtoMessageLitePtrs(const T* data,
|
||||||
/*is_proto_vector=*/std::false_type) {
|
/*is_proto_vector=*/std::false_type) {
|
||||||
return absl::InvalidArgumentError(absl::StrCat(
|
return absl::InvalidArgumentError(absl::StrCat(
|
||||||
"The Packet stores \"", tool::TypeId<T>().name(), "\"",
|
"The Packet stores \"", tool::TypeInfo::Get<T>().name(), "\"",
|
||||||
"which is not convertible to vector<proto_ns::MessageLite*>."));
|
"which is not convertible to vector<proto_ns::MessageLite*>."));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -510,7 +510,9 @@ class Holder : public HolderBase {
|
||||||
HolderSupport<T>::EnsureStaticInit();
|
HolderSupport<T>::EnsureStaticInit();
|
||||||
return *ptr_;
|
return *ptr_;
|
||||||
}
|
}
|
||||||
const tool::TypeInfo& GetTypeInfo() const final { return tool::TypeId<T>(); }
|
const tool::TypeInfo& GetTypeInfo() const final {
|
||||||
|
return tool::TypeInfo::Get<T>();
|
||||||
|
}
|
||||||
// Releases the underlying data pointer and transfers the ownership to a
|
// Releases the underlying data pointer and transfers the ownership to a
|
||||||
// unique pointer.
|
// unique pointer.
|
||||||
// This method is dangerous and is only used by Packet::Consume() if the
|
// This method is dangerous and is only used by Packet::Consume() if the
|
||||||
|
|
|
@ -259,14 +259,14 @@ absl::Status ValidatePacketTypeSet(const PacketTypeSet& packet_type_set);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
PacketType& PacketType::Set() {
|
PacketType& PacketType::Set() {
|
||||||
type_spec_ = &tool::TypeId<T>();
|
type_spec_ = &tool::TypeInfo::Get<T>();
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... T>
|
template <typename... T>
|
||||||
PacketType& PacketType::SetOneOf() {
|
PacketType& PacketType::SetOneOf() {
|
||||||
static const NoDestructor<std::vector<const tool::TypeInfo*>> types{
|
static const NoDestructor<std::vector<const tool::TypeInfo*>> types{
|
||||||
{&tool::TypeId<T>()...}};
|
{&tool::TypeInfo::Get<T>()...}};
|
||||||
static const NoDestructor<std::string> name{TypeNameForOneOf(*types)};
|
static const NoDestructor<std::string> name{TypeNameForOneOf(*types)};
|
||||||
type_spec_ = MultiType{*types, &*name};
|
type_spec_ = MultiType{*types, &*name};
|
||||||
return *this;
|
return *this;
|
||||||
|
|
|
@ -761,9 +761,11 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
"//mediapipe/framework/port:advanced_proto",
|
"//mediapipe/framework/port:advanced_proto",
|
||||||
"//mediapipe/framework/port:file_helpers",
|
"//mediapipe/framework/port:file_helpers",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
"//mediapipe/framework/port:logging",
|
"//mediapipe/framework/port:logging",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
|
"@com_google_absl//absl/cleanup",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
|
|
|
@ -58,14 +58,14 @@ class TypeMap {
|
||||||
public:
|
public:
|
||||||
template <class T>
|
template <class T>
|
||||||
bool Has() const {
|
bool Has() const {
|
||||||
return content_.count(TypeId<T>()) > 0;
|
return content_.count(TypeInfo::Get<T>()) > 0;
|
||||||
}
|
}
|
||||||
template <class T>
|
template <class T>
|
||||||
T* Get() const {
|
T* Get() const {
|
||||||
if (!Has<T>()) {
|
if (!Has<T>()) {
|
||||||
content_[TypeId<T>()] = std::make_shared<T>();
|
content_[TypeInfo::Get<T>()] = std::make_shared<T>();
|
||||||
}
|
}
|
||||||
return static_cast<T*>(content_[TypeId<T>()].get());
|
return static_cast<T*>(content_[TypeInfo::Get<T>()].get());
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/cleanup/cleanup.h"
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
|
@ -33,6 +34,7 @@
|
||||||
#include "mediapipe/framework/formats/image_format.pb.h"
|
#include "mediapipe/framework/formats/image_format.pb.h"
|
||||||
#include "mediapipe/framework/port/advanced_proto_inc.h"
|
#include "mediapipe/framework/port/advanced_proto_inc.h"
|
||||||
#include "mediapipe/framework/port/file_helpers.h"
|
#include "mediapipe/framework/port/file_helpers.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/logging.h"
|
#include "mediapipe/framework/port/logging.h"
|
||||||
#include "mediapipe/framework/port/proto_ns.h"
|
#include "mediapipe/framework/port/proto_ns.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
@ -208,6 +210,27 @@ bool CompareImageFrames(const ImageFrame& image1, const ImageFrame& image2,
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::Status CompareAndSaveImageOutput(
|
||||||
|
absl::string_view golden_image_path, const ImageFrame& actual,
|
||||||
|
const ImageFrameComparisonOptions& options) {
|
||||||
|
ASSIGN_OR_RETURN(auto output_img_path, SavePngTestOutput(actual, "output"));
|
||||||
|
|
||||||
|
auto expected = LoadTestImage(GetTestFilePath(golden_image_path));
|
||||||
|
if (!expected.ok()) {
|
||||||
|
return expected.status();
|
||||||
|
}
|
||||||
|
ASSIGN_OR_RETURN(auto expected_img_path,
|
||||||
|
SavePngTestOutput(**expected, "expected"));
|
||||||
|
|
||||||
|
std::unique_ptr<ImageFrame> diff_img;
|
||||||
|
auto status = CompareImageFrames(**expected, actual, options.max_color_diff,
|
||||||
|
options.max_alpha_diff, options.max_avg_diff,
|
||||||
|
diff_img);
|
||||||
|
ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff"));
|
||||||
|
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
std::string GetTestRootDir() {
|
std::string GetTestRootDir() {
|
||||||
return file::JoinPath(std::getenv("TEST_SRCDIR"), "mediapipe");
|
return file::JoinPath(std::getenv("TEST_SRCDIR"), "mediapipe");
|
||||||
}
|
}
|
||||||
|
@ -275,6 +298,23 @@ std::unique_ptr<ImageFrame> LoadTestPng(absl::string_view path,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write an ImageFrame as PNG to the test undeclared outputs directory.
|
||||||
|
// The image's name will contain the given prefix and a timestamp.
|
||||||
|
// Returns the path to the output if successful.
|
||||||
|
absl::StatusOr<std::string> SavePngTestOutput(
|
||||||
|
const mediapipe::ImageFrame& image, absl::string_view prefix) {
|
||||||
|
std::string now_string = absl::FormatTime(absl::Now());
|
||||||
|
std::string output_relative_path =
|
||||||
|
absl::StrCat(prefix, "_", now_string, ".png");
|
||||||
|
std::string output_full_path =
|
||||||
|
file::JoinPath(GetTestOutputsDir(), output_relative_path);
|
||||||
|
RET_CHECK(stbi_write_png(output_full_path.c_str(), image.Width(),
|
||||||
|
image.Height(), image.NumberOfChannels(),
|
||||||
|
image.PixelData(), image.WidthStep()))
|
||||||
|
<< " path: " << output_full_path;
|
||||||
|
return output_relative_path;
|
||||||
|
}
|
||||||
|
|
||||||
bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path) {
|
bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path) {
|
||||||
int fd = open(path.c_str(), O_RDONLY);
|
int fd = open(path.c_str(), O_RDONLY);
|
||||||
if (fd == -1) {
|
if (fd == -1) {
|
||||||
|
|
|
@ -22,20 +22,33 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
using mediapipe::CalculatorGraphConfig;
|
using mediapipe::CalculatorGraphConfig;
|
||||||
|
|
||||||
|
struct ImageFrameComparisonOptions {
|
||||||
|
// NOTE: these values are not normalized: use a value from 0 to 2^8-1
|
||||||
|
// for 8-bit data and a value from 0 to 2^16-1 for 16-bit data.
|
||||||
|
// Although these members are declared as floats,, all uint8/uint16
|
||||||
|
// values are exactly representable. (2^24 + 1 is the first non-representable
|
||||||
|
// positive integral value.)
|
||||||
|
|
||||||
|
// Maximum value difference allowed for non-alpha channels.
|
||||||
|
float max_color_diff;
|
||||||
|
// Maximum value difference allowed for alpha channel (if present).
|
||||||
|
float max_alpha_diff;
|
||||||
|
// Maximum difference for all channels, averaged across all pixels.
|
||||||
|
float max_avg_diff;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compares an output image with a golden file. Saves the output and difference
|
||||||
|
// to the undeclared test outputs.
|
||||||
|
// Returns ok if they are equal within the tolerances specified in options.
|
||||||
|
absl::Status CompareAndSaveImageOutput(
|
||||||
|
absl::string_view golden_image_path, const ImageFrame& actual,
|
||||||
|
const ImageFrameComparisonOptions& options);
|
||||||
|
|
||||||
// Checks if two image frames are equal within the specified tolerance.
|
// Checks if two image frames are equal within the specified tolerance.
|
||||||
// image1 and image2 may be of different-but-compatible image formats (e.g.,
|
// image1 and image2 may be of different-but-compatible image formats (e.g.,
|
||||||
// SRGB and SRGBA); in that case, only the channels available in both are
|
// SRGB and SRGBA); in that case, only the channels available in both are
|
||||||
// compared.
|
// compared.
|
||||||
// max_color_diff applies to the first 3 channels; i.e., R, G, B for sRGB and
|
// The diff arguments are as in ImageFrameComparisonOptions.
|
||||||
// sRGBA, and the single gray channel for GRAY8 and GRAY16. It is the maximum
|
|
||||||
// pixel color value difference allowed; i.e., a value from 0 to 2^8-1 for 8-bit
|
|
||||||
// data and a value from 0 to 2^16-1 for 16-bit data.
|
|
||||||
// max_alpha_diff applies to the 4th (alpha) channel only, if present.
|
|
||||||
// max_avg_diff applies to all channels, normalized across all pixels.
|
|
||||||
//
|
|
||||||
// Note: Although max_color_diff and max_alpha_diff are floats, all uint8/uint16
|
|
||||||
// values are exactly representable. (2^24 + 1 is the first non-representable
|
|
||||||
// positive integral value.)
|
|
||||||
absl::Status CompareImageFrames(const ImageFrame& image1,
|
absl::Status CompareImageFrames(const ImageFrame& image1,
|
||||||
const ImageFrame& image2,
|
const ImageFrame& image2,
|
||||||
const float max_color_diff,
|
const float max_color_diff,
|
||||||
|
@ -77,6 +90,13 @@ absl::StatusOr<std::unique_ptr<ImageFrame>> LoadTestImage(
|
||||||
std::unique_ptr<ImageFrame> LoadTestPng(
|
std::unique_ptr<ImageFrame> LoadTestPng(
|
||||||
absl::string_view path, ImageFormat::Format format = ImageFormat::SRGBA);
|
absl::string_view path, ImageFormat::Format format = ImageFormat::SRGBA);
|
||||||
|
|
||||||
|
// Write an ImageFrame as PNG to the test undeclared outputs directory.
|
||||||
|
// The image's name will contain the given prefix and a timestamp.
|
||||||
|
// If successful, returns the path to the output file relative to the output
|
||||||
|
// directory.
|
||||||
|
absl::StatusOr<std::string> SavePngTestOutput(
|
||||||
|
const mediapipe::ImageFrame& image, absl::string_view prefix);
|
||||||
|
|
||||||
// Returns the luminance image of |original_image|.
|
// Returns the luminance image of |original_image|.
|
||||||
// The format of |original_image| must be sRGB or sRGBA.
|
// The format of |original_image| must be sRGB or sRGBA.
|
||||||
std::unique_ptr<ImageFrame> GenerateLuminanceImage(
|
std::unique_ptr<ImageFrame> GenerateLuminanceImage(
|
||||||
|
|
|
@ -78,12 +78,6 @@ class TypeIndex {
|
||||||
const TypeInfo& info_;
|
const TypeInfo& info_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns a unique identifier for type T.
|
|
||||||
template <typename T>
|
|
||||||
const TypeInfo& TypeId() {
|
|
||||||
return TypeInfo::Get<T>();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper method that returns a hash code of the given type. This allows for
|
// Helper method that returns a hash code of the given type. This allows for
|
||||||
// typeid testing across multiple binaries, unlike FastTypeId which used a
|
// typeid testing across multiple binaries, unlike FastTypeId which used a
|
||||||
// memory location that only works within the same binary. Moreover, we use this
|
// memory location that only works within the same binary. Moreover, we use this
|
||||||
|
@ -94,7 +88,7 @@ const TypeInfo& TypeId() {
|
||||||
// as much as possible.
|
// as much as possible.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
size_t GetTypeHash() {
|
size_t GetTypeHash() {
|
||||||
return TypeId<T>().hash_code();
|
return TypeInfo::Get<T>().hash_code();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tool
|
} // namespace tool
|
||||||
|
|
|
@ -386,7 +386,7 @@ inline std::string MediaPipeTypeStringOrDemangled(
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::string MediaPipeTypeStringOrDemangled() {
|
std::string MediaPipeTypeStringOrDemangled() {
|
||||||
return MediaPipeTypeStringOrDemangled(tool::TypeId<T>());
|
return MediaPipeTypeStringOrDemangled(tool::TypeInfo::Get<T>());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns type hash id of type identified by type_string or NULL if not
|
// Returns type hash id of type identified by type_string or NULL if not
|
||||||
|
|
|
@ -26,22 +26,6 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Write an ImageFrame as PNG to the test undeclared outputs directory.
|
|
||||||
// The image's name will contain the given prefix and a timestamp.
|
|
||||||
// Returns the path to the output if successful.
|
|
||||||
std::string SavePngImage(const mediapipe::ImageFrame& image,
|
|
||||||
absl::string_view prefix) {
|
|
||||||
std::string output_dir = mediapipe::GetTestOutputsDir();
|
|
||||||
std::string now_string = absl::FormatTime(absl::Now());
|
|
||||||
std::string out_file_path =
|
|
||||||
absl::StrCat(output_dir, "/", prefix, "_", now_string, ".png");
|
|
||||||
EXPECT_TRUE(stbi_write_png(out_file_path.c_str(), image.Width(),
|
|
||||||
image.Height(), image.NumberOfChannels(),
|
|
||||||
image.PixelData(), image.WidthStep()))
|
|
||||||
<< " path: " << out_file_path;
|
|
||||||
return out_file_path;
|
|
||||||
}
|
|
||||||
|
|
||||||
void FillImageFrameRGBA(ImageFrame& image, uint8 r, uint8 g, uint8 b, uint8 a) {
|
void FillImageFrameRGBA(ImageFrame& image, uint8 r, uint8 g, uint8 b, uint8 a) {
|
||||||
auto* data = image.MutablePixelData();
|
auto* data = image.MutablePixelData();
|
||||||
for (int y = 0; y < image.Height(); ++y) {
|
for (int y = 0; y < image.Height(); ++y) {
|
||||||
|
@ -143,8 +127,8 @@ TEST_F(GpuBufferTest, GlTextureView) {
|
||||||
FillImageFrameRGBA(red, 255, 0, 0, 255);
|
FillImageFrameRGBA(red, 255, 0, 0, 255);
|
||||||
|
|
||||||
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0));
|
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0));
|
||||||
SavePngImage(red, "gltv_red_gold");
|
MP_EXPECT_OK(SavePngTestOutput(red, "gltv_red_gold"));
|
||||||
SavePngImage(*view, "gltv_red_view");
|
MP_EXPECT_OK(SavePngTestOutput(*view, "gltv_red_view"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GpuBufferTest, ImageFrame) {
|
TEST_F(GpuBufferTest, ImageFrame) {
|
||||||
|
@ -178,8 +162,8 @@ TEST_F(GpuBufferTest, ImageFrame) {
|
||||||
FillImageFrameRGBA(red, 255, 0, 0, 255);
|
FillImageFrameRGBA(red, 255, 0, 0, 255);
|
||||||
|
|
||||||
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0));
|
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0));
|
||||||
SavePngImage(red, "if_red_gold");
|
MP_EXPECT_OK(SavePngTestOutput(red, "if_red_gold"));
|
||||||
SavePngImage(*view, "if_red_view");
|
MP_EXPECT_OK(SavePngTestOutput(*view, "if_red_view"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -212,8 +196,8 @@ TEST_F(GpuBufferTest, Overwrite) {
|
||||||
FillImageFrameRGBA(red, 255, 0, 0, 255);
|
FillImageFrameRGBA(red, 255, 0, 0, 255);
|
||||||
|
|
||||||
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0));
|
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0));
|
||||||
SavePngImage(red, "ow_red_gold");
|
MP_EXPECT_OK(SavePngTestOutput(red, "ow_red_gold"));
|
||||||
SavePngImage(*view, "ow_red_view");
|
MP_EXPECT_OK(SavePngTestOutput(*view, "ow_red_view"));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -246,8 +230,8 @@ TEST_F(GpuBufferTest, Overwrite) {
|
||||||
FillImageFrameRGBA(green, 0, 255, 0, 255);
|
FillImageFrameRGBA(green, 0, 255, 0, 255);
|
||||||
|
|
||||||
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, green, 0.0, 0.0));
|
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, green, 0.0, 0.0));
|
||||||
SavePngImage(green, "ow_green_gold");
|
MP_EXPECT_OK(SavePngTestOutput(green, "ow_green_gold"));
|
||||||
SavePngImage(*view, "ow_green_view");
|
MP_EXPECT_OK(SavePngTestOutput(*view, "ow_green_view"));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -256,8 +240,8 @@ TEST_F(GpuBufferTest, Overwrite) {
|
||||||
FillImageFrameRGBA(blue, 0, 0, 255, 255);
|
FillImageFrameRGBA(blue, 0, 0, 255, 255);
|
||||||
|
|
||||||
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, blue, 0.0, 0.0));
|
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, blue, 0.0, 0.0));
|
||||||
SavePngImage(blue, "ow_blue_gold");
|
MP_EXPECT_OK(SavePngTestOutput(blue, "ow_blue_gold"));
|
||||||
SavePngImage(*view, "ow_blue_view");
|
MP_EXPECT_OK(SavePngTestOutput(*view, "ow_blue_view"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ package com.google.mediapipe.framework;
|
||||||
import com.google.common.base.Preconditions;
|
import com.google.common.base.Preconditions;
|
||||||
import com.google.common.flogger.FluentLogger;
|
import com.google.common.flogger.FluentLogger;
|
||||||
import com.google.mediapipe.framework.ProtoUtil.SerializedMessage;
|
import com.google.mediapipe.framework.ProtoUtil.SerializedMessage;
|
||||||
|
import com.google.protobuf.Internal;
|
||||||
import com.google.protobuf.InvalidProtocolBufferException;
|
import com.google.protobuf.InvalidProtocolBufferException;
|
||||||
import com.google.protobuf.MessageLite;
|
import com.google.protobuf.MessageLite;
|
||||||
import com.google.protobuf.Parser;
|
import com.google.protobuf.Parser;
|
||||||
|
@ -119,11 +120,20 @@ public final class PacketGetter {
|
||||||
return nativeGetProtoBytes(packet.getNativeHandle());
|
return nativeGetProtoBytes(packet.getNativeHandle());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <T extends MessageLite> T getProto(final Packet packet, Class<T> clazz)
|
public static <T extends MessageLite> T getProto(final Packet packet, T defaultInstance)
|
||||||
throws InvalidProtocolBufferException {
|
throws InvalidProtocolBufferException {
|
||||||
SerializedMessage result = new SerializedMessage();
|
SerializedMessage result = new SerializedMessage();
|
||||||
nativeGetProto(packet.getNativeHandle(), result);
|
nativeGetProto(packet.getNativeHandle(), result);
|
||||||
return ProtoUtil.unpack(result, clazz);
|
return ProtoUtil.unpack(result, defaultInstance);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated {@link #getProto(Packet, MessageLite)} is safer to use in obfuscated builds.
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public static <T extends MessageLite> T getProto(final Packet packet, Class<T> clazz)
|
||||||
|
throws InvalidProtocolBufferException {
|
||||||
|
return getProto(packet, Internal.getDefaultInstance(clazz));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static short[] getInt16Vector(final Packet packet) {
|
public static short[] getInt16Vector(final Packet packet) {
|
||||||
|
@ -162,6 +172,13 @@ public final class PacketGetter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static <T extends MessageLite> List<T> getProtoVector(
|
||||||
|
final Packet packet, T defaultInstance) {
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
Parser<T> parser = (Parser<T>) defaultInstance.getParserForType();
|
||||||
|
return getProtoVector(packet, parser);
|
||||||
|
}
|
||||||
|
|
||||||
public static int getImageWidth(final Packet packet) {
|
public static int getImageWidth(final Packet packet) {
|
||||||
return nativeGetImageWidth(packet.getNativeHandle());
|
return nativeGetImageWidth(packet.getNativeHandle());
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
package com.google.mediapipe.framework;
|
package com.google.mediapipe.framework;
|
||||||
|
|
||||||
import com.google.protobuf.ExtensionRegistryLite;
|
import com.google.protobuf.ExtensionRegistryLite;
|
||||||
import com.google.protobuf.Internal;
|
|
||||||
import com.google.protobuf.InvalidProtocolBufferException;
|
import com.google.protobuf.InvalidProtocolBufferException;
|
||||||
import com.google.protobuf.MessageLite;
|
import com.google.protobuf.MessageLite;
|
||||||
import java.util.NoSuchElementException;
|
import java.util.NoSuchElementException;
|
||||||
|
@ -52,10 +51,8 @@ public final class ProtoUtil {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Deserializes a MessageLite from a SerializedMessage object. */
|
/** Deserializes a MessageLite from a SerializedMessage object. */
|
||||||
public static <T extends MessageLite> T unpack(
|
public static <T extends MessageLite> T unpack(SerializedMessage serialized, T defaultInstance)
|
||||||
SerializedMessage serialized, java.lang.Class<T> clazz)
|
|
||||||
throws InvalidProtocolBufferException {
|
throws InvalidProtocolBufferException {
|
||||||
T defaultInstance = Internal.getDefaultInstance(clazz);
|
|
||||||
String expectedType = ProtoUtil.getTypeName(defaultInstance.getClass());
|
String expectedType = ProtoUtil.getTypeName(defaultInstance.getClass());
|
||||||
if (!serialized.typeName.equals(expectedType)) {
|
if (!serialized.typeName.equals(expectedType)) {
|
||||||
throw new InvalidProtocolBufferException(
|
throw new InvalidProtocolBufferException(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user