Project import generated by Copybara.
GitOrigin-RevId: b66251317fbebfbb8e1f2ddc64ea5da84bceb7e5
This commit is contained in:
parent
7fb37c80e8
commit
4a20e9909d
|
@ -22,7 +22,6 @@
|
|||
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/util/tflite/config.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
|
@ -53,11 +52,9 @@ class InferenceCalculatorGlImpl
|
|||
private:
|
||||
absl::Status ReadGpuCaches();
|
||||
absl::Status SaveGpuCaches();
|
||||
absl::Status InitInterpreter(CalculatorContext* cc);
|
||||
absl::Status LoadDelegate(CalculatorContext* cc,
|
||||
tflite::InterpreterBuilder* interpreter_builder);
|
||||
absl::Status BindBuffersToTensors();
|
||||
absl::Status AllocateTensors();
|
||||
absl::Status LoadModel(CalculatorContext* cc);
|
||||
absl::Status LoadDelegate(CalculatorContext* cc);
|
||||
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
|
||||
absl::Status InitTFLiteGPURunner(CalculatorContext* cc);
|
||||
|
||||
// TfLite requires us to keep the model alive as long as the interpreter is.
|
||||
|
@ -140,11 +137,17 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
|
|||
#endif // MEDIAPIPE_ANDROID
|
||||
}
|
||||
|
||||
// When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner
|
||||
// for everything.
|
||||
if (!use_advanced_gpu_api_) {
|
||||
MP_RETURN_IF_ERROR(LoadModel(cc));
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
MP_RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
|
||||
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc)
|
||||
: InitInterpreter(cc);
|
||||
: LoadDelegateAndAllocateTensors(cc);
|
||||
}));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -289,6 +292,9 @@ absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() {
|
|||
|
||||
absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
|
||||
CalculatorContext* cc) {
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
||||
const auto& model = *model_packet_.Get();
|
||||
|
||||
// Create runner
|
||||
tflite::gpu::InferenceOptions options;
|
||||
options.priority1 = allow_precision_loss_
|
||||
|
@ -326,12 +332,17 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
|
|||
break;
|
||||
}
|
||||
}
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
||||
const auto& model = *model_packet_.Get();
|
||||
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
|
||||
const auto& op_resolver = op_resolver_packet.Get();
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
|
||||
model, op_resolver, /*allow_quant_ops=*/true));
|
||||
if (kSideInOpResolver(cc).IsConnected()) {
|
||||
const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get();
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
|
||||
model, op_resolver, /*allow_quant_ops=*/true));
|
||||
} else {
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||
kSideInCustomOpResolver(cc).GetOr(
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
|
||||
model, op_resolver, /*allow_quant_ops=*/true));
|
||||
}
|
||||
|
||||
// Create and bind OpenGL buffers for outputs.
|
||||
// The buffers are created once and their ids are passed to calculator outputs
|
||||
|
@ -350,27 +361,35 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::InitInterpreter(CalculatorContext* cc) {
|
||||
absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
||||
const auto& model = *model_packet_.Get();
|
||||
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
|
||||
const auto& op_resolver = op_resolver_packet.Get();
|
||||
tflite::InterpreterBuilder interpreter_builder(model, op_resolver);
|
||||
MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder));
|
||||
if (kSideInOpResolver(cc).IsConnected()) {
|
||||
const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get();
|
||||
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
|
||||
} else {
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||
kSideInCustomOpResolver(cc).GetOr(
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
|
||||
}
|
||||
RET_CHECK(interpreter_);
|
||||
|
||||
#if defined(__EMSCRIPTEN__)
|
||||
interpreter_builder.SetNumThreads(1);
|
||||
interpreter_->SetNumThreads(1);
|
||||
#else
|
||||
interpreter_builder.SetNumThreads(
|
||||
interpreter_->SetNumThreads(
|
||||
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
|
||||
#endif // __EMSCRIPTEN__
|
||||
RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk);
|
||||
RET_CHECK(interpreter_);
|
||||
MP_RETURN_IF_ERROR(BindBuffersToTensors());
|
||||
MP_RETURN_IF_ERROR(AllocateTensors());
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::AllocateTensors() {
|
||||
absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors(
|
||||
CalculatorContext* cc) {
|
||||
MP_RETURN_IF_ERROR(LoadDelegate(cc));
|
||||
|
||||
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
|
||||
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
// TODO: Support quantized tensors.
|
||||
RET_CHECK_NE(
|
||||
|
@ -379,8 +398,7 @@ absl::Status InferenceCalculatorGlImpl::AllocateTensors() {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::LoadDelegate(
|
||||
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
|
||||
absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
|
||||
// Configure and create the delegate.
|
||||
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
||||
options.compile_options.precision_loss_allowed =
|
||||
|
@ -391,11 +409,7 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegate(
|
|||
options.compile_options.inline_parameters = 1;
|
||||
delegate_ = TfLiteDelegatePtr(TfLiteGpuDelegateCreate(&options),
|
||||
&TfLiteGpuDelegateDelete);
|
||||
interpreter_builder->AddDelegate(delegate_.get());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::BindBuffersToTensors() {
|
||||
// Get input image sizes.
|
||||
const auto& input_indices = interpreter_->inputs();
|
||||
for (int i = 0; i < input_indices.size(); ++i) {
|
||||
|
@ -427,6 +441,11 @@ absl::Status InferenceCalculatorGlImpl::BindBuffersToTensors() {
|
|||
output_indices[i]),
|
||||
kTfLiteOk);
|
||||
}
|
||||
|
||||
// Must call this last.
|
||||
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
|
||||
kTfLiteOk);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -180,7 +180,7 @@ class Packet {
|
|||
// Returns an error if the packet does not contain data of type T.
|
||||
template <typename T>
|
||||
absl::Status ValidateAsType() const {
|
||||
return ValidateAsType(tool::TypeId<T>());
|
||||
return ValidateAsType(tool::TypeInfo::Get<T>());
|
||||
}
|
||||
|
||||
// Returns an error if the packet is not an instance of
|
||||
|
@ -428,7 +428,7 @@ StatusOr<std::vector<const proto_ns::MessageLite*>>
|
|||
ConvertToVectorOfProtoMessageLitePtrs(const T* data,
|
||||
/*is_proto_vector=*/std::false_type) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"The Packet stores \"", tool::TypeId<T>().name(), "\"",
|
||||
"The Packet stores \"", tool::TypeInfo::Get<T>().name(), "\"",
|
||||
"which is not convertible to vector<proto_ns::MessageLite*>."));
|
||||
}
|
||||
|
||||
|
@ -510,7 +510,9 @@ class Holder : public HolderBase {
|
|||
HolderSupport<T>::EnsureStaticInit();
|
||||
return *ptr_;
|
||||
}
|
||||
const tool::TypeInfo& GetTypeInfo() const final { return tool::TypeId<T>(); }
|
||||
const tool::TypeInfo& GetTypeInfo() const final {
|
||||
return tool::TypeInfo::Get<T>();
|
||||
}
|
||||
// Releases the underlying data pointer and transfers the ownership to a
|
||||
// unique pointer.
|
||||
// This method is dangerous and is only used by Packet::Consume() if the
|
||||
|
|
|
@ -259,14 +259,14 @@ absl::Status ValidatePacketTypeSet(const PacketTypeSet& packet_type_set);
|
|||
|
||||
template <typename T>
|
||||
PacketType& PacketType::Set() {
|
||||
type_spec_ = &tool::TypeId<T>();
|
||||
type_spec_ = &tool::TypeInfo::Get<T>();
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename... T>
|
||||
PacketType& PacketType::SetOneOf() {
|
||||
static const NoDestructor<std::vector<const tool::TypeInfo*>> types{
|
||||
{&tool::TypeId<T>()...}};
|
||||
{&tool::TypeInfo::Get<T>()...}};
|
||||
static const NoDestructor<std::string> name{TypeNameForOneOf(*types)};
|
||||
type_spec_ = MultiType{*types, &*name};
|
||||
return *this;
|
||||
|
|
|
@ -761,9 +761,11 @@ cc_library(
|
|||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/port:advanced_proto",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/cleanup",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
|
|
|
@ -58,14 +58,14 @@ class TypeMap {
|
|||
public:
|
||||
template <class T>
|
||||
bool Has() const {
|
||||
return content_.count(TypeId<T>()) > 0;
|
||||
return content_.count(TypeInfo::Get<T>()) > 0;
|
||||
}
|
||||
template <class T>
|
||||
T* Get() const {
|
||||
if (!Has<T>()) {
|
||||
content_[TypeId<T>()] = std::make_shared<T>();
|
||||
content_[TypeInfo::Get<T>()] = std::make_shared<T>();
|
||||
}
|
||||
return static_cast<T*>(content_[TypeId<T>()].get());
|
||||
return static_cast<T*>(content_[TypeInfo::Get<T>()].get());
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/cleanup/cleanup.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
|
@ -33,6 +34,7 @@
|
|||
#include "mediapipe/framework/formats/image_format.pb.h"
|
||||
#include "mediapipe/framework/port/advanced_proto_inc.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/proto_ns.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
@ -208,6 +210,27 @@ bool CompareImageFrames(const ImageFrame& image1, const ImageFrame& image2,
|
|||
return false;
|
||||
}
|
||||
|
||||
absl::Status CompareAndSaveImageOutput(
|
||||
absl::string_view golden_image_path, const ImageFrame& actual,
|
||||
const ImageFrameComparisonOptions& options) {
|
||||
ASSIGN_OR_RETURN(auto output_img_path, SavePngTestOutput(actual, "output"));
|
||||
|
||||
auto expected = LoadTestImage(GetTestFilePath(golden_image_path));
|
||||
if (!expected.ok()) {
|
||||
return expected.status();
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto expected_img_path,
|
||||
SavePngTestOutput(**expected, "expected"));
|
||||
|
||||
std::unique_ptr<ImageFrame> diff_img;
|
||||
auto status = CompareImageFrames(**expected, actual, options.max_color_diff,
|
||||
options.max_alpha_diff, options.max_avg_diff,
|
||||
diff_img);
|
||||
ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff"));
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
std::string GetTestRootDir() {
|
||||
return file::JoinPath(std::getenv("TEST_SRCDIR"), "mediapipe");
|
||||
}
|
||||
|
@ -275,6 +298,23 @@ std::unique_ptr<ImageFrame> LoadTestPng(absl::string_view path,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// Write an ImageFrame as PNG to the test undeclared outputs directory.
|
||||
// The image's name will contain the given prefix and a timestamp.
|
||||
// Returns the path to the output if successful.
|
||||
absl::StatusOr<std::string> SavePngTestOutput(
|
||||
const mediapipe::ImageFrame& image, absl::string_view prefix) {
|
||||
std::string now_string = absl::FormatTime(absl::Now());
|
||||
std::string output_relative_path =
|
||||
absl::StrCat(prefix, "_", now_string, ".png");
|
||||
std::string output_full_path =
|
||||
file::JoinPath(GetTestOutputsDir(), output_relative_path);
|
||||
RET_CHECK(stbi_write_png(output_full_path.c_str(), image.Width(),
|
||||
image.Height(), image.NumberOfChannels(),
|
||||
image.PixelData(), image.WidthStep()))
|
||||
<< " path: " << output_full_path;
|
||||
return output_relative_path;
|
||||
}
|
||||
|
||||
bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path) {
|
||||
int fd = open(path.c_str(), O_RDONLY);
|
||||
if (fd == -1) {
|
||||
|
|
|
@ -22,20 +22,33 @@
|
|||
namespace mediapipe {
|
||||
using mediapipe::CalculatorGraphConfig;
|
||||
|
||||
struct ImageFrameComparisonOptions {
|
||||
// NOTE: these values are not normalized: use a value from 0 to 2^8-1
|
||||
// for 8-bit data and a value from 0 to 2^16-1 for 16-bit data.
|
||||
// Although these members are declared as floats,, all uint8/uint16
|
||||
// values are exactly representable. (2^24 + 1 is the first non-representable
|
||||
// positive integral value.)
|
||||
|
||||
// Maximum value difference allowed for non-alpha channels.
|
||||
float max_color_diff;
|
||||
// Maximum value difference allowed for alpha channel (if present).
|
||||
float max_alpha_diff;
|
||||
// Maximum difference for all channels, averaged across all pixels.
|
||||
float max_avg_diff;
|
||||
};
|
||||
|
||||
// Compares an output image with a golden file. Saves the output and difference
|
||||
// to the undeclared test outputs.
|
||||
// Returns ok if they are equal within the tolerances specified in options.
|
||||
absl::Status CompareAndSaveImageOutput(
|
||||
absl::string_view golden_image_path, const ImageFrame& actual,
|
||||
const ImageFrameComparisonOptions& options);
|
||||
|
||||
// Checks if two image frames are equal within the specified tolerance.
|
||||
// image1 and image2 may be of different-but-compatible image formats (e.g.,
|
||||
// SRGB and SRGBA); in that case, only the channels available in both are
|
||||
// compared.
|
||||
// max_color_diff applies to the first 3 channels; i.e., R, G, B for sRGB and
|
||||
// sRGBA, and the single gray channel for GRAY8 and GRAY16. It is the maximum
|
||||
// pixel color value difference allowed; i.e., a value from 0 to 2^8-1 for 8-bit
|
||||
// data and a value from 0 to 2^16-1 for 16-bit data.
|
||||
// max_alpha_diff applies to the 4th (alpha) channel only, if present.
|
||||
// max_avg_diff applies to all channels, normalized across all pixels.
|
||||
//
|
||||
// Note: Although max_color_diff and max_alpha_diff are floats, all uint8/uint16
|
||||
// values are exactly representable. (2^24 + 1 is the first non-representable
|
||||
// positive integral value.)
|
||||
// The diff arguments are as in ImageFrameComparisonOptions.
|
||||
absl::Status CompareImageFrames(const ImageFrame& image1,
|
||||
const ImageFrame& image2,
|
||||
const float max_color_diff,
|
||||
|
@ -77,6 +90,13 @@ absl::StatusOr<std::unique_ptr<ImageFrame>> LoadTestImage(
|
|||
std::unique_ptr<ImageFrame> LoadTestPng(
|
||||
absl::string_view path, ImageFormat::Format format = ImageFormat::SRGBA);
|
||||
|
||||
// Write an ImageFrame as PNG to the test undeclared outputs directory.
|
||||
// The image's name will contain the given prefix and a timestamp.
|
||||
// If successful, returns the path to the output file relative to the output
|
||||
// directory.
|
||||
absl::StatusOr<std::string> SavePngTestOutput(
|
||||
const mediapipe::ImageFrame& image, absl::string_view prefix);
|
||||
|
||||
// Returns the luminance image of |original_image|.
|
||||
// The format of |original_image| must be sRGB or sRGBA.
|
||||
std::unique_ptr<ImageFrame> GenerateLuminanceImage(
|
||||
|
|
|
@ -78,12 +78,6 @@ class TypeIndex {
|
|||
const TypeInfo& info_;
|
||||
};
|
||||
|
||||
// Returns a unique identifier for type T.
|
||||
template <typename T>
|
||||
const TypeInfo& TypeId() {
|
||||
return TypeInfo::Get<T>();
|
||||
}
|
||||
|
||||
// Helper method that returns a hash code of the given type. This allows for
|
||||
// typeid testing across multiple binaries, unlike FastTypeId which used a
|
||||
// memory location that only works within the same binary. Moreover, we use this
|
||||
|
@ -94,7 +88,7 @@ const TypeInfo& TypeId() {
|
|||
// as much as possible.
|
||||
template <typename T>
|
||||
size_t GetTypeHash() {
|
||||
return TypeId<T>().hash_code();
|
||||
return TypeInfo::Get<T>().hash_code();
|
||||
}
|
||||
|
||||
} // namespace tool
|
||||
|
|
|
@ -386,7 +386,7 @@ inline std::string MediaPipeTypeStringOrDemangled(
|
|||
|
||||
template <typename T>
|
||||
std::string MediaPipeTypeStringOrDemangled() {
|
||||
return MediaPipeTypeStringOrDemangled(tool::TypeId<T>());
|
||||
return MediaPipeTypeStringOrDemangled(tool::TypeInfo::Get<T>());
|
||||
}
|
||||
|
||||
// Returns type hash id of type identified by type_string or NULL if not
|
||||
|
|
|
@ -26,22 +26,6 @@
|
|||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
// Write an ImageFrame as PNG to the test undeclared outputs directory.
|
||||
// The image's name will contain the given prefix and a timestamp.
|
||||
// Returns the path to the output if successful.
|
||||
std::string SavePngImage(const mediapipe::ImageFrame& image,
|
||||
absl::string_view prefix) {
|
||||
std::string output_dir = mediapipe::GetTestOutputsDir();
|
||||
std::string now_string = absl::FormatTime(absl::Now());
|
||||
std::string out_file_path =
|
||||
absl::StrCat(output_dir, "/", prefix, "_", now_string, ".png");
|
||||
EXPECT_TRUE(stbi_write_png(out_file_path.c_str(), image.Width(),
|
||||
image.Height(), image.NumberOfChannels(),
|
||||
image.PixelData(), image.WidthStep()))
|
||||
<< " path: " << out_file_path;
|
||||
return out_file_path;
|
||||
}
|
||||
|
||||
void FillImageFrameRGBA(ImageFrame& image, uint8 r, uint8 g, uint8 b, uint8 a) {
|
||||
auto* data = image.MutablePixelData();
|
||||
for (int y = 0; y < image.Height(); ++y) {
|
||||
|
@ -143,8 +127,8 @@ TEST_F(GpuBufferTest, GlTextureView) {
|
|||
FillImageFrameRGBA(red, 255, 0, 0, 255);
|
||||
|
||||
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0));
|
||||
SavePngImage(red, "gltv_red_gold");
|
||||
SavePngImage(*view, "gltv_red_view");
|
||||
MP_EXPECT_OK(SavePngTestOutput(red, "gltv_red_gold"));
|
||||
MP_EXPECT_OK(SavePngTestOutput(*view, "gltv_red_view"));
|
||||
}
|
||||
|
||||
TEST_F(GpuBufferTest, ImageFrame) {
|
||||
|
@ -178,8 +162,8 @@ TEST_F(GpuBufferTest, ImageFrame) {
|
|||
FillImageFrameRGBA(red, 255, 0, 0, 255);
|
||||
|
||||
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0));
|
||||
SavePngImage(red, "if_red_gold");
|
||||
SavePngImage(*view, "if_red_view");
|
||||
MP_EXPECT_OK(SavePngTestOutput(red, "if_red_gold"));
|
||||
MP_EXPECT_OK(SavePngTestOutput(*view, "if_red_view"));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -212,8 +196,8 @@ TEST_F(GpuBufferTest, Overwrite) {
|
|||
FillImageFrameRGBA(red, 255, 0, 0, 255);
|
||||
|
||||
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0));
|
||||
SavePngImage(red, "ow_red_gold");
|
||||
SavePngImage(*view, "ow_red_view");
|
||||
MP_EXPECT_OK(SavePngTestOutput(red, "ow_red_gold"));
|
||||
MP_EXPECT_OK(SavePngTestOutput(*view, "ow_red_view"));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -246,8 +230,8 @@ TEST_F(GpuBufferTest, Overwrite) {
|
|||
FillImageFrameRGBA(green, 0, 255, 0, 255);
|
||||
|
||||
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, green, 0.0, 0.0));
|
||||
SavePngImage(green, "ow_green_gold");
|
||||
SavePngImage(*view, "ow_green_view");
|
||||
MP_EXPECT_OK(SavePngTestOutput(green, "ow_green_gold"));
|
||||
MP_EXPECT_OK(SavePngTestOutput(*view, "ow_green_view"));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -256,8 +240,8 @@ TEST_F(GpuBufferTest, Overwrite) {
|
|||
FillImageFrameRGBA(blue, 0, 0, 255, 255);
|
||||
|
||||
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, blue, 0.0, 0.0));
|
||||
SavePngImage(blue, "ow_blue_gold");
|
||||
SavePngImage(*view, "ow_blue_view");
|
||||
MP_EXPECT_OK(SavePngTestOutput(blue, "ow_blue_gold"));
|
||||
MP_EXPECT_OK(SavePngTestOutput(*view, "ow_blue_view"));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ package com.google.mediapipe.framework;
|
|||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.flogger.FluentLogger;
|
||||
import com.google.mediapipe.framework.ProtoUtil.SerializedMessage;
|
||||
import com.google.protobuf.Internal;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import com.google.protobuf.MessageLite;
|
||||
import com.google.protobuf.Parser;
|
||||
|
@ -119,11 +120,20 @@ public final class PacketGetter {
|
|||
return nativeGetProtoBytes(packet.getNativeHandle());
|
||||
}
|
||||
|
||||
public static <T extends MessageLite> T getProto(final Packet packet, Class<T> clazz)
|
||||
public static <T extends MessageLite> T getProto(final Packet packet, T defaultInstance)
|
||||
throws InvalidProtocolBufferException {
|
||||
SerializedMessage result = new SerializedMessage();
|
||||
nativeGetProto(packet.getNativeHandle(), result);
|
||||
return ProtoUtil.unpack(result, clazz);
|
||||
return ProtoUtil.unpack(result, defaultInstance);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated {@link #getProto(Packet, MessageLite)} is safer to use in obfuscated builds.
|
||||
*/
|
||||
@Deprecated
|
||||
public static <T extends MessageLite> T getProto(final Packet packet, Class<T> clazz)
|
||||
throws InvalidProtocolBufferException {
|
||||
return getProto(packet, Internal.getDefaultInstance(clazz));
|
||||
}
|
||||
|
||||
public static short[] getInt16Vector(final Packet packet) {
|
||||
|
@ -162,6 +172,13 @@ public final class PacketGetter {
|
|||
}
|
||||
}
|
||||
|
||||
public static <T extends MessageLite> List<T> getProtoVector(
|
||||
final Packet packet, T defaultInstance) {
|
||||
@SuppressWarnings("unchecked")
|
||||
Parser<T> parser = (Parser<T>) defaultInstance.getParserForType();
|
||||
return getProtoVector(packet, parser);
|
||||
}
|
||||
|
||||
public static int getImageWidth(final Packet packet) {
|
||||
return nativeGetImageWidth(packet.getNativeHandle());
|
||||
}
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
package com.google.mediapipe.framework;
|
||||
|
||||
import com.google.protobuf.ExtensionRegistryLite;
|
||||
import com.google.protobuf.Internal;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import com.google.protobuf.MessageLite;
|
||||
import java.util.NoSuchElementException;
|
||||
|
@ -52,10 +51,8 @@ public final class ProtoUtil {
|
|||
}
|
||||
|
||||
/** Deserializes a MessageLite from a SerializedMessage object. */
|
||||
public static <T extends MessageLite> T unpack(
|
||||
SerializedMessage serialized, java.lang.Class<T> clazz)
|
||||
public static <T extends MessageLite> T unpack(SerializedMessage serialized, T defaultInstance)
|
||||
throws InvalidProtocolBufferException {
|
||||
T defaultInstance = Internal.getDefaultInstance(clazz);
|
||||
String expectedType = ProtoUtil.getTypeName(defaultInstance.getClass());
|
||||
if (!serialized.typeName.equals(expectedType)) {
|
||||
throw new InvalidProtocolBufferException(
|
||||
|
|
Loading…
Reference in New Issue
Block a user