Pulled changes from master

This commit is contained in:
Prianka Liz Kariat 2023-05-24 19:15:48 +05:30
commit 164eae8c16
78 changed files with 644 additions and 337 deletions

View File

@ -55,6 +55,10 @@ MEDIAPIPE_REGISTER_NODE(ConcatenateUInt64VectorCalculator);
typedef ConcatenateVectorCalculator<bool> ConcatenateBoolVectorCalculator; typedef ConcatenateVectorCalculator<bool> ConcatenateBoolVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateBoolVectorCalculator); MEDIAPIPE_REGISTER_NODE(ConcatenateBoolVectorCalculator);
typedef ConcatenateVectorCalculator<std::string>
ConcatenateStringVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateStringVectorCalculator);
// Example config: // Example config:
// node { // node {
// calculator: "ConcatenateTfLiteTensorVectorCalculator" // calculator: "ConcatenateTfLiteTensorVectorCalculator"

View File

@ -30,13 +30,15 @@ namespace mediapipe {
typedef ConcatenateVectorCalculator<int> TestConcatenateIntVectorCalculator; typedef ConcatenateVectorCalculator<int> TestConcatenateIntVectorCalculator;
MEDIAPIPE_REGISTER_NODE(TestConcatenateIntVectorCalculator); MEDIAPIPE_REGISTER_NODE(TestConcatenateIntVectorCalculator);
void AddInputVector(int index, const std::vector<int>& input, int64_t timestamp, template <typename T>
void AddInputVector(int index, const std::vector<T>& input, int64_t timestamp,
CalculatorRunner* runner) { CalculatorRunner* runner) {
runner->MutableInputs()->Index(index).packets.push_back( runner->MutableInputs()->Index(index).packets.push_back(
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp))); MakePacket<std::vector<T>>(input).At(Timestamp(timestamp)));
} }
void AddInputVectors(const std::vector<std::vector<int>>& inputs, template <typename T>
void AddInputVectors(const std::vector<std::vector<T>>& inputs,
int64_t timestamp, CalculatorRunner* runner) { int64_t timestamp, CalculatorRunner* runner) {
for (int i = 0; i < inputs.size(); ++i) { for (int i = 0; i < inputs.size(); ++i) {
AddInputVector(i, inputs[i], timestamp, runner); AddInputVector(i, inputs[i], timestamp, runner);
@ -382,6 +384,23 @@ TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) {
EXPECT_EQ(0, outputs.size()); EXPECT_EQ(0, outputs.size());
} }
TEST(ConcatenateStringVectorCalculatorTest, OneTimestamp) {
CalculatorRunner runner("ConcatenateStringVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<std::string>> inputs = {
{"a", "b"}, {"c"}, {"d", "e", "f"}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<std::string> expected_vector = {"a", "b", "c", "d", "e", "f"};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<std::string>>());
}
typedef ConcatenateVectorCalculator<std::unique_ptr<int>> typedef ConcatenateVectorCalculator<std::unique_ptr<int>>
TestConcatenateUniqueIntPtrCalculator; TestConcatenateUniqueIntPtrCalculator;
MEDIAPIPE_REGISTER_NODE(TestConcatenateUniqueIntPtrCalculator); MEDIAPIPE_REGISTER_NODE(TestConcatenateUniqueIntPtrCalculator);

View File

@ -1099,6 +1099,7 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
], ],
alwayslink = True, # Defines TestServiceCalculator
) )
cc_library( cc_library(

View File

@ -68,11 +68,11 @@ StatusBuilder&& StatusBuilder::SetNoLogging() && {
return std::move(SetNoLogging()); return std::move(SetNoLogging());
} }
StatusBuilder::operator Status() const& { StatusBuilder::operator absl::Status() const& {
return StatusBuilder(*this).JoinMessageToStatus(); return StatusBuilder(*this).JoinMessageToStatus();
} }
StatusBuilder::operator Status() && { return JoinMessageToStatus(); } StatusBuilder::operator absl::Status() && { return JoinMessageToStatus(); }
absl::Status StatusBuilder::JoinMessageToStatus() { absl::Status StatusBuilder::JoinMessageToStatus() {
if (!impl_) { if (!impl_) {

View File

@ -83,8 +83,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
return std::move(*this << msg); return std::move(*this << msg);
} }
operator Status() const&; operator absl::Status() const&;
operator Status() &&; operator absl::Status() &&;
absl::Status JoinMessageToStatus(); absl::Status JoinMessageToStatus();

View File

@ -403,11 +403,11 @@ std::ostream &operator<<(std::ostream &os,
lhs op## = rhs; \ lhs op## = rhs; \
return lhs; \ return lhs; \
} }
STRONG_INT_VS_STRONG_INT_BINARY_OP(+); STRONG_INT_VS_STRONG_INT_BINARY_OP(+)
STRONG_INT_VS_STRONG_INT_BINARY_OP(-); STRONG_INT_VS_STRONG_INT_BINARY_OP(-)
STRONG_INT_VS_STRONG_INT_BINARY_OP(&); STRONG_INT_VS_STRONG_INT_BINARY_OP(&)
STRONG_INT_VS_STRONG_INT_BINARY_OP(|); STRONG_INT_VS_STRONG_INT_BINARY_OP(|)
STRONG_INT_VS_STRONG_INT_BINARY_OP(^); STRONG_INT_VS_STRONG_INT_BINARY_OP(^)
#undef STRONG_INT_VS_STRONG_INT_BINARY_OP #undef STRONG_INT_VS_STRONG_INT_BINARY_OP
// Define operators that take one StrongInt and one native integer argument. // Define operators that take one StrongInt and one native integer argument.
@ -431,12 +431,12 @@ STRONG_INT_VS_STRONG_INT_BINARY_OP(^);
rhs op## = lhs; \ rhs op## = lhs; \
return rhs; \ return rhs; \
} }
STRONG_INT_VS_NUMERIC_BINARY_OP(*); STRONG_INT_VS_NUMERIC_BINARY_OP(*)
NUMERIC_VS_STRONG_INT_BINARY_OP(*); NUMERIC_VS_STRONG_INT_BINARY_OP(*)
STRONG_INT_VS_NUMERIC_BINARY_OP(/); STRONG_INT_VS_NUMERIC_BINARY_OP(/)
STRONG_INT_VS_NUMERIC_BINARY_OP(%); STRONG_INT_VS_NUMERIC_BINARY_OP(%)
STRONG_INT_VS_NUMERIC_BINARY_OP(<<); // NOLINT(whitespace/operators) STRONG_INT_VS_NUMERIC_BINARY_OP(<<) // NOLINT(whitespace/operators)
STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators) STRONG_INT_VS_NUMERIC_BINARY_OP(>>) // NOLINT(whitespace/operators)
#undef STRONG_INT_VS_NUMERIC_BINARY_OP #undef STRONG_INT_VS_NUMERIC_BINARY_OP
#undef NUMERIC_VS_STRONG_INT_BINARY_OP #undef NUMERIC_VS_STRONG_INT_BINARY_OP
@ -447,12 +447,12 @@ STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators)
StrongInt<TagType, ValueType, ValidatorType> rhs) { \ StrongInt<TagType, ValueType, ValidatorType> rhs) { \
return lhs.value() op rhs.value(); \ return lhs.value() op rhs.value(); \
} }
STRONG_INT_COMPARISON_OP(==); // NOLINT(whitespace/operators) STRONG_INT_COMPARISON_OP(==) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(!=); // NOLINT(whitespace/operators) STRONG_INT_COMPARISON_OP(!=) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(<); // NOLINT(whitespace/operators) STRONG_INT_COMPARISON_OP(<) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(<=); // NOLINT(whitespace/operators) STRONG_INT_COMPARISON_OP(<=) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(>); // NOLINT(whitespace/operators) STRONG_INT_COMPARISON_OP(>) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(>=); // NOLINT(whitespace/operators) STRONG_INT_COMPARISON_OP(>=) // NOLINT(whitespace/operators)
#undef STRONG_INT_COMPARISON_OP #undef STRONG_INT_COMPARISON_OP
} // namespace intops } // namespace intops

View File

@ -44,7 +44,6 @@ class GraphServiceBase {
constexpr GraphServiceBase(const char* key) : key(key) {} constexpr GraphServiceBase(const char* key) : key(key) {}
virtual ~GraphServiceBase() = default;
inline virtual absl::StatusOr<Packet> CreateDefaultObject() const { inline virtual absl::StatusOr<Packet> CreateDefaultObject() const {
return DefaultInitializationUnsupported(); return DefaultInitializationUnsupported();
} }
@ -52,14 +51,32 @@ class GraphServiceBase {
const char* key; const char* key;
protected: protected:
// `GraphService<T>` objects, deriving `GraphServiceBase` are designed to be
// global constants and not ever deleted through `GraphServiceBase`. Hence,
// protected and non-virtual destructor which helps to make `GraphService<T>`
// trivially destructible and properly defined as global constants.
//
// A class with any virtual functions should have a destructor that is either
// public and virtual or else protected and non-virtual.
// https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rc-dtor-virtual
~GraphServiceBase() = default;
absl::Status DefaultInitializationUnsupported() const { absl::Status DefaultInitializationUnsupported() const {
return absl::UnimplementedError(absl::StrCat( return absl::UnimplementedError(absl::StrCat(
"Graph service '", key, "' does not support default initialization")); "Graph service '", key, "' does not support default initialization"));
} }
}; };
// A global constant to refer a service:
// - Requesting `CalculatorContract::UseService` from calculator
// - Accessing `Calculator/SubgraphContext::Service`from calculator/subgraph
// - Setting before graph initialization `CalculatorGraph::SetServiceObject`
//
// NOTE: In headers, define your graph service reference safely as following:
// `inline constexpr GraphService<YourService> kYourService("YourService");`
//
template <typename T> template <typename T>
class GraphService : public GraphServiceBase { class GraphService final : public GraphServiceBase {
public: public:
using type = T; using type = T;
using packet_type = std::shared_ptr<T>; using packet_type = std::shared_ptr<T>;
@ -68,7 +85,7 @@ class GraphService : public GraphServiceBase {
kDisallowDefaultInitialization) kDisallowDefaultInitialization)
: GraphServiceBase(my_key), default_init_(default_init) {} : GraphServiceBase(my_key), default_init_(default_init) {}
absl::StatusOr<Packet> CreateDefaultObject() const override { absl::StatusOr<Packet> CreateDefaultObject() const final {
if (default_init_ != kAllowDefaultInitialization) { if (default_init_ != kAllowDefaultInitialization) {
return DefaultInitializationUnsupported(); return DefaultInitializationUnsupported();
} }

View File

@ -7,7 +7,7 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
const GraphService<int> kIntService("mediapipe::IntService"); constexpr GraphService<int> kIntService("mediapipe::IntService");
} // namespace } // namespace
TEST(GraphServiceManager, SetGetServiceObject) { TEST(GraphServiceManager, SetGetServiceObject) {

View File

@ -14,6 +14,8 @@
#include "mediapipe/framework/graph_service.h" #include "mediapipe/framework/graph_service.h"
#include <type_traits>
#include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/canonical_errors.h"
@ -159,7 +161,7 @@ TEST_F(GraphServiceTest, CreateDefault) {
struct TestServiceData {}; struct TestServiceData {};
const GraphService<TestServiceData> kTestServiceAllowDefaultInitialization( constexpr GraphService<TestServiceData> kTestServiceAllowDefaultInitialization(
"kTestServiceAllowDefaultInitialization", "kTestServiceAllowDefaultInitialization",
GraphServiceBase::kAllowDefaultInitialization); GraphServiceBase::kAllowDefaultInitialization);
@ -272,9 +274,13 @@ TEST(AllowDefaultInitializationGraphServiceTest,
HasSubstr("Service is unavailable."))); HasSubstr("Service is unavailable.")));
} }
const GraphService<TestServiceData> kTestServiceDisallowDefaultInitialization( constexpr GraphService<TestServiceData>
"kTestServiceDisallowDefaultInitialization", kTestServiceDisallowDefaultInitialization(
GraphServiceBase::kDisallowDefaultInitialization); "kTestServiceDisallowDefaultInitialization",
GraphServiceBase::kDisallowDefaultInitialization);
static_assert(std::is_trivially_destructible_v<GraphService<TestServiceData>>,
"GraphService is not trivially destructible");
class FailOnUnavailableOptionalDisallowDefaultInitServiceCalculator class FailOnUnavailableOptionalDisallowDefaultInitServiceCalculator
: public CalculatorBase { : public CalculatorBase {

View File

@ -16,15 +16,6 @@
namespace mediapipe { namespace mediapipe {
const GraphService<TestServiceObject> kTestService(
"test_service", GraphServiceBase::kDisallowDefaultInitialization);
const GraphService<int> kAnotherService(
"another_service", GraphServiceBase::kAllowDefaultInitialization);
const GraphService<NoDefaultConstructor> kNoDefaultService(
"no_default_service", GraphServiceBase::kAllowDefaultInitialization);
const GraphService<NeedsCreateMethod> kNeedsCreateService(
"needs_create_service", GraphServiceBase::kAllowDefaultInitialization);
absl::Status TestServiceCalculator::GetContract(CalculatorContract* cc) { absl::Status TestServiceCalculator::GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>(); cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));

View File

@ -22,14 +22,17 @@ namespace mediapipe {
using TestServiceObject = std::map<std::string, int>; using TestServiceObject = std::map<std::string, int>;
extern const GraphService<TestServiceObject> kTestService; inline constexpr GraphService<TestServiceObject> kTestService(
extern const GraphService<int> kAnotherService; "test_service", GraphServiceBase::kDisallowDefaultInitialization);
inline constexpr GraphService<int> kAnotherService(
"another_service", GraphServiceBase::kAllowDefaultInitialization);
class NoDefaultConstructor { class NoDefaultConstructor {
public: public:
NoDefaultConstructor() = delete; NoDefaultConstructor() = delete;
}; };
extern const GraphService<NoDefaultConstructor> kNoDefaultService; inline constexpr GraphService<NoDefaultConstructor> kNoDefaultService(
"no_default_service", GraphServiceBase::kAllowDefaultInitialization);
class NeedsCreateMethod { class NeedsCreateMethod {
public: public:
@ -40,7 +43,8 @@ class NeedsCreateMethod {
private: private:
NeedsCreateMethod() = default; NeedsCreateMethod() = default;
}; };
extern const GraphService<NeedsCreateMethod> kNeedsCreateService; inline constexpr GraphService<NeedsCreateMethod> kNeedsCreateService(
"needs_create_service", GraphServiceBase::kAllowDefaultInitialization);
// Use a service. // Use a service.
class TestServiceCalculator : public CalculatorBase { class TestServiceCalculator : public CalculatorBase {

View File

@ -57,7 +57,7 @@ namespace mediapipe {
// have underflow/overflow etc. This type is used internally by Timestamp // have underflow/overflow etc. This type is used internally by Timestamp
// and TimestampDiff. // and TimestampDiff.
MEDIAPIPE_DEFINE_SAFE_INT_TYPE(TimestampBaseType, int64, MEDIAPIPE_DEFINE_SAFE_INT_TYPE(TimestampBaseType, int64,
mediapipe::intops::LogFatalOnError); mediapipe::intops::LogFatalOnError)
class TimestampDiff; class TimestampDiff;

View File

@ -272,17 +272,20 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string);
#define MEDIAPIPE_REGISTER_TYPE(type, type_name, serialize_fn, deserialize_fn) \ #define MEDIAPIPE_REGISTER_TYPE(type, type_name, serialize_fn, deserialize_fn) \
SET_MEDIAPIPE_TYPE_MAP_VALUE( \ SET_MEDIAPIPE_TYPE_MAP_VALUE( \
mediapipe::PacketTypeIdToMediaPipeTypeData, \ mediapipe::PacketTypeIdToMediaPipeTypeData, \
mediapipe::tool::GetTypeHash< \ mediapipe::TypeId::Of< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \ mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
.hash_code(), \
(mediapipe::MediaPipeTypeData{ \ (mediapipe::MediaPipeTypeData{ \
mediapipe::tool::GetTypeHash< \ mediapipe::TypeId::Of< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \ mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
.hash_code(), \
type_name, serialize_fn, deserialize_fn})); \ type_name, serialize_fn, deserialize_fn})); \
SET_MEDIAPIPE_TYPE_MAP_VALUE( \ SET_MEDIAPIPE_TYPE_MAP_VALUE( \
mediapipe::PacketTypeStringToMediaPipeTypeData, type_name, \ mediapipe::PacketTypeStringToMediaPipeTypeData, type_name, \
(mediapipe::MediaPipeTypeData{ \ (mediapipe::MediaPipeTypeData{ \
mediapipe::tool::GetTypeHash< \ mediapipe::TypeId::Of< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \ mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
.hash_code(), \
type_name, serialize_fn, deserialize_fn})); type_name, serialize_fn, deserialize_fn}));
// End define MEDIAPIPE_REGISTER_TYPE. // End define MEDIAPIPE_REGISTER_TYPE.

View File

@ -38,7 +38,10 @@ cc_library(
srcs = ["gpu_service.cc"], srcs = ["gpu_service.cc"],
hdrs = ["gpu_service.h"], hdrs = ["gpu_service.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:graph_service"] + select({ deps = [
"//mediapipe/framework:graph_service",
"@com_google_absl//absl/base:core_headers",
] + select({
"//conditions:default": [ "//conditions:default": [
":gpu_shared_data_internal", ":gpu_shared_data_internal",
], ],
@ -292,6 +295,7 @@ cc_library(
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/functional:bind_front",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
] + select({ ] + select({
@ -630,6 +634,7 @@ cc_library(
"//mediapipe/framework:executor", "//mediapipe/framework:executor",
"//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/deps:no_destructor",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/base:core_headers",
] + select({ ] + select({
"//conditions:default": [], "//conditions:default": [],
"//mediapipe:apple": [ "//mediapipe:apple": [

View File

@ -47,6 +47,7 @@ std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Create(int width, int height,
auto buf = absl::make_unique<GlTextureBuffer>(GL_TEXTURE_2D, 0, width, height, auto buf = absl::make_unique<GlTextureBuffer>(GL_TEXTURE_2D, 0, width, height,
format, nullptr); format, nullptr);
if (!buf->CreateInternal(data, alignment)) { if (!buf->CreateInternal(data, alignment)) {
LOG(WARNING) << "Failed to create a GL texture";
return nullptr; return nullptr;
} }
return buf; return buf;
@ -106,7 +107,10 @@ GlTextureBuffer::GlTextureBuffer(GLenum target, GLuint name, int width,
bool GlTextureBuffer::CreateInternal(const void* data, int alignment) { bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
auto context = GlContext::GetCurrent(); auto context = GlContext::GetCurrent();
if (!context) return false; if (!context) {
LOG(WARNING) << "Cannot create a GL texture without a valid context";
return false;
}
producer_context_ = context; // Save creation GL context. producer_context_ = context; // Save creation GL context.

View File

@ -20,6 +20,7 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "absl/log/check.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_buffer_format.h"
@ -72,8 +73,10 @@ class GpuBuffer {
// are not portable. Applications and calculators should normally obtain // are not portable. Applications and calculators should normally obtain
// GpuBuffers in a portable way from the framework, e.g. using // GpuBuffers in a portable way from the framework, e.g. using
// GpuBufferMultiPool. // GpuBufferMultiPool.
explicit GpuBuffer(std::shared_ptr<internal::GpuBufferStorage> storage) explicit GpuBuffer(std::shared_ptr<internal::GpuBufferStorage> storage) {
: holder_(std::make_shared<StorageHolder>(std::move(storage))) {} CHECK(storage) << "Cannot construct GpuBuffer with null storage";
holder_ = std::make_shared<StorageHolder>(std::move(storage));
}
#if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
// This is used to support backward-compatible construction of GpuBuffer from // This is used to support backward-compatible construction of GpuBuffer from

View File

@ -28,6 +28,12 @@ namespace mediapipe {
#define GL_HALF_FLOAT 0x140B #define GL_HALF_FLOAT 0x140B
#endif // GL_HALF_FLOAT #endif // GL_HALF_FLOAT
#ifdef __EMSCRIPTEN__
#ifndef GL_HALF_FLOAT_OES
#define GL_HALF_FLOAT_OES 0x8D61
#endif // GL_HALF_FLOAT_OES
#endif // __EMSCRIPTEN__
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
#ifdef GL_ES_VERSION_2_0 #ifdef GL_ES_VERSION_2_0
static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) { static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) {
@ -48,6 +54,12 @@ static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) {
case GL_RG8: case GL_RG8:
info->gl_internal_format = info->gl_format = GL_RG_EXT; info->gl_internal_format = info->gl_format = GL_RG_EXT;
return; return;
#ifdef __EMSCRIPTEN__
case GL_RGBA16F:
info->gl_internal_format = GL_RGBA;
info->gl_type = GL_HALF_FLOAT_OES;
return;
#endif // __EMSCRIPTEN__
default: default:
return; return;
} }

View File

@ -15,6 +15,7 @@
#ifndef MEDIAPIPE_GPU_GPU_SERVICE_H_ #ifndef MEDIAPIPE_GPU_GPU_SERVICE_H_
#define MEDIAPIPE_GPU_GPU_SERVICE_H_ #define MEDIAPIPE_GPU_GPU_SERVICE_H_
#include "absl/base/attributes.h"
#include "mediapipe/framework/graph_service.h" #include "mediapipe/framework/graph_service.h"
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
@ -29,7 +30,7 @@ class GpuResources {
}; };
#endif // MEDIAPIPE_DISABLE_GPU #endif // MEDIAPIPE_DISABLE_GPU
extern const GraphService<GpuResources> kGpuService; ABSL_CONST_INIT extern const GraphService<GpuResources> kGpuService;
} // namespace mediapipe } // namespace mediapipe

View File

@ -14,6 +14,7 @@
#include "mediapipe/gpu/gpu_shared_data_internal.h" #include "mediapipe/gpu/gpu_shared_data_internal.h"
#include "absl/base/attributes.h"
#include "mediapipe/framework/deps/no_destructor.h" #include "mediapipe/framework/deps/no_destructor.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_context.h"
@ -116,7 +117,7 @@ GpuResources::~GpuResources() {
#endif // __APPLE__ #endif // __APPLE__
} }
extern const GraphService<GpuResources> kGpuService; ABSL_CONST_INIT extern const GraphService<GpuResources> kGpuService;
absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) { absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) {
CHECK(node->Contract().ServiceRequests().contains(kGpuService.key)); CHECK(node->Contract().ServiceRequests().contains(kGpuService.key));

View File

@ -187,7 +187,7 @@ class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
"""Instantiates perceptual loss. """Instantiates perceptual loss.
Args: Args:
feature_weight: The weight coeffcients of multiple model extracted feature_weight: The weight coefficients of multiple model extracted
features used for calculating the perceptual loss. features used for calculating the perceptual loss.
loss_weight: The weight coefficients between `style_loss` and loss_weight: The weight coefficients between `style_loss` and
`content_loss`. `content_loss`.

View File

@ -105,7 +105,7 @@ class FaceStylizer(object):
self._train_model(train_data=train_data, preprocessor=self._preprocessor) self._train_model(train_data=train_data, preprocessor=self._preprocessor)
def _create_model(self): def _create_model(self):
"""Creates the componenets of face stylizer.""" """Creates the components of face stylizer."""
self._encoder = model_util.load_keras_model( self._encoder = model_util.load_keras_model(
constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path() constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path()
) )
@ -138,7 +138,7 @@ class FaceStylizer(object):
""" """
train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor) train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor)
# TODO: Support processing mulitple input style images. The # TODO: Support processing multiple input style images. The
# input style images are expected to have similar style. # input style images are expected to have similar style.
# style_sample represents a tuple of (style_image, style_label). # style_sample represents a tuple of (style_image, style_label).
style_sample = next(iter(train_dataset)) style_sample = next(iter(train_dataset))

View File

@ -103,8 +103,8 @@ class ModelResourcesCache {
}; };
// Global service for mediapipe task model resources cache. // Global service for mediapipe task model resources cache.
const mediapipe::GraphService<ModelResourcesCache> kModelResourcesCacheService( inline constexpr mediapipe::GraphService<ModelResourcesCache>
"mediapipe::tasks::ModelResourcesCacheService"); kModelResourcesCacheService("mediapipe::tasks::ModelResourcesCacheService");
} // namespace core } // namespace core
} // namespace tasks } // namespace tasks

View File

@ -291,8 +291,11 @@ class TensorsToSegmentationCalculator : public Node {
static constexpr Output<Image>::Multiple kConfidenceMaskOut{ static constexpr Output<Image>::Multiple kConfidenceMaskOut{
"CONFIDENCE_MASK"}; "CONFIDENCE_MASK"};
static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"}; static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"};
static constexpr Output<std::vector<float>>::Optional kQualityScoresOut{
"QUALITY_SCORES"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut, MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut,
kConfidenceMaskOut, kCategoryMaskOut); kConfidenceMaskOut, kCategoryMaskOut,
kQualityScoresOut);
static absl::Status UpdateContract(CalculatorContract* cc); static absl::Status UpdateContract(CalculatorContract* cc);
@ -345,12 +348,33 @@ absl::Status TensorsToSegmentationCalculator::Open(
absl::Status TensorsToSegmentationCalculator::Process( absl::Status TensorsToSegmentationCalculator::Process(
mediapipe::CalculatorContext* cc) { mediapipe::CalculatorContext* cc) {
RET_CHECK_EQ(kTensorsIn(cc).Get().size(), 1) const auto& input_tensors = kTensorsIn(cc).Get();
<< "Expect a vector of single Tensor."; if (input_tensors.size() != 1 && input_tensors.size() != 2) {
const auto& input_tensor = kTensorsIn(cc).Get()[0]; return absl::InvalidArgumentError(
"Expect input tensor vector of size 1 or 2.");
}
const auto& input_tensor = *input_tensors.rbegin();
ASSIGN_OR_RETURN(const Shape input_shape, ASSIGN_OR_RETURN(const Shape input_shape,
GetImageLikeTensorShape(input_tensor)); GetImageLikeTensorShape(input_tensor));
// TODO: should use tensor signature to get the correct output
// tensor.
if (input_tensors.size() == 2) {
const auto& quality_tensor = input_tensors[0];
const float* quality_score_buffer =
quality_tensor.GetCpuReadView().buffer<float>();
const std::vector<float> quality_scores(
quality_score_buffer,
quality_score_buffer +
(quality_tensor.bytes() / quality_tensor.element_size()));
kQualityScoresOut(cc).Send(quality_scores);
} else {
// If the input_tensors don't contain quality scores, send the default
// quality scores as 1.
const std::vector<float> quality_scores(input_shape.channels, 1.0f);
kQualityScoresOut(cc).Send(quality_scores);
}
// Category mask does not require activation function. // Category mask does not require activation function.
if (options_.segmenter_options().output_type() == if (options_.segmenter_options().output_type() ==
SegmenterOptions::CONFIDENCE_MASK && SegmenterOptions::CONFIDENCE_MASK &&

View File

@ -46,6 +46,8 @@ constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectStreamName[] = "norm_rect_in"; constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kQualityScoresStreamName[] = "quality_scores";
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
constexpr char kSubgraphTypeName[] = constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000; constexpr int kMicroSecondsPerMilliSecond = 1000;
@ -77,6 +79,8 @@ CalculatorGraphConfig CreateGraphConfig(
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
graph.Out(kCategoryMaskTag); graph.Out(kCategoryMaskTag);
} }
task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
graph.Out(kQualityScoresTag);
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag); graph.Out(kImageTag);
if (enable_flow_limiting) { if (enable_flow_limiting) {
@ -172,9 +176,13 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
category_mask = category_mask =
status_or_packets.value()[kCategoryMaskStreamName].Get<Image>(); status_or_packets.value()[kCategoryMaskStreamName].Get<Image>();
} }
const std::vector<float>& quality_scores =
status_or_packets.value()[kQualityScoresStreamName]
.Get<std::vector<float>>();
Packet image_packet = status_or_packets.value()[kImageOutStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback( result_callback(
{{confidence_masks, category_mask}}, image_packet.Get<Image>(), {{confidence_masks, category_mask, quality_scores}},
image_packet.Get<Image>(),
image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
}; };
} }
@ -227,7 +235,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
if (output_category_mask_) { if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>(); category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
} }
return {{confidence_masks, category_mask}}; const std::vector<float>& quality_scores =
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
return {{confidence_masks, category_mask, quality_scores}};
} }
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo( absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
@ -260,7 +270,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
if (output_category_mask_) { if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>(); category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
} }
return {{confidence_masks, category_mask}}; const std::vector<float>& quality_scores =
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
return {{confidence_masks, category_mask, quality_scores}};
} }
absl::Status ImageSegmenter::SegmentAsync( absl::Status ImageSegmenter::SegmentAsync(

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <cstdint>
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <type_traits> #include <type_traits>
@ -81,6 +82,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU";
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA"; constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
// Struct holding the different output streams produced by the image segmenter // Struct holding the different output streams produced by the image segmenter
@ -90,6 +92,7 @@ struct ImageSegmenterOutputs {
std::optional<std::vector<Source<Image>>> confidence_masks; std::optional<std::vector<Source<Image>>> confidence_masks;
std::optional<Source<Image>> category_mask; std::optional<Source<Image>> category_mask;
// The same as the input image, mainly used for live stream mode. // The same as the input image, mainly used for live stream mode.
std::optional<Source<std::vector<float>>> quality_scores;
Source<Image> image; Source<Image> image;
}; };
@ -191,19 +194,12 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
"Segmentation tflite models are assumed to have a single subgraph.", "Segmentation tflite models are assumed to have a single subgraph.",
MediaPipeTasksStatus::kInvalidArgumentError); MediaPipeTasksStatus::kInvalidArgumentError);
} }
const auto* primary_subgraph = (*model.subgraphs())[0];
if (primary_subgraph->outputs()->size() != 1) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Segmentation tflite models are assumed to have a single output.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
*options->mutable_label_items(), *options->mutable_label_items(),
GetLabelItemsIfAny(*metadata_extractor, GetLabelItemsIfAny(
*metadata_extractor->GetOutputTensorMetadata()->Get(0), *metadata_extractor,
segmenter_option.display_names_locale())); **metadata_extractor->GetOutputTensorMetadata()->crbegin(),
segmenter_option.display_names_locale()));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -213,10 +209,16 @@ absl::StatusOr<const tflite::Tensor*> GetOutputTensor(
const tflite::Model& model = *model_resources.GetTfLiteModel(); const tflite::Model& model = *model_resources.GetTfLiteModel();
const auto* primary_subgraph = (*model.subgraphs())[0]; const auto* primary_subgraph = (*model.subgraphs())[0];
const auto* output_tensor = const auto* output_tensor =
(*primary_subgraph->tensors())[(*primary_subgraph->outputs())[0]]; (*primary_subgraph->tensors())[*(*primary_subgraph->outputs()).rbegin()];
return output_tensor; return output_tensor;
} }
uint32_t GetOutputTensorsSize(const core::ModelResources& model_resources) {
const tflite::Model& model = *model_resources.GetTfLiteModel();
const auto* primary_subgraph = (*model.subgraphs())[0];
return primary_subgraph->outputs()->size();
}
// Get the input tensor from the tflite model of given model resources. // Get the input tensor from the tflite model of given model resources.
absl::StatusOr<const tflite::Tensor*> GetInputTensor( absl::StatusOr<const tflite::Tensor*> GetInputTensor(
const core::ModelResources& model_resources) { const core::ModelResources& model_resources) {
@ -433,6 +435,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
*output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)]; *output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
} }
} }
if (output_streams.quality_scores) {
*output_streams.quality_scores >>
graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
}
output_streams.image >> graph[Output<Image>(kImageTag)]; output_streams.image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -530,9 +536,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i])); tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
} }
} }
auto quality_scores =
tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks, return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
/*confidence_masks=*/std::nullopt, /*confidence_masks=*/std::nullopt,
/*category_mask=*/std::nullopt, /*category_mask=*/std::nullopt,
/*quality_scores=*/quality_scores,
/*image=*/image_and_tensors.image}; /*image=*/image_and_tensors.image};
} else { } else {
std::optional<std::vector<Source<Image>>> confidence_masks; std::optional<std::vector<Source<Image>>> confidence_masks;
@ -552,9 +561,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
if (output_category_mask_) { if (output_category_mask_) {
category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)]; category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)];
} }
auto quality_scores =
tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt, return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt,
/*confidence_masks=*/confidence_masks, /*confidence_masks=*/confidence_masks,
/*category_mask=*/category_mask, /*category_mask=*/category_mask,
/*quality_scores=*/quality_scores,
/*image=*/image_and_tensors.image}; /*image=*/image_and_tensors.image};
} }
} }

View File

@ -33,6 +33,10 @@ struct ImageSegmenterResult {
// A category mask of uint8 image in GRAY8 format where each pixel represents // A category mask of uint8 image in GRAY8 format where each pixel represents
// the class which the pixel in the original image was predicted to belong to. // the class which the pixel in the original image was predicted to belong to.
std::optional<Image> category_mask; std::optional<Image> category_mask;
// The quality scores of the result masks, in the range of [0, 1]. Defaults to
// `1` if the model doesn't output quality scores. Each element corresponds to
// the score of the category in the model outputs.
std::vector<float> quality_scores;
}; };
} // namespace image_segmenter } // namespace image_segmenter

View File

@ -51,12 +51,14 @@ constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageOutStreamName[] = "image_out";
constexpr char kRoiStreamName[] = "roi_in"; constexpr char kRoiStreamName[] = "roi_in";
constexpr char kNormRectStreamName[] = "norm_rect_in"; constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kQualityScoresStreamName[] = "quality_scores";
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"}; constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"}; constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
constexpr absl::string_view kImageTag{"IMAGE"}; constexpr absl::string_view kImageTag{"IMAGE"};
constexpr absl::string_view kRoiTag{"ROI"}; constexpr absl::string_view kRoiTag{"ROI"};
constexpr absl::string_view kNormRectTag{"NORM_RECT"}; constexpr absl::string_view kNormRectTag{"NORM_RECT"};
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
constexpr absl::string_view kSubgraphTypeName{ constexpr absl::string_view kSubgraphTypeName{
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"}; "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
@ -91,6 +93,8 @@ CalculatorGraphConfig CreateGraphConfig(
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
graph.Out(kCategoryMaskTag); graph.Out(kCategoryMaskTag);
} }
task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
graph.Out(kQualityScoresTag);
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag); graph.Out(kImageTag);
graph.In(kImageTag) >> task_subgraph.In(kImageTag); graph.In(kImageTag) >> task_subgraph.In(kImageTag);
@ -201,7 +205,9 @@ absl::StatusOr<ImageSegmenterResult> InteractiveSegmenter::Segment(
if (output_category_mask_) { if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>(); category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
} }
return {{confidence_masks, category_mask}}; const std::vector<float>& quality_scores =
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
return {{confidence_masks, category_mask, quality_scores}};
} }
} // namespace interactive_segmenter } // namespace interactive_segmenter

View File

@ -58,6 +58,7 @@ constexpr absl::string_view kAlphaTag{"ALPHA"};
constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"}; constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
constexpr absl::string_view kNormRectTag{"NORM_RECT"}; constexpr absl::string_view kNormRectTag{"NORM_RECT"};
constexpr absl::string_view kRoiTag{"ROI"}; constexpr absl::string_view kRoiTag{"ROI"};
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
// Updates the graph to return `roi` stream which has same dimension as // Updates the graph to return `roi` stream which has same dimension as
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is // `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
@ -200,6 +201,8 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
graph[Output<Image>(kCategoryMaskTag)]; graph[Output<Image>(kCategoryMaskTag)];
} }
} }
image_segmenter.Out(kQualityScoresTag) >>
graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)]; image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
return graph.GetConfig(); return graph.GetConfig();

View File

@ -81,7 +81,7 @@ strip_api_include_path_prefix(
"//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierResult.h", "//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierResult.h",
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetector.h", "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetector.h",
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorOptions.h", "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorOptions.h",
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectionResult.h", "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorResult.h",
], ],
) )
@ -162,7 +162,7 @@ apple_static_xcframework(
":MPPImageClassifierResult.h", ":MPPImageClassifierResult.h",
":MPPObjectDetector.h", ":MPPObjectDetector.h",
":MPPObjectDetectorOptions.h", ":MPPObjectDetectorOptions.h",
":MPPObjectDetectionResult.h", ":MPPObjectDetectorResult.h",
], ],
deps = [ deps = [
"//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier", "//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier",

View File

@ -16,17 +16,6 @@
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
/**
* MediaPipe Tasks delegate.
*/
typedef NS_ENUM(NSUInteger, MPPDelegate) {
/** CPU. */
MPPDelegateCPU,
/** GPU. */
MPPDelegateGPU
} NS_SWIFT_NAME(Delegate);
/** /**
* Holds the base options that is used for creation of any type of task. It has fields with * Holds the base options that is used for creation of any type of task. It has fields with
* important information acceleration configuration, TFLite model source etc. * important information acceleration configuration, TFLite model source etc.
@ -37,12 +26,6 @@ NS_SWIFT_NAME(BaseOptions)
/** The path to the model asset to open and mmap in memory. */ /** The path to the model asset to open and mmap in memory. */
@property(nonatomic, copy) NSString *modelAssetPath; @property(nonatomic, copy) NSString *modelAssetPath;
/**
* Device delegate to run the MediaPipe pipeline. If the delegate is not set, the default
* delegate CPU is used.
*/
@property(nonatomic) MPPDelegate delegate;
@end @end
NS_ASSUME_NONNULL_END NS_ASSUME_NONNULL_END

View File

@ -28,7 +28,6 @@
MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init]; MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init];
baseOptions.modelAssetPath = self.modelAssetPath; baseOptions.modelAssetPath = self.modelAssetPath;
baseOptions.delegate = self.delegate;
return baseOptions; return baseOptions;
} }

View File

@ -33,20 +33,6 @@ using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions;
if (self.modelAssetPath) { if (self.modelAssetPath) {
baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String); baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String);
} }
switch (self.delegate) {
case MPPDelegateCPU: {
baseOptionsProto->mutable_acceleration()->mutable_tflite();
break;
}
case MPPDelegateGPU: {
// TODO: Provide an implementation for GPU Delegate.
[NSException raise:@"Invalid value for delegate" format:@"GPU Delegate is not implemented."];
break;
}
default:
break;
}
} }
@end @end

View File

@ -28,9 +28,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
XCTAssertNotNil(error); \ XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \ XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \ XCTAssertEqual(error.code, expectedError.code); \
XCTAssertNotEqual( \ XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
NSNotFound)
#define AssertEqualCategoryArrays(categories, expectedCategories) \ #define AssertEqualCategoryArrays(categories, expectedCategories) \
XCTAssertEqual(categories.count, expectedCategories.count); \ XCTAssertEqual(categories.count, expectedCategories.count); \

View File

@ -29,9 +29,7 @@ static const float kSimilarityDiffTolerance = 1e-4;
XCTAssertNotNil(error); \ XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \ XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \ XCTAssertEqual(error.code, expectedError.code); \
XCTAssertNotEqual( \ XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
NSNotFound)
#define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \ #define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \
XCTAssertNotNil(textEmbedderResult); \ XCTAssertNotNil(textEmbedderResult); \

View File

@ -34,9 +34,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
XCTAssertNotNil(error); \ XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \ XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \ XCTAssertEqual(error.code, expectedError.code); \
XCTAssertNotEqual( \ XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
NSNotFound)
#define AssertEqualCategoryArrays(categories, expectedCategories) \ #define AssertEqualCategoryArrays(categories, expectedCategories) \
XCTAssertEqual(categories.count, expectedCategories.count); \ XCTAssertEqual(categories.count, expectedCategories.count); \
@ -670,10 +668,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
// Because of flow limiting, we cannot ensure that the callback will be // Because of flow limiting, we cannot ensure that the callback will be
// invoked `iterationCount` times. // invoked `iterationCount` times.
// An normal expectation will fail if expectation.fullfill() is not called // An normal expectation will fail if expectation.fulfill() is not called
// `expectation.expectedFulfillmentCount` times. // `expectation.expectedFulfillmentCount` times.
// If `expectation.isInverted = true`, the test will only succeed if // If `expectation.isInverted = true`, the test will only succeed if
// expectation is not fullfilled for the specified `expectedFulfillmentCount`. // expectation is not fulfilled for the specified `expectedFulfillmentCount`.
// Since in our case we cannot predict how many times the expectation is // Since in our case we cannot predict how many times the expectation is
// supposed to be fullfilled setting, // supposed to be fullfilled setting,
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and // `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and

View File

@ -32,9 +32,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
XCTAssertNotNil(error); \ XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \ XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \ XCTAssertEqual(error.code, expectedError.code); \
XCTAssertNotEqual( \ XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
NSNotFound)
#define AssertEqualCategories(category, expectedCategory, detectionIndex, categoryIndex) \ #define AssertEqualCategories(category, expectedCategory, detectionIndex, categoryIndex) \
XCTAssertEqual(category.index, expectedCategory.index, \ XCTAssertEqual(category.index, expectedCategory.index, \
@ -70,7 +68,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
#pragma mark Results #pragma mark Results
+ (MPPObjectDetectionResult *)expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: + (MPPObjectDetectorResult *)expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
(NSInteger)timestampInMilliseconds { (NSInteger)timestampInMilliseconds {
NSArray<MPPDetection *> *detections = @[ NSArray<MPPDetection *> *detections = @[
[[MPPDetection alloc] initWithCategories:@[ [[MPPDetection alloc] initWithCategories:@[
@ -95,8 +93,8 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
keypoints:nil], keypoints:nil],
]; ];
return [[MPPObjectDetectionResult alloc] initWithDetections:detections return [[MPPObjectDetectorResult alloc] initWithDetections:detections
timestampInMilliseconds:timestampInMilliseconds]; timestampInMilliseconds:timestampInMilliseconds];
} }
- (void)assertDetections:(NSArray<MPPDetection *> *)detections - (void)assertDetections:(NSArray<MPPDetection *> *)detections
@ -112,25 +110,25 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
} }
} }
- (void)assertObjectDetectionResult:(MPPObjectDetectionResult *)objectDetectionResult - (void)assertObjectDetectorResult:(MPPObjectDetectorResult *)objectDetectorResult
isEqualToExpectedResult:(MPPObjectDetectionResult *)expectedObjectDetectionResult isEqualToExpectedResult:(MPPObjectDetectorResult *)expectedObjectDetectorResult
expectedDetectionsCount:(NSInteger)expectedDetectionsCount { expectedDetectionsCount:(NSInteger)expectedDetectionsCount {
XCTAssertNotNil(objectDetectionResult); XCTAssertNotNil(objectDetectorResult);
NSArray<MPPDetection *> *detectionsSubsetToCompare; NSArray<MPPDetection *> *detectionsSubsetToCompare;
XCTAssertEqual(objectDetectionResult.detections.count, expectedDetectionsCount); XCTAssertEqual(objectDetectorResult.detections.count, expectedDetectionsCount);
if (objectDetectionResult.detections.count > expectedObjectDetectionResult.detections.count) { if (objectDetectorResult.detections.count > expectedObjectDetectorResult.detections.count) {
detectionsSubsetToCompare = [objectDetectionResult.detections detectionsSubsetToCompare = [objectDetectorResult.detections
subarrayWithRange:NSMakeRange(0, expectedObjectDetectionResult.detections.count)]; subarrayWithRange:NSMakeRange(0, expectedObjectDetectorResult.detections.count)];
} else { } else {
detectionsSubsetToCompare = objectDetectionResult.detections; detectionsSubsetToCompare = objectDetectorResult.detections;
} }
[self assertDetections:detectionsSubsetToCompare [self assertDetections:detectionsSubsetToCompare
isEqualToExpectedDetections:expectedObjectDetectionResult.detections]; isEqualToExpectedDetections:expectedObjectDetectorResult.detections];
XCTAssertEqual(objectDetectionResult.timestampInMilliseconds, XCTAssertEqual(objectDetectorResult.timestampInMilliseconds,
expectedObjectDetectionResult.timestampInMilliseconds); expectedObjectDetectorResult.timestampInMilliseconds);
} }
#pragma mark File #pragma mark File
@ -195,28 +193,27 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
- (void)assertResultsOfDetectInImage:(MPPImage *)mppImage - (void)assertResultsOfDetectInImage:(MPPImage *)mppImage
usingObjectDetector:(MPPObjectDetector *)objectDetector usingObjectDetector:(MPPObjectDetector *)objectDetector
maxResults:(NSInteger)maxResults maxResults:(NSInteger)maxResults
equalsObjectDetectionResult:(MPPObjectDetectionResult *)expectedObjectDetectionResult { equalsObjectDetectorResult:(MPPObjectDetectorResult *)expectedObjectDetectorResult {
MPPObjectDetectionResult *objectDetectionResult = [objectDetector detectInImage:mppImage MPPObjectDetectorResult *ObjectDetectorResult = [objectDetector detectInImage:mppImage error:nil];
error:nil];
[self assertObjectDetectionResult:objectDetectionResult [self assertObjectDetectorResult:ObjectDetectorResult
isEqualToExpectedResult:expectedObjectDetectionResult isEqualToExpectedResult:expectedObjectDetectorResult
expectedDetectionsCount:maxResults > 0 ? maxResults expectedDetectionsCount:maxResults > 0 ? maxResults
: objectDetectionResult.detections.count]; : ObjectDetectorResult.detections.count];
} }
- (void)assertResultsOfDetectInImageWithFileInfo:(NSDictionary *)fileInfo - (void)assertResultsOfDetectInImageWithFileInfo:(NSDictionary *)fileInfo
usingObjectDetector:(MPPObjectDetector *)objectDetector usingObjectDetector:(MPPObjectDetector *)objectDetector
maxResults:(NSInteger)maxResults maxResults:(NSInteger)maxResults
equalsObjectDetectionResult: equalsObjectDetectorResult:
(MPPObjectDetectionResult *)expectedObjectDetectionResult { (MPPObjectDetectorResult *)expectedObjectDetectorResult {
MPPImage *mppImage = [self imageWithFileInfo:fileInfo]; MPPImage *mppImage = [self imageWithFileInfo:fileInfo];
[self assertResultsOfDetectInImage:mppImage [self assertResultsOfDetectInImage:mppImage
usingObjectDetector:objectDetector usingObjectDetector:objectDetector
maxResults:maxResults maxResults:maxResults
equalsObjectDetectionResult:expectedObjectDetectionResult]; equalsObjectDetectorResult:expectedObjectDetectorResult];
} }
#pragma mark General Tests #pragma mark General Tests
@ -266,10 +263,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
usingObjectDetector:objectDetector usingObjectDetector:objectDetector
maxResults:-1 maxResults:-1
equalsObjectDetectionResult: equalsObjectDetectorResult:
[MPPObjectDetectorTests [MPPObjectDetectorTests
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
0]]; 0]];
} }
- (void)testDetectWithOptionsSucceeds { - (void)testDetectWithOptionsSucceeds {
@ -280,10 +277,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
usingObjectDetector:objectDetector usingObjectDetector:objectDetector
maxResults:-1 maxResults:-1
equalsObjectDetectionResult: equalsObjectDetectorResult:
[MPPObjectDetectorTests [MPPObjectDetectorTests
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
0]]; 0]];
} }
- (void)testDetectWithMaxResultsSucceeds { - (void)testDetectWithMaxResultsSucceeds {
@ -297,10 +294,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
usingObjectDetector:objectDetector usingObjectDetector:objectDetector
maxResults:maxResults maxResults:maxResults
equalsObjectDetectionResult: equalsObjectDetectorResult:
[MPPObjectDetectorTests [MPPObjectDetectorTests
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
0]]; 0]];
} }
- (void)testDetectWithScoreThresholdSucceeds { - (void)testDetectWithScoreThresholdSucceeds {
@ -316,13 +313,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
boundingBox:CGRectMake(608, 161, 381, 439) boundingBox:CGRectMake(608, 161, 381, 439)
keypoints:nil], keypoints:nil],
]; ];
MPPObjectDetectionResult *expectedObjectDetectionResult = MPPObjectDetectorResult *expectedObjectDetectorResult =
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0]; [[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
usingObjectDetector:objectDetector usingObjectDetector:objectDetector
maxResults:-1 maxResults:-1
equalsObjectDetectionResult:expectedObjectDetectionResult]; equalsObjectDetectorResult:expectedObjectDetectorResult];
} }
- (void)testDetectWithCategoryAllowlistSucceeds { - (void)testDetectWithCategoryAllowlistSucceeds {
@ -359,13 +356,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
keypoints:nil], keypoints:nil],
]; ];
MPPObjectDetectionResult *expectedDetectionResult = MPPObjectDetectorResult *expectedDetectionResult =
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0]; [[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
usingObjectDetector:objectDetector usingObjectDetector:objectDetector
maxResults:-1 maxResults:-1
equalsObjectDetectionResult:expectedDetectionResult]; equalsObjectDetectorResult:expectedDetectionResult];
} }
- (void)testDetectWithCategoryDenylistSucceeds { - (void)testDetectWithCategoryDenylistSucceeds {
@ -414,13 +411,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
keypoints:nil], keypoints:nil],
]; ];
MPPObjectDetectionResult *expectedDetectionResult = MPPObjectDetectorResult *expectedDetectionResult =
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0]; [[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
usingObjectDetector:objectDetector usingObjectDetector:objectDetector
maxResults:-1 maxResults:-1
equalsObjectDetectionResult:expectedDetectionResult]; equalsObjectDetectorResult:expectedDetectionResult];
} }
- (void)testDetectWithOrientationSucceeds { - (void)testDetectWithOrientationSucceeds {
@ -437,8 +434,8 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
keypoints:nil], keypoints:nil],
]; ];
MPPObjectDetectionResult *expectedDetectionResult = MPPObjectDetectorResult *expectedDetectionResult =
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0]; [[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
MPPImage *image = [self imageWithFileInfo:kCatsAndDogsRotatedImage MPPImage *image = [self imageWithFileInfo:kCatsAndDogsRotatedImage
orientation:UIImageOrientationRight]; orientation:UIImageOrientationRight];
@ -446,7 +443,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
[self assertResultsOfDetectInImage:image [self assertResultsOfDetectInImage:image
usingObjectDetector:objectDetector usingObjectDetector:objectDetector
maxResults:1 maxResults:1
equalsObjectDetectionResult:expectedDetectionResult]; equalsObjectDetectorResult:expectedDetectionResult];
} }
#pragma mark Running Mode Tests #pragma mark Running Mode Tests
@ -613,15 +610,15 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
MPPImage *image = [self imageWithFileInfo:kCatsAndDogsImage]; MPPImage *image = [self imageWithFileInfo:kCatsAndDogsImage];
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
MPPObjectDetectionResult *objectDetectionResult = [objectDetector detectInVideoFrame:image MPPObjectDetectorResult *ObjectDetectorResult = [objectDetector detectInVideoFrame:image
timestampInMilliseconds:i timestampInMilliseconds:i
error:nil]; error:nil];
[self assertObjectDetectionResult:objectDetectionResult [self assertObjectDetectorResult:ObjectDetectorResult
isEqualToExpectedResult: isEqualToExpectedResult:
[MPPObjectDetectorTests [MPPObjectDetectorTests
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:i] expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:i]
expectedDetectionsCount:maxResults]; expectedDetectionsCount:maxResults];
} }
} }
@ -676,15 +673,15 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
// Because of flow limiting, we cannot ensure that the callback will be // Because of flow limiting, we cannot ensure that the callback will be
// invoked `iterationCount` times. // invoked `iterationCount` times.
// An normal expectation will fail if expectation.fullfill() is not called // An normal expectation will fail if expectation.fulfill() is not called
// `expectation.expectedFulfillmentCount` times. // `expectation.expectedFulfillmentCount` times.
// If `expectation.isInverted = true`, the test will only succeed if // If `expectation.isInverted = true`, the test will only succeed if
// expectation is not fullfilled for the specified `expectedFulfillmentCount`. // expectation is not fulfilled for the specified `expectedFulfillmentCount`.
// Since in our case we cannot predict how many times the expectation is // Since in our case we cannot predict how many times the expectation is
// supposed to be fullfilled setting, // supposed to be fulfilled setting,
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and // `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds if // `expectation.isInverted = true` ensures that test succeeds if
// expectation is fullfilled <= `iterationCount` times. // expectation is fulfilled <= `iterationCount` times.
XCTestExpectation *expectation = [[XCTestExpectation alloc] XCTestExpectation *expectation = [[XCTestExpectation alloc]
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"]; initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
expectation.expectedFulfillmentCount = iterationCount + 1; expectation.expectedFulfillmentCount = iterationCount + 1;
@ -714,16 +711,16 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
#pragma mark MPPObjectDetectorLiveStreamDelegate Methods #pragma mark MPPObjectDetectorLiveStreamDelegate Methods
- (void)objectDetector:(MPPObjectDetector *)objectDetector - (void)objectDetector:(MPPObjectDetector *)objectDetector
didFinishDetectionWithResult:(MPPObjectDetectionResult *)objectDetectionResult didFinishDetectionWithResult:(MPPObjectDetectorResult *)ObjectDetectorResult
timestampInMilliseconds:(NSInteger)timestampInMilliseconds timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError *)error { error:(NSError *)error {
NSInteger maxResults = 4; NSInteger maxResults = 4;
[self assertObjectDetectionResult:objectDetectionResult [self assertObjectDetectorResult:ObjectDetectorResult
isEqualToExpectedResult: isEqualToExpectedResult:
[MPPObjectDetectorTests [MPPObjectDetectorTests
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
timestampInMilliseconds] timestampInMilliseconds]
expectedDetectionsCount:maxResults]; expectedDetectionsCount:maxResults];
if (objectDetector == outOfOrderTimestampTestDict[kLiveStreamTestsDictObjectDetectorKey]) { if (objectDetector == outOfOrderTimestampTestDict[kLiveStreamTestsDictObjectDetectorKey]) {
[outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill]; [outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];

View File

@ -64,4 +64,3 @@ objc_library(
"//mediapipe/tasks/ios/vision/gesture_recognizer/utils:MPPGestureRecognizerResultHelpers", "//mediapipe/tasks/ios/vision/gesture_recognizer/utils:MPPGestureRecognizerResultHelpers",
], ],
) )

View File

@ -87,7 +87,7 @@ NS_SWIFT_NAME(GestureRecognizerOptions)
gestureRecognizerLiveStreamDelegate; gestureRecognizerLiveStreamDelegate;
/** Sets the maximum number of hands can be detected by the GestureRecognizer. */ /** Sets the maximum number of hands can be detected by the GestureRecognizer. */
@property(nonatomic) NSInteger numberOfHands; @property(nonatomic) NSInteger numberOfHands NS_SWIFT_NAME(numHands);
/** Sets minimum confidence score for the hand detection to be considered successful */ /** Sets minimum confidence score for the hand detection to be considered successful */
@property(nonatomic) float minHandDetectionConfidence; @property(nonatomic) float minHandDetectionConfidence;

View File

@ -31,7 +31,8 @@
MPPGestureRecognizerOptions *gestureRecognizerOptions = [super copyWithZone:zone]; MPPGestureRecognizerOptions *gestureRecognizerOptions = [super copyWithZone:zone];
gestureRecognizerOptions.runningMode = self.runningMode; gestureRecognizerOptions.runningMode = self.runningMode;
gestureRecognizerOptions.gestureRecognizerLiveStreamDelegate = self.gestureRecognizerLiveStreamDelegate; gestureRecognizerOptions.gestureRecognizerLiveStreamDelegate =
self.gestureRecognizerLiveStreamDelegate;
gestureRecognizerOptions.numberOfHands = self.numberOfHands; gestureRecognizerOptions.numberOfHands = self.numberOfHands;
gestureRecognizerOptions.minHandDetectionConfidence = self.minHandDetectionConfidence; gestureRecognizerOptions.minHandDetectionConfidence = self.minHandDetectionConfidence;
gestureRecognizerOptions.minHandPresenceConfidence = self.minHandPresenceConfidence; gestureRecognizerOptions.minHandPresenceConfidence = self.minHandPresenceConfidence;

View File

@ -18,9 +18,9 @@
- (instancetype)initWithGestures:(NSArray<NSArray<MPPCategory *> *> *)gestures - (instancetype)initWithGestures:(NSArray<NSArray<MPPCategory *> *> *)gestures
handedness:(NSArray<NSArray<MPPCategory *> *> *)handedness handedness:(NSArray<NSArray<MPPCategory *> *> *)handedness
landmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)landmarks landmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)landmarks
worldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)worldLandmarks worldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)worldLandmarks
timestampInMilliseconds:(NSInteger)timestampInMilliseconds { timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds]; self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
if (self) { if (self) {
_landmarks = landmarks; _landmarks = landmarks;

View File

@ -22,7 +22,12 @@ objc_library(
hdrs = ["sources/MPPGestureRecognizerOptions+Helpers.h"], hdrs = ["sources/MPPGestureRecognizerOptions+Helpers.h"],
deps = [ deps = [
"//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers", "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/ios/components/processors/utils:MPPClassifierOptionsHelpers", "//mediapipe/tasks/ios/components/processors/utils:MPPClassifierOptionsHelpers",
"//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol", "//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol",

View File

@ -18,7 +18,12 @@
#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h" #import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h"
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" #import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
namespace { namespace {
using CalculatorOptionsProto = mediapipe::CalculatorOptions; using CalculatorOptionsProto = mediapipe::CalculatorOptions;

View File

@ -17,7 +17,7 @@
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
static const int kMicroSecondsPerMilliSecond = 1000; static const int kMicrosecondsPerMillisecond = 1000;
namespace { namespace {
using ClassificationResultProto = using ClassificationResultProto =
@ -29,19 +29,26 @@ using ::mediapipe::Packet;
+ (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket: + (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket:
(const Packet &)packet { (const Packet &)packet {
MPPClassificationResult *classificationResult; // Even if packet does not validate as the expected type, you can safely access the timestamp.
NSInteger timestampInMilliSeconds =
(NSInteger)(packet.Timestamp().Value() / kMicrosecondsPerMillisecond);
if (!packet.ValidateAsType<ClassificationResultProto>().ok()) { if (!packet.ValidateAsType<ClassificationResultProto>().ok()) {
return nil; // MPPClassificationResult's timestamp is populated from timestamp `ClassificationResultProto`'s
// timestamp_ms(). It is 0 since the packet can't be validated as a `ClassificationResultProto`.
return [[MPPImageClassifierResult alloc]
initWithClassificationResult:[[MPPClassificationResult alloc] initWithClassifications:@[]
timestampInMilliseconds:0]
timestampInMilliseconds:timestampInMilliSeconds];
} }
classificationResult = [MPPClassificationResult MPPClassificationResult *classificationResult = [MPPClassificationResult
classificationResultWithProto:packet.Get<ClassificationResultProto>()]; classificationResultWithProto:packet.Get<ClassificationResultProto>()];
return [[MPPImageClassifierResult alloc] return [[MPPImageClassifierResult alloc]
initWithClassificationResult:classificationResult initWithClassificationResult:classificationResult
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() / timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)]; kMicrosecondsPerMillisecond)];
} }
@end @end

View File

@ -17,9 +17,9 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
objc_library( objc_library(
name = "MPPObjectDetectionResult", name = "MPPObjectDetectorResult",
srcs = ["sources/MPPObjectDetectionResult.m"], srcs = ["sources/MPPObjectDetectorResult.m"],
hdrs = ["sources/MPPObjectDetectionResult.h"], hdrs = ["sources/MPPObjectDetectorResult.h"],
deps = [ deps = [
"//mediapipe/tasks/ios/components/containers:MPPDetection", "//mediapipe/tasks/ios/components/containers:MPPDetection",
"//mediapipe/tasks/ios/core:MPPTaskResult", "//mediapipe/tasks/ios/core:MPPTaskResult",
@ -31,7 +31,7 @@ objc_library(
srcs = ["sources/MPPObjectDetectorOptions.m"], srcs = ["sources/MPPObjectDetectorOptions.m"],
hdrs = ["sources/MPPObjectDetectorOptions.h"], hdrs = ["sources/MPPObjectDetectorOptions.h"],
deps = [ deps = [
":MPPObjectDetectionResult", ":MPPObjectDetectorResult",
"//mediapipe/tasks/ios/core:MPPTaskOptions", "//mediapipe/tasks/ios/core:MPPTaskOptions",
"//mediapipe/tasks/ios/vision/core:MPPRunningMode", "//mediapipe/tasks/ios/vision/core:MPPRunningMode",
], ],
@ -47,8 +47,8 @@ objc_library(
"-x objective-c++", "-x objective-c++",
], ],
deps = [ deps = [
":MPPObjectDetectionResult",
":MPPObjectDetectorOptions", ":MPPObjectDetectorOptions",
":MPPObjectDetectorResult",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers", "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
@ -56,7 +56,7 @@ objc_library(
"//mediapipe/tasks/ios/vision/core:MPPImage", "//mediapipe/tasks/ios/vision/core:MPPImage",
"//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator", "//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
"//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner", "//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner",
"//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectionResultHelpers",
"//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectorOptionsHelpers", "//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectorOptionsHelpers",
"//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectorResultHelpers",
], ],
) )

View File

@ -15,8 +15,8 @@
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h" #import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h" #import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h"
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
@ -109,14 +109,13 @@ NS_SWIFT_NAME(ObjectDetector)
* @param error An optional error parameter populated when there is an error in performing object * @param error An optional error parameter populated when there is an error in performing object
* detection on the input image. * detection on the input image.
* *
* @return An `MPPObjectDetectionResult` object that contains a list of detections, each detection * @return An `MPPObjectDetectorResult` object that contains a list of detections, each detection
* has a bounding box that is expressed in the unrotated input frame of reference coordinates * has a bounding box that is expressed in the unrotated input frame of reference coordinates
* system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying * system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying
* image data. * image data.
*/ */
- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image - (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image
error:(NSError **)error error:(NSError **)error NS_SWIFT_NAME(detect(image:));
NS_SWIFT_NAME(detect(image:));
/** /**
* Performs object detection on the provided video frame of type `MPPImage` using the whole * Performs object detection on the provided video frame of type `MPPImage` using the whole
@ -139,14 +138,14 @@ NS_SWIFT_NAME(ObjectDetector)
* @param error An optional error parameter populated when there is an error in performing object * @param error An optional error parameter populated when there is an error in performing object
* detection on the input image. * detection on the input image.
* *
* @return An `MPPObjectDetectionResult` object that contains a list of detections, each detection * @return An `MPPObjectDetectorResult` object that contains a list of detections, each detection
* has a bounding box that is expressed in the unrotated input frame of reference coordinates * has a bounding box that is expressed in the unrotated input frame of reference coordinates
* system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying * system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying
* image data. * image data.
*/ */
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - (nullable MPPObjectDetectorResult *)detectInVideoFrame:(MPPImage *)image
timestampInMilliseconds:(NSInteger)timestampInMilliseconds timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error error:(NSError **)error
NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:)); NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:));
/** /**

View File

@ -19,8 +19,8 @@
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h" #import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h" #import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h"
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.h"
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorOptions+Helpers.h" #import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorOptions+Helpers.h"
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.h"
namespace { namespace {
using ::mediapipe::NormalizedRect; using ::mediapipe::NormalizedRect;
@ -118,9 +118,9 @@ static NSString *const kTaskName = @"objectDetector";
return; return;
} }
MPPObjectDetectionResult *result = [MPPObjectDetectionResult MPPObjectDetectorResult *result = [MPPObjectDetectorResult
objectDetectionResultWithDetectionsPacket:statusOrPackets.value()[kDetectionsStreamName objectDetectorResultWithDetectionsPacket:statusOrPackets
.cppString]]; .value()[kDetectionsStreamName.cppString]];
NSInteger timeStampInMilliseconds = NSInteger timeStampInMilliseconds =
outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() / outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() /
@ -184,9 +184,9 @@ static NSString *const kTaskName = @"objectDetector";
return inputPacketMap; return inputPacketMap;
} }
- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image - (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image
regionOfInterest:(CGRect)roi regionOfInterest:(CGRect)roi
error:(NSError **)error { error:(NSError **)error {
std::optional<NormalizedRect> rect = std::optional<NormalizedRect> rect =
[_visionTaskRunner normalizedRectFromRegionOfInterest:roi [_visionTaskRunner normalizedRectFromRegionOfInterest:roi
imageSize:CGSizeMake(image.width, image.height) imageSize:CGSizeMake(image.width, image.height)
@ -213,18 +213,18 @@ static NSString *const kTaskName = @"objectDetector";
return nil; return nil;
} }
return [MPPObjectDetectionResult return [MPPObjectDetectorResult
objectDetectionResultWithDetectionsPacket:outputPacketMap objectDetectorResultWithDetectionsPacket:outputPacketMap
.value()[kDetectionsStreamName.cppString]]; .value()[kDetectionsStreamName.cppString]];
} }
- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image error:(NSError **)error { - (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image error:(NSError **)error {
return [self detectInImage:image regionOfInterest:CGRectZero error:error]; return [self detectInImage:image regionOfInterest:CGRectZero error:error];
} }
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - (nullable MPPObjectDetectorResult *)detectInVideoFrame:(MPPImage *)image
timestampInMilliseconds:(NSInteger)timestampInMilliseconds timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error { error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampInMilliseconds:timestampInMilliseconds timestampInMilliseconds:timestampInMilliseconds
error:error]; error:error];
@ -239,9 +239,9 @@ static NSString *const kTaskName = @"objectDetector";
return nil; return nil;
} }
return [MPPObjectDetectionResult return [MPPObjectDetectorResult
objectDetectionResultWithDetectionsPacket:outputPacketMap objectDetectorResultWithDetectionsPacket:outputPacketMap
.value()[kDetectionsStreamName.cppString]]; .value()[kDetectionsStreamName.cppString]];
} }
- (BOOL)detectAsyncInImage:(MPPImage *)image - (BOOL)detectAsyncInImage:(MPPImage *)image

View File

@ -16,7 +16,7 @@
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h" #import "mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h"
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h" #import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
@ -44,7 +44,7 @@ NS_SWIFT_NAME(ObjectDetectorLiveStreamDelegate)
* *
* @param objectDetector The object detector which performed the object detection. * @param objectDetector The object detector which performed the object detection.
* This is useful to test equality when there are multiple instances of `MPPObjectDetector`. * This is useful to test equality when there are multiple instances of `MPPObjectDetector`.
* @param result The `MPPObjectDetectionResult` object that contains a list of detections, each * @param result The `MPPObjectDetectorResult` object that contains a list of detections, each
* detection has a bounding box that is expressed in the unrotated input frame of reference * detection has a bounding box that is expressed in the unrotated input frame of reference
* coordinates system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the * coordinates system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the
* underlying image data. * underlying image data.
@ -54,7 +54,7 @@ NS_SWIFT_NAME(ObjectDetectorLiveStreamDelegate)
* detection on the input live stream image data. * detection on the input live stream image data.
*/ */
- (void)objectDetector:(MPPObjectDetector *)objectDetector - (void)objectDetector:(MPPObjectDetector *)objectDetector
didFinishDetectionWithResult:(nullable MPPObjectDetectionResult *)result didFinishDetectionWithResult:(nullable MPPObjectDetectorResult *)result
timestampInMilliseconds:(NSInteger)timestampInMilliseconds timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(nullable NSError *)error error:(nullable NSError *)error
NS_SWIFT_NAME(objectDetector(_:didFinishDetection:timestampInMilliseconds:error:)); NS_SWIFT_NAME(objectDetector(_:didFinishDetection:timestampInMilliseconds:error:));

View File

@ -19,8 +19,8 @@
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
/** Represents the detection results generated by `MPPObjectDetector`. */ /** Represents the detection results generated by `MPPObjectDetector`. */
NS_SWIFT_NAME(ObjectDetectionResult) NS_SWIFT_NAME(ObjectDetectorResult)
@interface MPPObjectDetectionResult : MPPTaskResult @interface MPPObjectDetectorResult : MPPTaskResult
/** /**
* The array of `MPPDetection` objects each of which has a bounding box that is expressed in the * The array of `MPPDetection` objects each of which has a bounding box that is expressed in the
@ -30,7 +30,7 @@ NS_SWIFT_NAME(ObjectDetectionResult)
@property(nonatomic, readonly) NSArray<MPPDetection *> *detections; @property(nonatomic, readonly) NSArray<MPPDetection *> *detections;
/** /**
* Initializes a new `MPPObjectDetectionResult` with the given array of detections and timestamp (in * Initializes a new `MPPObjectDetectorResult` with the given array of detections and timestamp (in
* milliseconds). * milliseconds).
* *
* @param detections An array of `MPPDetection` objects each of which has a bounding box that is * @param detections An array of `MPPDetection` objects each of which has a bounding box that is
@ -38,7 +38,7 @@ NS_SWIFT_NAME(ObjectDetectionResult)
* x [0,image_height)`, which are the dimensions of the underlying image data. * x [0,image_height)`, which are the dimensions of the underlying image data.
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result. * @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
* *
* @return An instance of `MPPObjectDetectionResult` initialized with the given array of detections * @return An instance of `MPPObjectDetectorResult` initialized with the given array of detections
* and timestamp (in milliseconds). * and timestamp (in milliseconds).
*/ */
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections - (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections

View File

@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h" #import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
@implementation MPPObjectDetectionResult @implementation MPPObjectDetectorResult
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections - (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
timestampInMilliseconds:(NSInteger)timestampInMilliseconds { timestampInMilliseconds:(NSInteger)timestampInMilliseconds {

View File

@ -31,12 +31,12 @@ objc_library(
) )
objc_library( objc_library(
name = "MPPObjectDetectionResultHelpers", name = "MPPObjectDetectorResultHelpers",
srcs = ["sources/MPPObjectDetectionResult+Helpers.mm"], srcs = ["sources/MPPObjectDetectorResult+Helpers.mm"],
hdrs = ["sources/MPPObjectDetectionResult+Helpers.h"], hdrs = ["sources/MPPObjectDetectorResult+Helpers.h"],
deps = [ deps = [
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/tasks/ios/components/containers/utils:MPPDetectionHelpers", "//mediapipe/tasks/ios/components/containers/utils:MPPDetectionHelpers",
"//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetectionResult", "//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetectorResult",
], ],
) )

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h" #import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
@ -20,17 +20,17 @@ NS_ASSUME_NONNULL_BEGIN
static const int kMicroSecondsPerMilliSecond = 1000; static const int kMicroSecondsPerMilliSecond = 1000;
@interface MPPObjectDetectionResult (Helpers) @interface MPPObjectDetectorResult (Helpers)
/** /**
* Creates an `MPPObjectDetectionResult` from a MediaPipe packet containing a * Creates an `MPPObjectDetectorResult` from a MediaPipe packet containing a
* `std::vector<DetectionProto>`. * `std::vector<DetectionProto>`.
* *
* @param packet a MediaPipe packet wrapping a `std::vector<DetectionProto>`. * @param packet a MediaPipe packet wrapping a `std::vector<DetectionProto>`.
* *
* @return An `MPPObjectDetectionResult` object that contains a list of detections. * @return An `MPPObjectDetectorResult` object that contains a list of detections.
*/ */
+ (nullable MPPObjectDetectionResult *)objectDetectionResultWithDetectionsPacket: + (nullable MPPObjectDetectorResult *)objectDetectorResultWithDetectionsPacket:
(const mediapipe::Packet &)packet; (const mediapipe::Packet &)packet;
@end @end

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.h" #import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h" #import "mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h"
@ -21,9 +21,9 @@ using DetectionProto = ::mediapipe::Detection;
using ::mediapipe::Packet; using ::mediapipe::Packet;
} // namespace } // namespace
@implementation MPPObjectDetectionResult (Helpers) @implementation MPPObjectDetectorResult (Helpers)
+ (nullable MPPObjectDetectionResult *)objectDetectionResultWithDetectionsPacket: + (nullable MPPObjectDetectorResult *)objectDetectorResultWithDetectionsPacket:
(const Packet &)packet { (const Packet &)packet {
if (!packet.ValidateAsType<std::vector<DetectionProto>>().ok()) { if (!packet.ValidateAsType<std::vector<DetectionProto>>().ok()) {
return nil; return nil;
@ -37,10 +37,10 @@ using ::mediapipe::Packet;
[detections addObject:[MPPDetection detectionWithProto:detectionProto]]; [detections addObject:[MPPDetection detectionWithProto:detectionProto]];
} }
return [[MPPObjectDetectionResult alloc] return
initWithDetections:detections [[MPPObjectDetectorResult alloc] initWithDetections:detections
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() / timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)]; kMicroSecondsPerMilliSecond)];
} }
@end @end

View File

@ -166,7 +166,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
// For 90° and 270° rotations, we need to swap width and height. // For 90° and 270° rotations, we need to swap width and height.
// This is due to the internal behavior of ImageToTensorCalculator, which: // This is due to the internal behavior of ImageToTensorCalculator, which:
// - first denormalizes the provided rect by multiplying the rect width or // - first denormalizes the provided rect by multiplying the rect width or
// height by the image width or height, repectively. // height by the image width or height, respectively.
// - then rotates this by denormalized rect by the provided rotation, and // - then rotates this by denormalized rect by the provided rotation, and
// uses this for cropping, // uses this for cropping,
// - then finally rotates this back. // - then finally rotates this back.

View File

@ -115,6 +115,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
segmenterOptions.outputCategoryMask() segmenterOptions.outputCategoryMask()
? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask") ? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
: -1; : -1;
final int qualityScoresOutStreamIndex =
getStreamIndex.apply(outputStreams, "QUALITY_SCORES:quality_scores");
final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out"); final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
// TODO: Consolidate OutputHandler and TaskRunner. // TODO: Consolidate OutputHandler and TaskRunner.
@ -128,6 +130,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
return ImageSegmenterResult.create( return ImageSegmenterResult.create(
Optional.empty(), Optional.empty(),
Optional.empty(), Optional.empty(),
new ArrayList<>(),
packets.get(imageOutStreamIndex).getTimestamp()); packets.get(imageOutStreamIndex).getTimestamp());
} }
boolean copyImage = !segmenterOptions.resultListener().isPresent(); boolean copyImage = !segmenterOptions.resultListener().isPresent();
@ -182,9 +185,16 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA); new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
categoryMask = Optional.of(builder.build()); categoryMask = Optional.of(builder.build());
} }
float[] qualityScores =
PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
for (float score : qualityScores) {
qualityScoresList.add(score);
}
return ImageSegmenterResult.create( return ImageSegmenterResult.create(
confidenceMasks, confidenceMasks,
categoryMask, categoryMask,
qualityScoresList,
BaseVisionTaskApi.generateResultTimestampMs( BaseVisionTaskApi.generateResultTimestampMs(
segmenterOptions.runningMode(), packets.get(imageOutStreamIndex))); segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
} }
@ -592,8 +602,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
public abstract Builder setOutputCategoryMask(boolean value); public abstract Builder setOutputCategoryMask(boolean value);
/** /**
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph * /** Sets an optional {@link ResultListener} to receive the segmentation results when the
* pipeline is done processing an image. * graph pipeline is done processing an image.
*/ */
public abstract Builder setResultListener( public abstract Builder setResultListener(
ResultListener<ImageSegmenterResult, MPImage> value); ResultListener<ImageSegmenterResult, MPImage> value);

View File

@ -34,19 +34,30 @@ public abstract class ImageSegmenterResult implements TaskResult {
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a * @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
* category mask, where each pixel represents the class which the pixel in the original image * category mask, where each pixel represents the class which the pixel in the original image
* was predicted to belong to. * was predicted to belong to.
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Defaults
* to `1` if the model doesn't output quality scores. Each element corresponds to the score of
* the category in the model outputs.
* @param timestampMs a timestamp for this result. * @param timestampMs a timestamp for this result.
*/ */
// TODO: consolidate output formats across platforms. // TODO: consolidate output formats across platforms.
public static ImageSegmenterResult create( public static ImageSegmenterResult create(
Optional<List<MPImage>> confidenceMasks, Optional<MPImage> categoryMask, long timestampMs) { Optional<List<MPImage>> confidenceMasks,
Optional<MPImage> categoryMask,
List<Float> qualityScores,
long timestampMs) {
return new AutoValue_ImageSegmenterResult( return new AutoValue_ImageSegmenterResult(
confidenceMasks.map(Collections::unmodifiableList), categoryMask, timestampMs); confidenceMasks.map(Collections::unmodifiableList),
categoryMask,
Collections.unmodifiableList(qualityScores),
timestampMs);
} }
public abstract Optional<List<MPImage>> confidenceMasks(); public abstract Optional<List<MPImage>> confidenceMasks();
public abstract Optional<MPImage> categoryMask(); public abstract Optional<MPImage> categoryMask();
public abstract List<Float> qualityScores();
@Override @Override
public abstract long timestampMs(); public abstract long timestampMs();
} }

View File

@ -127,6 +127,10 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
outputStreams.add("CATEGORY_MASK:category_mask"); outputStreams.add("CATEGORY_MASK:category_mask");
} }
final int categoryMaskOutStreamIndex = outputStreams.size() - 1; final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
outputStreams.add("QUALITY_SCORES:quality_scores");
final int qualityScoresOutStreamIndex = outputStreams.size() - 1;
outputStreams.add("IMAGE:image_out"); outputStreams.add("IMAGE:image_out");
// TODO: add test for stream indices. // TODO: add test for stream indices.
final int imageOutStreamIndex = outputStreams.size() - 1; final int imageOutStreamIndex = outputStreams.size() - 1;
@ -142,6 +146,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
return ImageSegmenterResult.create( return ImageSegmenterResult.create(
Optional.empty(), Optional.empty(),
Optional.empty(), Optional.empty(),
new ArrayList<>(),
packets.get(imageOutStreamIndex).getTimestamp()); packets.get(imageOutStreamIndex).getTimestamp());
} }
// If resultListener is not provided, the resulted MPImage is deep copied from // If resultListener is not provided, the resulted MPImage is deep copied from
@ -199,9 +204,17 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
categoryMask = Optional.of(builder.build()); categoryMask = Optional.of(builder.build());
} }
float[] qualityScores =
PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
for (float score : qualityScores) {
qualityScoresList.add(score);
}
return ImageSegmenterResult.create( return ImageSegmenterResult.create(
confidenceMasks, confidenceMasks,
categoryMask, categoryMask,
qualityScoresList,
BaseVisionTaskApi.generateResultTimestampMs( BaseVisionTaskApi.generateResultTimestampMs(
RunningMode.IMAGE, packets.get(imageOutStreamIndex))); RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
} }

View File

@ -201,6 +201,7 @@ py_test(
"//mediapipe/tasks/testdata/vision:test_images", "//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models", "//mediapipe/tasks/testdata/vision:test_models",
], ],
tags = ["not_run:arm"],
deps = [ deps = [
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/components/containers:rect",

View File

@ -27,13 +27,14 @@ export declare interface Detection {
boundingBox?: BoundingBox; boundingBox?: BoundingBox;
/** /**
* Optional list of keypoints associated with the detection. Keypoints * List of keypoints associated with the detection. Keypoints represent
* represent interesting points related to the detection. For example, the * interesting points related to the detection. For example, the keypoints
* keypoints represent the eye, ear and mouth from face detection model. Or * represent the eye, ear and mouth from face detection model. Or in the
* in the template matching detection, e.g. KNIFT, they can represent the * template matching detection, e.g. KNIFT, they can represent the feature
* feature points for template matching. * points for template matching. Contains an empty list if no keypoints are
* detected.
*/ */
keypoints?: NormalizedKeypoint[]; keypoints: NormalizedKeypoint[];
} }
/** Detection results of a model. */ /** Detection results of a model. */

View File

@ -85,7 +85,8 @@ describe('convertFromDetectionProto()', () => {
categoryName: '', categoryName: '',
displayName: '', displayName: '',
}], }],
boundingBox: {originX: 0, originY: 0, width: 0, height: 0} boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
keypoints: []
}); });
}); });
}); });

View File

@ -26,7 +26,7 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
const labels = source.getLabelList(); const labels = source.getLabelList();
const displayNames = source.getDisplayNameList(); const displayNames = source.getDisplayNameList();
const detection: Detection = {categories: []}; const detection: Detection = {categories: [], keypoints: []};
for (let i = 0; i < scores.length; i++) { for (let i = 0; i < scores.length; i++) {
detection.categories.push({ detection.categories.push({
score: scores[i], score: scores[i],
@ -47,7 +47,6 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
} }
if (source.getLocationData()?.getRelativeKeypointsList().length) { if (source.getLocationData()?.getRelativeKeypointsList().length) {
detection.keypoints = [];
for (const keypoint of for (const keypoint of
source.getLocationData()!.getRelativeKeypointsList()) { source.getLocationData()!.getRelativeKeypointsList()) {
detection.keypoints.push({ detection.keypoints.push({

View File

@ -62,7 +62,10 @@ jasmine_node_test(
mediapipe_ts_library( mediapipe_ts_library(
name = "mask", name = "mask",
srcs = ["mask.ts"], srcs = ["mask.ts"],
deps = [":image"], deps = [
":image",
"//mediapipe/web/graph_runner:platform_utils",
],
) )
mediapipe_ts_library( mediapipe_ts_library(

View File

@ -60,6 +60,10 @@ class MPImageTestContext {
this.webGLTexture = gl.createTexture()!; this.webGLTexture = gl.createTexture()!;
gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture); gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
gl.texImage2D( gl.texImage2D(
gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, this.imageBitmap); gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, this.imageBitmap);
gl.bindTexture(gl.TEXTURE_2D, null); gl.bindTexture(gl.TEXTURE_2D, null);

View File

@ -187,10 +187,11 @@ export class MPImage {
destinationContainer = destinationContainer =
assertNotNull(gl.createTexture(), 'Failed to create texture'); assertNotNull(gl.createTexture(), 'Failed to create texture');
gl.bindTexture(gl.TEXTURE_2D, destinationContainer); gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
this.configureTextureParams();
gl.texImage2D( gl.texImage2D(
gl.TEXTURE_2D, 0, gl.RGBA, this.width, this.height, 0, gl.RGBA, gl.TEXTURE_2D, 0, gl.RGBA, this.width, this.height, 0, gl.RGBA,
gl.UNSIGNED_BYTE, null); gl.UNSIGNED_BYTE, null);
gl.bindTexture(gl.TEXTURE_2D, null);
shaderContext.bindFramebuffer(gl, destinationContainer); shaderContext.bindFramebuffer(gl, destinationContainer);
shaderContext.run(gl, /* flipVertically= */ false, () => { shaderContext.run(gl, /* flipVertically= */ false, () => {
@ -302,6 +303,20 @@ export class MPImage {
return webGLTexture; return webGLTexture;
} }
/** Sets texture params for the currently bound texture. */
private configureTextureParams() {
const gl = this.getGL();
// `gl.LINEAR` might break rendering for some textures, but it allows us to
// do smooth resizing. Ideally, this would be user-configurable, but for now
// we hard-code the value here to `gl.LINEAR` (versus `gl.NEAREST` for
// `MPMask` where we do not want to interpolate mask values, especially for
// category masks).
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
}
/** /**
* Binds the backing texture to the canvas. If the texture does not yet * Binds the backing texture to the canvas. If the texture does not yet
* exist, creates it first. * exist, creates it first.
@ -318,16 +333,12 @@ export class MPImage {
assertNotNull(gl.createTexture(), 'Failed to create texture'); assertNotNull(gl.createTexture(), 'Failed to create texture');
this.containers.push(webGLTexture); this.containers.push(webGLTexture);
this.ownsWebGLTexture = true; this.ownsWebGLTexture = true;
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
this.configureTextureParams();
} else {
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
} }
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
// TODO: Ideally, we would only set these once per texture and
// not once every frame.
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
return webGLTexture; return webGLTexture;
} }

View File

@ -60,8 +60,11 @@ class MPMaskTestContext {
} }
this.webGLTexture = gl.createTexture()!; this.webGLTexture = gl.createTexture()!;
gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture); gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texImage2D( gl.texImage2D(
gl.TEXTURE_2D, 0, gl.R32F, width, height, 0, gl.RED, gl.FLOAT, gl.TEXTURE_2D, 0, gl.R32F, width, height, 0, gl.RED, gl.FLOAT,
new Float32Array(pixels).map(v => v / 255)); new Float32Array(pixels).map(v => v / 255));

View File

@ -15,6 +15,7 @@
*/ */
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
import {isIOS} from '../../../../web/graph_runner/platform_utils';
/** Number of instances a user can keep alive before we raise a warning. */ /** Number of instances a user can keep alive before we raise a warning. */
const INSTANCE_COUNT_WARNING_THRESHOLD = 250; const INSTANCE_COUNT_WARNING_THRESHOLD = 250;
@ -32,6 +33,8 @@ enum MPMaskType {
/** The supported mask formats. For internal usage. */ /** The supported mask formats. For internal usage. */
export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture; export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
/** /**
* The wrapper class for MediaPipe segmentation masks. * The wrapper class for MediaPipe segmentation masks.
* *
@ -56,6 +59,9 @@ export class MPMask {
*/ */
private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD; private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD;
/** The format used to write pixel values from textures. */
private static texImage2DFormat?: GLenum;
/** @hideconstructor */ /** @hideconstructor */
constructor( constructor(
private readonly containers: MPMaskContainer[], private readonly containers: MPMaskContainer[],
@ -127,6 +133,29 @@ export class MPMask {
return this.convertToWebGLTexture(); return this.convertToWebGLTexture();
} }
/**
* Returns the texture format used for writing float textures on this
* platform.
*/
getTexImage2DFormat(): GLenum {
const gl = this.getGL();
if (!MPMask.texImage2DFormat) {
// Note: This is the same check we use in
// `SegmentationPostprocessorGl::GetSegmentationResultGpu()`.
if (gl.getExtension('EXT_color_buffer_float') &&
gl.getExtension('OES_texture_float_linear') &&
gl.getExtension('EXT_float_blend')) {
MPMask.texImage2DFormat = gl.R32F;
} else if (gl.getExtension('EXT_color_buffer_half_float')) {
MPMask.texImage2DFormat = gl.R16F;
} else {
throw new Error(
'GPU does not fully support 4-channel float32 or float16 formats');
}
}
return MPMask.texImage2DFormat;
}
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined; private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined; private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined; private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
@ -175,8 +204,10 @@ export class MPMask {
destinationContainer = destinationContainer =
assertNotNull(gl.createTexture(), 'Failed to create texture'); assertNotNull(gl.createTexture(), 'Failed to create texture');
gl.bindTexture(gl.TEXTURE_2D, destinationContainer); gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
this.configureTextureParams();
const format = this.getTexImage2DFormat();
gl.texImage2D( gl.texImage2D(
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED, gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
gl.FLOAT, null); gl.FLOAT, null);
gl.bindTexture(gl.TEXTURE_2D, null); gl.bindTexture(gl.TEXTURE_2D, null);
@ -207,7 +238,7 @@ export class MPMask {
if (!this.canvas) { if (!this.canvas) {
throw new Error( throw new Error(
'Conversion to different image formats require that a canvas ' + 'Conversion to different image formats require that a canvas ' +
'is passed when iniitializing the image.'); 'is passed when initializing the image.');
} }
if (!this.gl) { if (!this.gl) {
this.gl = assertNotNull( this.gl = assertNotNull(
@ -215,11 +246,6 @@ export class MPMask {
'You cannot use a canvas that is already bound to a different ' + 'You cannot use a canvas that is already bound to a different ' +
'type of rendering context.'); 'type of rendering context.');
} }
const ext = this.gl.getExtension('EXT_color_buffer_float');
if (!ext) {
// TODO: Ensure this works on iOS
throw new Error('Missing required EXT_color_buffer_float extension');
}
return this.gl; return this.gl;
} }
@ -237,18 +263,34 @@ export class MPMask {
if (uint8Array) { if (uint8Array) {
float32Array = new Float32Array(uint8Array).map(v => v / 255); float32Array = new Float32Array(uint8Array).map(v => v / 255);
} else { } else {
float32Array = new Float32Array(this.width * this.height);
const gl = this.getGL(); const gl = this.getGL();
const shaderContext = this.getShaderContext(); const shaderContext = this.getShaderContext();
float32Array = new Float32Array(this.width * this.height);
// Create texture if needed // Create texture if needed
const webGlTexture = this.convertToWebGLTexture(); const webGlTexture = this.convertToWebGLTexture();
// Create a framebuffer from the texture and read back pixels // Create a framebuffer from the texture and read back pixels
shaderContext.bindFramebuffer(gl, webGlTexture); shaderContext.bindFramebuffer(gl, webGlTexture);
gl.readPixels(
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array); if (isIOS()) {
shaderContext.unbindFramebuffer(); // WebKit on iOS only supports gl.HALF_FLOAT for single channel reads
// (as tested on iOS 16.4). HALF_FLOAT requires reading data into a
// Uint16Array, however, and requires a manual bitwise conversion from
// Uint16 to floating point numbers. This conversion is more expensive
// that reading back a Float32Array from the RGBA image and dropping
// the superfluous data, so we do this instead.
const outputArray = new Float32Array(this.width * this.height * 4);
gl.readPixels(
0, 0, this.width, this.height, gl.RGBA, gl.FLOAT, outputArray);
for (let i = 0, j = 0; i < float32Array.length; ++i, j += 4) {
float32Array[i] = outputArray[j];
}
} else {
gl.readPixels(
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
}
} }
this.containers.push(float32Array); this.containers.push(float32Array);
} }
@ -273,9 +315,9 @@ export class MPMask {
webGLTexture = this.bindTexture(); webGLTexture = this.bindTexture();
const data = this.convertToFloat32Array(); const data = this.convertToFloat32Array();
// TODO: Add support for R16F to support iOS const format = this.getTexImage2DFormat();
gl.texImage2D( gl.texImage2D(
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED, gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
gl.FLOAT, data); gl.FLOAT, data);
this.unbindTexture(); this.unbindTexture();
} }
@ -283,6 +325,19 @@ export class MPMask {
return webGLTexture; return webGLTexture;
} }
/** Sets texture params for the currently bound texture. */
private configureTextureParams() {
const gl = this.getGL();
// `gl.NEAREST` ensures that we do not get interpolated values for
// masks. In some cases, the user might want interpolation (e.g. for
// confidence masks), so we might want to make this user-configurable.
// Note that `MPImage` uses `gl.LINEAR`.
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
}
/** /**
* Binds the backing texture to the canvas. If the texture does not yet * Binds the backing texture to the canvas. If the texture does not yet
* exist, creates it first. * exist, creates it first.
@ -299,15 +354,12 @@ export class MPMask {
assertNotNull(gl.createTexture(), 'Failed to create texture'); assertNotNull(gl.createTexture(), 'Failed to create texture');
this.containers.push(webGLTexture); this.containers.push(webGLTexture);
this.ownsWebGLTexture = true; this.ownsWebGLTexture = true;
}
gl.bindTexture(gl.TEXTURE_2D, webGLTexture); gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
// TODO: Ideally, we would only set these once per texture and this.configureTextureParams();
// not once every frame. } else {
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); }
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
return webGLTexture; return webGLTexture;
} }

View File

@ -191,7 +191,8 @@ describe('FaceDetector', () => {
categoryName: '', categoryName: '',
displayName: '', displayName: '',
}], }],
boundingBox: {originX: 0, originY: 0, width: 0, height: 0} boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
keypoints: []
}); });
}); });
}); });

View File

@ -171,7 +171,7 @@ export class FaceStylizer extends VisionTaskRunner {
/** /**
* Performs face stylization on the provided single image and returns the * Performs face stylization on the provided single image and returns the
* result. This method creates a copy of the resulting image and should not be * result. This method creates a copy of the resulting image and should not be
* used in high-throughput applictions. Only use this method when the * used in high-throughput applications. Only use this method when the
* FaceStylizer is created with the image running mode. * FaceStylizer is created with the image running mode.
* *
* @param image An image to process. * @param image An image to process.
@ -182,7 +182,7 @@ export class FaceStylizer extends VisionTaskRunner {
/** /**
* Performs face stylization on the provided single image and returns the * Performs face stylization on the provided single image and returns the
* result. This method creates a copy of the resulting image and should not be * result. This method creates a copy of the resulting image and should not be
* used in high-throughput applictions. Only use this method when the * used in high-throughput applications. Only use this method when the
* FaceStylizer is created with the image running mode. * FaceStylizer is created with the image running mode.
* *
* The 'imageProcessingOptions' parameter can be used to specify one or all * The 'imageProcessingOptions' parameter can be used to specify one or all
@ -275,7 +275,7 @@ export class FaceStylizer extends VisionTaskRunner {
/** /**
* Performs face stylization on the provided video frame. This method creates * Performs face stylization on the provided video frame. This method creates
* a copy of the resulting image and should not be used in high-throughput * a copy of the resulting image and should not be used in high-throughput
* applictions. Only use this method when the FaceStylizer is created with the * applications. Only use this method when the FaceStylizer is created with the
* video running mode. * video running mode.
* *
* The input frame can be of any size. It's required to provide the video * The input frame can be of any size. It's required to provide the video

View File

@ -39,6 +39,7 @@ const IMAGE_STREAM = 'image_in';
const NORM_RECT_STREAM = 'norm_rect'; const NORM_RECT_STREAM = 'norm_rect';
const CONFIDENCE_MASKS_STREAM = 'confidence_masks'; const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
const CATEGORY_MASK_STREAM = 'category_mask'; const CATEGORY_MASK_STREAM = 'category_mask';
const QUALITY_SCORES_STREAM = 'quality_scores';
const IMAGE_SEGMENTER_GRAPH = const IMAGE_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
@ -61,6 +62,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
export class ImageSegmenter extends VisionTaskRunner { export class ImageSegmenter extends VisionTaskRunner {
private categoryMask?: MPMask; private categoryMask?: MPMask;
private confidenceMasks?: MPMask[]; private confidenceMasks?: MPMask[];
private qualityScores?: number[];
private labels: string[] = []; private labels: string[] = [];
private userCallback?: ImageSegmenterCallback; private userCallback?: ImageSegmenterCallback;
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
@ -229,7 +231,7 @@ export class ImageSegmenter extends VisionTaskRunner {
/** /**
* Performs image segmentation on the provided single image and returns the * Performs image segmentation on the provided single image and returns the
* segmentation result. This method creates a copy of the resulting masks and * segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-throughput applictions. Only use this method * should not be used in high-throughput applications. Only use this method
* when the ImageSegmenter is created with running mode `image`. * when the ImageSegmenter is created with running mode `image`.
* *
* @param image An image to process. * @param image An image to process.
@ -240,7 +242,7 @@ export class ImageSegmenter extends VisionTaskRunner {
/** /**
* Performs image segmentation on the provided single image and returns the * Performs image segmentation on the provided single image and returns the
* segmentation result. This method creates a copy of the resulting masks and * segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-v applictions. Only use this method when * should not be used in high-v applications. Only use this method when
* the ImageSegmenter is created with running mode `image`. * the ImageSegmenter is created with running mode `image`.
* *
* @param image An image to process. * @param image An image to process.
@ -318,7 +320,7 @@ export class ImageSegmenter extends VisionTaskRunner {
/** /**
* Performs image segmentation on the provided video frame and returns the * Performs image segmentation on the provided video frame and returns the
* segmentation result. This method creates a copy of the resulting masks and * segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-v applictions. Only use this method when * should not be used in high-v applications. Only use this method when
* the ImageSegmenter is created with running mode `video`. * the ImageSegmenter is created with running mode `video`.
* *
* @param videoFrame A video frame to process. * @param videoFrame A video frame to process.
@ -367,12 +369,13 @@ export class ImageSegmenter extends VisionTaskRunner {
private reset(): void { private reset(): void {
this.categoryMask = undefined; this.categoryMask = undefined;
this.confidenceMasks = undefined; this.confidenceMasks = undefined;
this.qualityScores = undefined;
} }
private processResults(): ImageSegmenterResult|void { private processResults(): ImageSegmenterResult|void {
try { try {
const result = const result = new ImageSegmenterResult(
new ImageSegmenterResult(this.confidenceMasks, this.categoryMask); this.confidenceMasks, this.categoryMask, this.qualityScores);
if (this.userCallback) { if (this.userCallback) {
this.userCallback(result); this.userCallback(result);
} else { } else {
@ -442,6 +445,20 @@ export class ImageSegmenter extends VisionTaskRunner {
}); });
} }
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
this.graphRunner.attachFloatVectorListener(
QUALITY_SCORES_STREAM, (scores, timestamp) => {
this.qualityScores = scores;
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
QUALITY_SCORES_STREAM, timestamp => {
this.categoryMask = undefined;
this.setLatestOutputTimestamp(timestamp);
});
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
} }

View File

@ -30,7 +30,13 @@ export class ImageSegmenterResult {
* `WebGLTexture`-backed `MPImage` where each pixel represents the class * `WebGLTexture`-backed `MPImage` where each pixel represents the class
* which the pixel in the original image was predicted to belong to. * which the pixel in the original image was predicted to belong to.
*/ */
readonly categoryMask?: MPMask) {} readonly categoryMask?: MPMask,
/**
* The quality scores of the result masks, in the range of [0, 1].
* Defaults to `1` if the model doesn't output quality scores. Each
* element corresponds to the score of the category in the model outputs.
*/
readonly qualityScores?: number[]) {}
/** Frees the resources held by the category and confidence masks. */ /** Frees the resources held by the category and confidence masks. */
close(): void { close(): void {

View File

@ -35,6 +35,8 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
((images: WasmImage, timestamp: number) => void)|undefined; ((images: WasmImage, timestamp: number) => void)|undefined;
confidenceMasksListener: confidenceMasksListener:
((images: WasmImage[], timestamp: number) => void)|undefined; ((images: WasmImage[], timestamp: number) => void)|undefined;
qualityScoresListener:
((data: number[], timestamp: number) => void)|undefined;
constructor() { constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null); super(createSpyWasmModule(), /* glCanvas= */ null);
@ -52,6 +54,12 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
expect(stream).toEqual('confidence_masks'); expect(stream).toEqual('confidence_masks');
this.confidenceMasksListener = listener; this.confidenceMasksListener = listener;
}); });
this.attachListenerSpies[2] =
spyOn(this.graphRunner, 'attachFloatVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('quality_scores');
this.qualityScoresListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
}); });
@ -266,6 +274,7 @@ describe('ImageSegmenter', () => {
it('invokes listener after masks are available', async () => { it('invokes listener after masks are available', async () => {
const categoryMask = new Uint8Array([1]); const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]); const confidenceMask = new Float32Array([0.0]);
const qualityScores = [1.0];
let listenerCalled = false; let listenerCalled = false;
await imageSegmenter.setOptions( await imageSegmenter.setOptions(
@ -283,11 +292,16 @@ describe('ImageSegmenter', () => {
], ],
1337); 1337);
expect(listenerCalled).toBeFalse(); expect(listenerCalled).toBeFalse();
imageSegmenter.qualityScoresListener!(qualityScores, 1337);
expect(listenerCalled).toBeFalse();
}); });
return new Promise<void>(resolve => { return new Promise<void>(resolve => {
imageSegmenter.segment({} as HTMLImageElement, () => { imageSegmenter.segment({} as HTMLImageElement, result => {
listenerCalled = true; listenerCalled = true;
expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.qualityScores).toEqual(qualityScores);
resolve(); resolve();
}); });
}); });

View File

@ -42,6 +42,7 @@ const NORM_RECT_IN_STREAM = 'norm_rect_in';
const ROI_IN_STREAM = 'roi_in'; const ROI_IN_STREAM = 'roi_in';
const CONFIDENCE_MASKS_STREAM = 'confidence_masks'; const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
const CATEGORY_MASK_STREAM = 'category_mask'; const CATEGORY_MASK_STREAM = 'category_mask';
const QUALITY_SCORES_STREAM = 'quality_scores';
const IMAGEA_SEGMENTER_GRAPH = const IMAGEA_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph'; 'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
const DEFAULT_OUTPUT_CATEGORY_MASK = false; const DEFAULT_OUTPUT_CATEGORY_MASK = false;
@ -86,6 +87,7 @@ export type InteractiveSegmenterCallback =
export class InteractiveSegmenter extends VisionTaskRunner { export class InteractiveSegmenter extends VisionTaskRunner {
private categoryMask?: MPMask; private categoryMask?: MPMask;
private confidenceMasks?: MPMask[]; private confidenceMasks?: MPMask[];
private qualityScores?: number[];
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
private userCallback?: InteractiveSegmenterCallback; private userCallback?: InteractiveSegmenterCallback;
@ -284,12 +286,13 @@ export class InteractiveSegmenter extends VisionTaskRunner {
private reset(): void { private reset(): void {
this.confidenceMasks = undefined; this.confidenceMasks = undefined;
this.categoryMask = undefined; this.categoryMask = undefined;
this.qualityScores = undefined;
} }
private processResults(): InteractiveSegmenterResult|void { private processResults(): InteractiveSegmenterResult|void {
try { try {
const result = new InteractiveSegmenterResult( const result = new InteractiveSegmenterResult(
this.confidenceMasks, this.categoryMask); this.confidenceMasks, this.categoryMask, this.qualityScores);
if (this.userCallback) { if (this.userCallback) {
this.userCallback(result); this.userCallback(result);
} else { } else {
@ -361,6 +364,20 @@ export class InteractiveSegmenter extends VisionTaskRunner {
}); });
} }
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
this.graphRunner.attachFloatVectorListener(
QUALITY_SCORES_STREAM, (scores, timestamp) => {
this.qualityScores = scores;
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
QUALITY_SCORES_STREAM, timestamp => {
this.categoryMask = undefined;
this.setLatestOutputTimestamp(timestamp);
});
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
} }

View File

@ -30,7 +30,13 @@ export class InteractiveSegmenterResult {
* `WebGLTexture`-backed `MPImage` where each pixel represents the class * `WebGLTexture`-backed `MPImage` where each pixel represents the class
* which the pixel in the original image was predicted to belong to. * which the pixel in the original image was predicted to belong to.
*/ */
readonly categoryMask?: MPMask) {} readonly categoryMask?: MPMask,
/**
* The quality scores of the result masks, in the range of [0, 1].
* Defaults to `1` if the model doesn't output quality scores. Each
* element corresponds to the score of the category in the model outputs.
*/
readonly qualityScores?: number[]) {}
/** Frees the resources held by the category and confidence masks. */ /** Frees the resources held by the category and confidence masks. */
close(): void { close(): void {

View File

@ -46,6 +46,8 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
((images: WasmImage, timestamp: number) => void)|undefined; ((images: WasmImage, timestamp: number) => void)|undefined;
confidenceMasksListener: confidenceMasksListener:
((images: WasmImage[], timestamp: number) => void)|undefined; ((images: WasmImage[], timestamp: number) => void)|undefined;
qualityScoresListener:
((data: number[], timestamp: number) => void)|undefined;
lastRoi?: RenderDataProto; lastRoi?: RenderDataProto;
constructor() { constructor() {
@ -64,6 +66,12 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
expect(stream).toEqual('confidence_masks'); expect(stream).toEqual('confidence_masks');
this.confidenceMasksListener = listener; this.confidenceMasksListener = listener;
}); });
this.attachListenerSpies[2] =
spyOn(this.graphRunner, 'attachFloatVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('quality_scores');
this.qualityScoresListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
}); });
@ -277,9 +285,10 @@ describe('InteractiveSegmenter', () => {
}); });
}); });
it('invokes listener after masks are avaiblae', async () => { it('invokes listener after masks are available', async () => {
const categoryMask = new Uint8Array([1]); const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]); const confidenceMask = new Float32Array([0.0]);
const qualityScores = [1.0];
let listenerCalled = false; let listenerCalled = false;
await interactiveSegmenter.setOptions( await interactiveSegmenter.setOptions(
@ -297,11 +306,16 @@ describe('InteractiveSegmenter', () => {
], ],
1337); 1337);
expect(listenerCalled).toBeFalse(); expect(listenerCalled).toBeFalse();
interactiveSegmenter.qualityScoresListener!(qualityScores, 1337);
expect(listenerCalled).toBeFalse();
}); });
return new Promise<void>(resolve => { return new Promise<void>(resolve => {
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => { interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
listenerCalled = true; listenerCalled = true;
expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.qualityScores).toEqual(qualityScores);
resolve(); resolve();
}); });
}); });

View File

@ -210,7 +210,8 @@ describe('ObjectDetector', () => {
categoryName: '', categoryName: '',
displayName: '', displayName: '',
}], }],
boundingBox: {originX: 0, originY: 0, width: 0, height: 0} boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
keypoints: []
}); });
}); });
}); });

View File

@ -21,3 +21,16 @@ export function isWebKit(browser = navigator) {
// it uses "CriOS". // it uses "CriOS".
return userAgent.includes('Safari') && !userAgent.includes('Chrome'); return userAgent.includes('Safari') && !userAgent.includes('Chrome');
} }
/** Detect if code is running on iOS. */
export function isIOS() {
// Source:
// https://stackoverflow.com/questions/9038625/detect-if-device-is-ios
return [
'iPad Simulator', 'iPhone Simulator', 'iPod Simulator', 'iPad', 'iPhone',
'iPod'
// tslint:disable-next-line:deprecation
].includes(navigator.platform)
// iPad on iOS 13 detection
|| (navigator.userAgent.includes('Mac') && 'ontouchend' in document);
}

View File

@ -357,7 +357,10 @@ class BuildExtension(build_ext.build_ext):
for ext in self.extensions: for ext in self.extensions:
target_name = self.get_ext_fullpath(ext.name) target_name = self.get_ext_fullpath(ext.name)
# Build x86 # Build x86
self._build_binary(ext) self._build_binary(
ext,
['--cpu=darwin', '--ios_multi_cpus=i386,x86_64,armv7,arm64'],
)
x86_name = self.get_ext_fullpath(ext.name) x86_name = self.get_ext_fullpath(ext.name)
# Build Arm64 # Build Arm64
ext.name = ext.name + '.arm64' ext.name = ext.name + '.arm64'

View File

@ -42,16 +42,15 @@ filegroup(
"include/flatbuffers/allocator.h", "include/flatbuffers/allocator.h",
"include/flatbuffers/array.h", "include/flatbuffers/array.h",
"include/flatbuffers/base.h", "include/flatbuffers/base.h",
"include/flatbuffers/bfbs_generator.h",
"include/flatbuffers/buffer.h", "include/flatbuffers/buffer.h",
"include/flatbuffers/buffer_ref.h", "include/flatbuffers/buffer_ref.h",
"include/flatbuffers/code_generator.h", "include/flatbuffers/code_generator.h",
"include/flatbuffers/code_generators.h", "include/flatbuffers/code_generators.h",
"include/flatbuffers/default_allocator.h", "include/flatbuffers/default_allocator.h",
"include/flatbuffers/detached_buffer.h", "include/flatbuffers/detached_buffer.h",
"include/flatbuffers/file_manager.h",
"include/flatbuffers/flatbuffer_builder.h", "include/flatbuffers/flatbuffer_builder.h",
"include/flatbuffers/flatbuffers.h", "include/flatbuffers/flatbuffers.h",
"include/flatbuffers/flatc.h",
"include/flatbuffers/flex_flat_util.h", "include/flatbuffers/flex_flat_util.h",
"include/flatbuffers/flexbuffers.h", "include/flatbuffers/flexbuffers.h",
"include/flatbuffers/grpc.h", "include/flatbuffers/grpc.h",

View File

@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
def repo(): def repo():
third_party_http_archive( third_party_http_archive(
name = "flatbuffers", name = "flatbuffers",
strip_prefix = "flatbuffers-23.1.21", strip_prefix = "flatbuffers-23.5.8",
sha256 = "d84cb25686514348e615163b458ae0767001b24b42325f426fd56406fd384238", sha256 = "55b75dfa5b6f6173e4abf9c35284a10482ba65db886b39db511eba6c244f1e88",
urls = [ urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v23.1.21.tar.gz", "https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz",
"https://github.com/google/flatbuffers/archive/v23.1.21.tar.gz", "https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz",
], ],
build_file = "//third_party/flatbuffers:BUILD.bazel", build_file = "//third_party/flatbuffers:BUILD.bazel",
delete = ["build_defs.bzl", "BUILD.bazel"], delete = ["build_defs.bzl", "BUILD.bazel"],