Project import generated by Copybara.

GitOrigin-RevId: b66251317fbebfbb8e1f2ddc64ea5da84bceb7e5
This commit is contained in:
MediaPipe Team 2022-05-06 14:39:20 -07:00 committed by jqtang
parent 7fb37c80e8
commit 4a20e9909d
12 changed files with 164 additions and 89 deletions

View File

@ -22,7 +22,6 @@
#include "mediapipe/calculators/tensor/inference_calculator.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/util/tflite/config.h"
#include "tensorflow/lite/interpreter_builder.h"
#if MEDIAPIPE_TFLITE_GL_INFERENCE
#include "mediapipe/gpu/gl_calculator_helper.h"
@ -53,11 +52,9 @@ class InferenceCalculatorGlImpl
private:
absl::Status ReadGpuCaches();
absl::Status SaveGpuCaches();
absl::Status InitInterpreter(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc,
tflite::InterpreterBuilder* interpreter_builder);
absl::Status BindBuffersToTensors();
absl::Status AllocateTensors();
absl::Status LoadModel(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc);
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
absl::Status InitTFLiteGPURunner(CalculatorContext* cc);
// TfLite requires us to keep the model alive as long as the interpreter is.
@ -140,11 +137,17 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
#endif // MEDIAPIPE_ANDROID
}
// When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner
// for everything.
if (!use_advanced_gpu_api_) {
MP_RETURN_IF_ERROR(LoadModel(cc));
}
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc)
: InitInterpreter(cc);
: LoadDelegateAndAllocateTensors(cc);
}));
return absl::OkStatus();
}
@ -289,6 +292,9 @@ absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() {
absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get();
// Create runner
tflite::gpu::InferenceOptions options;
options.priority1 = allow_precision_loss_
@ -326,12 +332,17 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
break;
}
}
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get();
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
const auto& op_resolver = op_resolver_packet.Get();
if (kSideInOpResolver(cc).IsConnected()) {
const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get();
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
model, op_resolver, /*allow_quant_ops=*/true));
} else {
tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
model, op_resolver, /*allow_quant_ops=*/true));
}
// Create and bind OpenGL buffers for outputs.
// The buffers are created once and their ids are passed to calculator outputs
@ -350,27 +361,35 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::InitInterpreter(CalculatorContext* cc) {
absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get();
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
const auto& op_resolver = op_resolver_packet.Get();
tflite::InterpreterBuilder interpreter_builder(model, op_resolver);
MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder));
if (kSideInOpResolver(cc).IsConnected()) {
const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get();
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
} else {
tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
}
RET_CHECK(interpreter_);
#if defined(__EMSCRIPTEN__)
interpreter_builder.SetNumThreads(1);
interpreter_->SetNumThreads(1);
#else
interpreter_builder.SetNumThreads(
interpreter_->SetNumThreads(
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
#endif // __EMSCRIPTEN__
RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk);
RET_CHECK(interpreter_);
MP_RETURN_IF_ERROR(BindBuffersToTensors());
MP_RETURN_IF_ERROR(AllocateTensors());
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::AllocateTensors() {
absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors(
CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadDelegate(cc));
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
// TODO: Support quantized tensors.
RET_CHECK_NE(
@ -379,8 +398,7 @@ absl::Status InferenceCalculatorGlImpl::AllocateTensors() {
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::LoadDelegate(
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
// Configure and create the delegate.
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
options.compile_options.precision_loss_allowed =
@ -391,11 +409,7 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegate(
options.compile_options.inline_parameters = 1;
delegate_ = TfLiteDelegatePtr(TfLiteGpuDelegateCreate(&options),
&TfLiteGpuDelegateDelete);
interpreter_builder->AddDelegate(delegate_.get());
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::BindBuffersToTensors() {
// Get input image sizes.
const auto& input_indices = interpreter_->inputs();
for (int i = 0; i < input_indices.size(); ++i) {
@ -427,6 +441,11 @@ absl::Status InferenceCalculatorGlImpl::BindBuffersToTensors() {
output_indices[i]),
kTfLiteOk);
}
// Must call this last.
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk);
return absl::OkStatus();
}

View File

@ -180,7 +180,7 @@ class Packet {
// Returns an error if the packet does not contain data of type T.
template <typename T>
absl::Status ValidateAsType() const {
return ValidateAsType(tool::TypeId<T>());
return ValidateAsType(tool::TypeInfo::Get<T>());
}
// Returns an error if the packet is not an instance of
@ -428,7 +428,7 @@ StatusOr<std::vector<const proto_ns::MessageLite*>>
ConvertToVectorOfProtoMessageLitePtrs(const T* data,
/*is_proto_vector=*/std::false_type) {
return absl::InvalidArgumentError(absl::StrCat(
"The Packet stores \"", tool::TypeId<T>().name(), "\"",
"The Packet stores \"", tool::TypeInfo::Get<T>().name(), "\"",
"which is not convertible to vector<proto_ns::MessageLite*>."));
}
@ -510,7 +510,9 @@ class Holder : public HolderBase {
HolderSupport<T>::EnsureStaticInit();
return *ptr_;
}
const tool::TypeInfo& GetTypeInfo() const final { return tool::TypeId<T>(); }
const tool::TypeInfo& GetTypeInfo() const final {
return tool::TypeInfo::Get<T>();
}
// Releases the underlying data pointer and transfers the ownership to a
// unique pointer.
// This method is dangerous and is only used by Packet::Consume() if the

View File

@ -259,14 +259,14 @@ absl::Status ValidatePacketTypeSet(const PacketTypeSet& packet_type_set);
template <typename T>
PacketType& PacketType::Set() {
type_spec_ = &tool::TypeId<T>();
type_spec_ = &tool::TypeInfo::Get<T>();
return *this;
}
template <typename... T>
PacketType& PacketType::SetOneOf() {
static const NoDestructor<std::vector<const tool::TypeInfo*>> types{
{&tool::TypeId<T>()...}};
{&tool::TypeInfo::Get<T>()...}};
static const NoDestructor<std::string> name{TypeNameForOneOf(*types)};
type_spec_ = MultiType{*types, &*name};
return *this;

View File

@ -761,9 +761,11 @@ cc_library(
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",

View File

@ -58,14 +58,14 @@ class TypeMap {
public:
template <class T>
bool Has() const {
return content_.count(TypeId<T>()) > 0;
return content_.count(TypeInfo::Get<T>()) > 0;
}
template <class T>
T* Get() const {
if (!Has<T>()) {
content_[TypeId<T>()] = std::make_shared<T>();
content_[TypeInfo::Get<T>()] = std::make_shared<T>();
}
return static_cast<T*>(content_[TypeId<T>()].get());
return static_cast<T*>(content_[TypeInfo::Get<T>()].get());
}
private:

View File

@ -20,6 +20,7 @@
#include <memory>
#include <string>
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
@ -33,6 +34,7 @@
#include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/port/advanced_proto_inc.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/ret_check.h"
@ -208,6 +210,27 @@ bool CompareImageFrames(const ImageFrame& image1, const ImageFrame& image2,
return false;
}
absl::Status CompareAndSaveImageOutput(
absl::string_view golden_image_path, const ImageFrame& actual,
const ImageFrameComparisonOptions& options) {
ASSIGN_OR_RETURN(auto output_img_path, SavePngTestOutput(actual, "output"));
auto expected = LoadTestImage(GetTestFilePath(golden_image_path));
if (!expected.ok()) {
return expected.status();
}
ASSIGN_OR_RETURN(auto expected_img_path,
SavePngTestOutput(**expected, "expected"));
std::unique_ptr<ImageFrame> diff_img;
auto status = CompareImageFrames(**expected, actual, options.max_color_diff,
options.max_alpha_diff, options.max_avg_diff,
diff_img);
ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff"));
return status;
}
std::string GetTestRootDir() {
return file::JoinPath(std::getenv("TEST_SRCDIR"), "mediapipe");
}
@ -275,6 +298,23 @@ std::unique_ptr<ImageFrame> LoadTestPng(absl::string_view path,
return nullptr;
}
// Write an ImageFrame as PNG to the test undeclared outputs directory.
// The image's name will contain the given prefix and a timestamp.
// Returns the path to the output if successful.
absl::StatusOr<std::string> SavePngTestOutput(
const mediapipe::ImageFrame& image, absl::string_view prefix) {
std::string now_string = absl::FormatTime(absl::Now());
std::string output_relative_path =
absl::StrCat(prefix, "_", now_string, ".png");
std::string output_full_path =
file::JoinPath(GetTestOutputsDir(), output_relative_path);
RET_CHECK(stbi_write_png(output_full_path.c_str(), image.Width(),
image.Height(), image.NumberOfChannels(),
image.PixelData(), image.WidthStep()))
<< " path: " << output_full_path;
return output_relative_path;
}
bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path) {
int fd = open(path.c_str(), O_RDONLY);
if (fd == -1) {

View File

@ -22,20 +22,33 @@
namespace mediapipe {
using mediapipe::CalculatorGraphConfig;
struct ImageFrameComparisonOptions {
// NOTE: these values are not normalized: use a value from 0 to 2^8-1
// for 8-bit data and a value from 0 to 2^16-1 for 16-bit data.
// Although these members are declared as floats,, all uint8/uint16
// values are exactly representable. (2^24 + 1 is the first non-representable
// positive integral value.)
// Maximum value difference allowed for non-alpha channels.
float max_color_diff;
// Maximum value difference allowed for alpha channel (if present).
float max_alpha_diff;
// Maximum difference for all channels, averaged across all pixels.
float max_avg_diff;
};
// Compares an output image with a golden file. Saves the output and difference
// to the undeclared test outputs.
// Returns ok if they are equal within the tolerances specified in options.
absl::Status CompareAndSaveImageOutput(
absl::string_view golden_image_path, const ImageFrame& actual,
const ImageFrameComparisonOptions& options);
// Checks if two image frames are equal within the specified tolerance.
// image1 and image2 may be of different-but-compatible image formats (e.g.,
// SRGB and SRGBA); in that case, only the channels available in both are
// compared.
// max_color_diff applies to the first 3 channels; i.e., R, G, B for sRGB and
// sRGBA, and the single gray channel for GRAY8 and GRAY16. It is the maximum
// pixel color value difference allowed; i.e., a value from 0 to 2^8-1 for 8-bit
// data and a value from 0 to 2^16-1 for 16-bit data.
// max_alpha_diff applies to the 4th (alpha) channel only, if present.
// max_avg_diff applies to all channels, normalized across all pixels.
//
// Note: Although max_color_diff and max_alpha_diff are floats, all uint8/uint16
// values are exactly representable. (2^24 + 1 is the first non-representable
// positive integral value.)
// The diff arguments are as in ImageFrameComparisonOptions.
absl::Status CompareImageFrames(const ImageFrame& image1,
const ImageFrame& image2,
const float max_color_diff,
@ -77,6 +90,13 @@ absl::StatusOr<std::unique_ptr<ImageFrame>> LoadTestImage(
std::unique_ptr<ImageFrame> LoadTestPng(
absl::string_view path, ImageFormat::Format format = ImageFormat::SRGBA);
// Write an ImageFrame as PNG to the test undeclared outputs directory.
// The image's name will contain the given prefix and a timestamp.
// If successful, returns the path to the output file relative to the output
// directory.
absl::StatusOr<std::string> SavePngTestOutput(
const mediapipe::ImageFrame& image, absl::string_view prefix);
// Returns the luminance image of |original_image|.
// The format of |original_image| must be sRGB or sRGBA.
std::unique_ptr<ImageFrame> GenerateLuminanceImage(

View File

@ -78,12 +78,6 @@ class TypeIndex {
const TypeInfo& info_;
};
// Returns a unique identifier for type T.
template <typename T>
const TypeInfo& TypeId() {
return TypeInfo::Get<T>();
}
// Helper method that returns a hash code of the given type. This allows for
// typeid testing across multiple binaries, unlike FastTypeId which used a
// memory location that only works within the same binary. Moreover, we use this
@ -94,7 +88,7 @@ const TypeInfo& TypeId() {
// as much as possible.
template <typename T>
size_t GetTypeHash() {
return TypeId<T>().hash_code();
return TypeInfo::Get<T>().hash_code();
}
} // namespace tool

View File

@ -386,7 +386,7 @@ inline std::string MediaPipeTypeStringOrDemangled(
template <typename T>
std::string MediaPipeTypeStringOrDemangled() {
return MediaPipeTypeStringOrDemangled(tool::TypeId<T>());
return MediaPipeTypeStringOrDemangled(tool::TypeInfo::Get<T>());
}
// Returns type hash id of type identified by type_string or NULL if not

View File

@ -26,22 +26,6 @@
namespace mediapipe {
namespace {
// Write an ImageFrame as PNG to the test undeclared outputs directory.
// The image's name will contain the given prefix and a timestamp.
// Returns the path to the output if successful.
std::string SavePngImage(const mediapipe::ImageFrame& image,
absl::string_view prefix) {
std::string output_dir = mediapipe::GetTestOutputsDir();
std::string now_string = absl::FormatTime(absl::Now());
std::string out_file_path =
absl::StrCat(output_dir, "/", prefix, "_", now_string, ".png");
EXPECT_TRUE(stbi_write_png(out_file_path.c_str(), image.Width(),
image.Height(), image.NumberOfChannels(),
image.PixelData(), image.WidthStep()))
<< " path: " << out_file_path;
return out_file_path;
}
void FillImageFrameRGBA(ImageFrame& image, uint8 r, uint8 g, uint8 b, uint8 a) {
auto* data = image.MutablePixelData();
for (int y = 0; y < image.Height(); ++y) {
@ -143,8 +127,8 @@ TEST_F(GpuBufferTest, GlTextureView) {
FillImageFrameRGBA(red, 255, 0, 0, 255);
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0));
SavePngImage(red, "gltv_red_gold");
SavePngImage(*view, "gltv_red_view");
MP_EXPECT_OK(SavePngTestOutput(red, "gltv_red_gold"));
MP_EXPECT_OK(SavePngTestOutput(*view, "gltv_red_view"));
}
TEST_F(GpuBufferTest, ImageFrame) {
@ -178,8 +162,8 @@ TEST_F(GpuBufferTest, ImageFrame) {
FillImageFrameRGBA(red, 255, 0, 0, 255);
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0));
SavePngImage(red, "if_red_gold");
SavePngImage(*view, "if_red_view");
MP_EXPECT_OK(SavePngTestOutput(red, "if_red_gold"));
MP_EXPECT_OK(SavePngTestOutput(*view, "if_red_view"));
}
}
@ -212,8 +196,8 @@ TEST_F(GpuBufferTest, Overwrite) {
FillImageFrameRGBA(red, 255, 0, 0, 255);
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0));
SavePngImage(red, "ow_red_gold");
SavePngImage(*view, "ow_red_view");
MP_EXPECT_OK(SavePngTestOutput(red, "ow_red_gold"));
MP_EXPECT_OK(SavePngTestOutput(*view, "ow_red_view"));
}
{
@ -246,8 +230,8 @@ TEST_F(GpuBufferTest, Overwrite) {
FillImageFrameRGBA(green, 0, 255, 0, 255);
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, green, 0.0, 0.0));
SavePngImage(green, "ow_green_gold");
SavePngImage(*view, "ow_green_view");
MP_EXPECT_OK(SavePngTestOutput(green, "ow_green_gold"));
MP_EXPECT_OK(SavePngTestOutput(*view, "ow_green_view"));
}
{
@ -256,8 +240,8 @@ TEST_F(GpuBufferTest, Overwrite) {
FillImageFrameRGBA(blue, 0, 0, 255, 255);
EXPECT_TRUE(mediapipe::CompareImageFrames(*view, blue, 0.0, 0.0));
SavePngImage(blue, "ow_blue_gold");
SavePngImage(*view, "ow_blue_view");
MP_EXPECT_OK(SavePngTestOutput(blue, "ow_blue_gold"));
MP_EXPECT_OK(SavePngTestOutput(*view, "ow_blue_view"));
}
}

View File

@ -17,6 +17,7 @@ package com.google.mediapipe.framework;
import com.google.common.base.Preconditions;
import com.google.common.flogger.FluentLogger;
import com.google.mediapipe.framework.ProtoUtil.SerializedMessage;
import com.google.protobuf.Internal;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.MessageLite;
import com.google.protobuf.Parser;
@ -119,11 +120,20 @@ public final class PacketGetter {
return nativeGetProtoBytes(packet.getNativeHandle());
}
public static <T extends MessageLite> T getProto(final Packet packet, Class<T> clazz)
public static <T extends MessageLite> T getProto(final Packet packet, T defaultInstance)
throws InvalidProtocolBufferException {
SerializedMessage result = new SerializedMessage();
nativeGetProto(packet.getNativeHandle(), result);
return ProtoUtil.unpack(result, clazz);
return ProtoUtil.unpack(result, defaultInstance);
}
/**
* @deprecated {@link #getProto(Packet, MessageLite)} is safer to use in obfuscated builds.
*/
@Deprecated
public static <T extends MessageLite> T getProto(final Packet packet, Class<T> clazz)
throws InvalidProtocolBufferException {
return getProto(packet, Internal.getDefaultInstance(clazz));
}
public static short[] getInt16Vector(final Packet packet) {
@ -162,6 +172,13 @@ public final class PacketGetter {
}
}
public static <T extends MessageLite> List<T> getProtoVector(
final Packet packet, T defaultInstance) {
@SuppressWarnings("unchecked")
Parser<T> parser = (Parser<T>) defaultInstance.getParserForType();
return getProtoVector(packet, parser);
}
public static int getImageWidth(final Packet packet) {
return nativeGetImageWidth(packet.getNativeHandle());
}

View File

@ -15,7 +15,6 @@
package com.google.mediapipe.framework;
import com.google.protobuf.ExtensionRegistryLite;
import com.google.protobuf.Internal;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.MessageLite;
import java.util.NoSuchElementException;
@ -52,10 +51,8 @@ public final class ProtoUtil {
}
/** Deserializes a MessageLite from a SerializedMessage object. */
public static <T extends MessageLite> T unpack(
SerializedMessage serialized, java.lang.Class<T> clazz)
public static <T extends MessageLite> T unpack(SerializedMessage serialized, T defaultInstance)
throws InvalidProtocolBufferException {
T defaultInstance = Internal.getDefaultInstance(clazz);
String expectedType = ProtoUtil.getTypeName(defaultInstance.getClass());
if (!serialized.typeName.equals(expectedType)) {
throw new InvalidProtocolBufferException(