Pulled changes from master
This commit is contained in:
commit
164eae8c16
|
@ -55,6 +55,10 @@ MEDIAPIPE_REGISTER_NODE(ConcatenateUInt64VectorCalculator);
|
|||
typedef ConcatenateVectorCalculator<bool> ConcatenateBoolVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateBoolVectorCalculator);
|
||||
|
||||
typedef ConcatenateVectorCalculator<std::string>
|
||||
ConcatenateStringVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateStringVectorCalculator);
|
||||
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "ConcatenateTfLiteTensorVectorCalculator"
|
||||
|
|
|
@ -30,13 +30,15 @@ namespace mediapipe {
|
|||
typedef ConcatenateVectorCalculator<int> TestConcatenateIntVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(TestConcatenateIntVectorCalculator);
|
||||
|
||||
void AddInputVector(int index, const std::vector<int>& input, int64_t timestamp,
|
||||
template <typename T>
|
||||
void AddInputVector(int index, const std::vector<T>& input, int64_t timestamp,
|
||||
CalculatorRunner* runner) {
|
||||
runner->MutableInputs()->Index(index).packets.push_back(
|
||||
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
|
||||
MakePacket<std::vector<T>>(input).At(Timestamp(timestamp)));
|
||||
}
|
||||
|
||||
void AddInputVectors(const std::vector<std::vector<int>>& inputs,
|
||||
template <typename T>
|
||||
void AddInputVectors(const std::vector<std::vector<T>>& inputs,
|
||||
int64_t timestamp, CalculatorRunner* runner) {
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
AddInputVector(i, inputs[i], timestamp, runner);
|
||||
|
@ -382,6 +384,23 @@ TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) {
|
|||
EXPECT_EQ(0, outputs.size());
|
||||
}
|
||||
|
||||
TEST(ConcatenateStringVectorCalculatorTest, OneTimestamp) {
|
||||
CalculatorRunner runner("ConcatenateStringVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<std::string>> inputs = {
|
||||
{"a", "b"}, {"c"}, {"d", "e", "f"}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<std::string> expected_vector = {"a", "b", "c", "d", "e", "f"};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<std::string>>());
|
||||
}
|
||||
|
||||
typedef ConcatenateVectorCalculator<std::unique_ptr<int>>
|
||||
TestConcatenateUniqueIntPtrCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(TestConcatenateUniqueIntPtrCalculator);
|
||||
|
|
|
@ -1099,6 +1099,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = True, # Defines TestServiceCalculator
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
|
@ -68,11 +68,11 @@ StatusBuilder&& StatusBuilder::SetNoLogging() && {
|
|||
return std::move(SetNoLogging());
|
||||
}
|
||||
|
||||
StatusBuilder::operator Status() const& {
|
||||
StatusBuilder::operator absl::Status() const& {
|
||||
return StatusBuilder(*this).JoinMessageToStatus();
|
||||
}
|
||||
|
||||
StatusBuilder::operator Status() && { return JoinMessageToStatus(); }
|
||||
StatusBuilder::operator absl::Status() && { return JoinMessageToStatus(); }
|
||||
|
||||
absl::Status StatusBuilder::JoinMessageToStatus() {
|
||||
if (!impl_) {
|
||||
|
|
|
@ -83,8 +83,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
|
|||
return std::move(*this << msg);
|
||||
}
|
||||
|
||||
operator Status() const&;
|
||||
operator Status() &&;
|
||||
operator absl::Status() const&;
|
||||
operator absl::Status() &&;
|
||||
|
||||
absl::Status JoinMessageToStatus();
|
||||
|
||||
|
|
|
@ -403,11 +403,11 @@ std::ostream &operator<<(std::ostream &os,
|
|||
lhs op## = rhs; \
|
||||
return lhs; \
|
||||
}
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(+);
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(-);
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(&);
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(|);
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(^);
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(+)
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(-)
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(&)
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(|)
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(^)
|
||||
#undef STRONG_INT_VS_STRONG_INT_BINARY_OP
|
||||
|
||||
// Define operators that take one StrongInt and one native integer argument.
|
||||
|
@ -431,12 +431,12 @@ STRONG_INT_VS_STRONG_INT_BINARY_OP(^);
|
|||
rhs op## = lhs; \
|
||||
return rhs; \
|
||||
}
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(*);
|
||||
NUMERIC_VS_STRONG_INT_BINARY_OP(*);
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(/);
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(%);
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(<<); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(*)
|
||||
NUMERIC_VS_STRONG_INT_BINARY_OP(*)
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(/)
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(%)
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(<<) // NOLINT(whitespace/operators)
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(>>) // NOLINT(whitespace/operators)
|
||||
#undef STRONG_INT_VS_NUMERIC_BINARY_OP
|
||||
#undef NUMERIC_VS_STRONG_INT_BINARY_OP
|
||||
|
||||
|
@ -447,12 +447,12 @@ STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators)
|
|||
StrongInt<TagType, ValueType, ValidatorType> rhs) { \
|
||||
return lhs.value() op rhs.value(); \
|
||||
}
|
||||
STRONG_INT_COMPARISON_OP(==); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(!=); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(<); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(<=); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(>); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(>=); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(==) // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(!=) // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(<) // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(<=) // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(>) // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(>=) // NOLINT(whitespace/operators)
|
||||
#undef STRONG_INT_COMPARISON_OP
|
||||
|
||||
} // namespace intops
|
||||
|
|
|
@ -44,7 +44,6 @@ class GraphServiceBase {
|
|||
|
||||
constexpr GraphServiceBase(const char* key) : key(key) {}
|
||||
|
||||
virtual ~GraphServiceBase() = default;
|
||||
inline virtual absl::StatusOr<Packet> CreateDefaultObject() const {
|
||||
return DefaultInitializationUnsupported();
|
||||
}
|
||||
|
@ -52,14 +51,32 @@ class GraphServiceBase {
|
|||
const char* key;
|
||||
|
||||
protected:
|
||||
// `GraphService<T>` objects, deriving `GraphServiceBase` are designed to be
|
||||
// global constants and not ever deleted through `GraphServiceBase`. Hence,
|
||||
// protected and non-virtual destructor which helps to make `GraphService<T>`
|
||||
// trivially destructible and properly defined as global constants.
|
||||
//
|
||||
// A class with any virtual functions should have a destructor that is either
|
||||
// public and virtual or else protected and non-virtual.
|
||||
// https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rc-dtor-virtual
|
||||
~GraphServiceBase() = default;
|
||||
|
||||
absl::Status DefaultInitializationUnsupported() const {
|
||||
return absl::UnimplementedError(absl::StrCat(
|
||||
"Graph service '", key, "' does not support default initialization"));
|
||||
}
|
||||
};
|
||||
|
||||
// A global constant to refer a service:
|
||||
// - Requesting `CalculatorContract::UseService` from calculator
|
||||
// - Accessing `Calculator/SubgraphContext::Service`from calculator/subgraph
|
||||
// - Setting before graph initialization `CalculatorGraph::SetServiceObject`
|
||||
//
|
||||
// NOTE: In headers, define your graph service reference safely as following:
|
||||
// `inline constexpr GraphService<YourService> kYourService("YourService");`
|
||||
//
|
||||
template <typename T>
|
||||
class GraphService : public GraphServiceBase {
|
||||
class GraphService final : public GraphServiceBase {
|
||||
public:
|
||||
using type = T;
|
||||
using packet_type = std::shared_ptr<T>;
|
||||
|
@ -68,7 +85,7 @@ class GraphService : public GraphServiceBase {
|
|||
kDisallowDefaultInitialization)
|
||||
: GraphServiceBase(my_key), default_init_(default_init) {}
|
||||
|
||||
absl::StatusOr<Packet> CreateDefaultObject() const override {
|
||||
absl::StatusOr<Packet> CreateDefaultObject() const final {
|
||||
if (default_init_ != kAllowDefaultInitialization) {
|
||||
return DefaultInitializationUnsupported();
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
const GraphService<int> kIntService("mediapipe::IntService");
|
||||
constexpr GraphService<int> kIntService("mediapipe::IntService");
|
||||
} // namespace
|
||||
|
||||
TEST(GraphServiceManager, SetGetServiceObject) {
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
|
||||
#include "mediapipe/framework/graph_service.h"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
|
@ -159,7 +161,7 @@ TEST_F(GraphServiceTest, CreateDefault) {
|
|||
|
||||
struct TestServiceData {};
|
||||
|
||||
const GraphService<TestServiceData> kTestServiceAllowDefaultInitialization(
|
||||
constexpr GraphService<TestServiceData> kTestServiceAllowDefaultInitialization(
|
||||
"kTestServiceAllowDefaultInitialization",
|
||||
GraphServiceBase::kAllowDefaultInitialization);
|
||||
|
||||
|
@ -272,9 +274,13 @@ TEST(AllowDefaultInitializationGraphServiceTest,
|
|||
HasSubstr("Service is unavailable.")));
|
||||
}
|
||||
|
||||
const GraphService<TestServiceData> kTestServiceDisallowDefaultInitialization(
|
||||
"kTestServiceDisallowDefaultInitialization",
|
||||
GraphServiceBase::kDisallowDefaultInitialization);
|
||||
constexpr GraphService<TestServiceData>
|
||||
kTestServiceDisallowDefaultInitialization(
|
||||
"kTestServiceDisallowDefaultInitialization",
|
||||
GraphServiceBase::kDisallowDefaultInitialization);
|
||||
|
||||
static_assert(std::is_trivially_destructible_v<GraphService<TestServiceData>>,
|
||||
"GraphService is not trivially destructible");
|
||||
|
||||
class FailOnUnavailableOptionalDisallowDefaultInitServiceCalculator
|
||||
: public CalculatorBase {
|
||||
|
|
|
@ -16,15 +16,6 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
const GraphService<TestServiceObject> kTestService(
|
||||
"test_service", GraphServiceBase::kDisallowDefaultInitialization);
|
||||
const GraphService<int> kAnotherService(
|
||||
"another_service", GraphServiceBase::kAllowDefaultInitialization);
|
||||
const GraphService<NoDefaultConstructor> kNoDefaultService(
|
||||
"no_default_service", GraphServiceBase::kAllowDefaultInitialization);
|
||||
const GraphService<NeedsCreateMethod> kNeedsCreateService(
|
||||
"needs_create_service", GraphServiceBase::kAllowDefaultInitialization);
|
||||
|
||||
absl::Status TestServiceCalculator::GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<int>();
|
||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||
|
|
|
@ -22,14 +22,17 @@ namespace mediapipe {
|
|||
|
||||
using TestServiceObject = std::map<std::string, int>;
|
||||
|
||||
extern const GraphService<TestServiceObject> kTestService;
|
||||
extern const GraphService<int> kAnotherService;
|
||||
inline constexpr GraphService<TestServiceObject> kTestService(
|
||||
"test_service", GraphServiceBase::kDisallowDefaultInitialization);
|
||||
inline constexpr GraphService<int> kAnotherService(
|
||||
"another_service", GraphServiceBase::kAllowDefaultInitialization);
|
||||
|
||||
class NoDefaultConstructor {
|
||||
public:
|
||||
NoDefaultConstructor() = delete;
|
||||
};
|
||||
extern const GraphService<NoDefaultConstructor> kNoDefaultService;
|
||||
inline constexpr GraphService<NoDefaultConstructor> kNoDefaultService(
|
||||
"no_default_service", GraphServiceBase::kAllowDefaultInitialization);
|
||||
|
||||
class NeedsCreateMethod {
|
||||
public:
|
||||
|
@ -40,7 +43,8 @@ class NeedsCreateMethod {
|
|||
private:
|
||||
NeedsCreateMethod() = default;
|
||||
};
|
||||
extern const GraphService<NeedsCreateMethod> kNeedsCreateService;
|
||||
inline constexpr GraphService<NeedsCreateMethod> kNeedsCreateService(
|
||||
"needs_create_service", GraphServiceBase::kAllowDefaultInitialization);
|
||||
|
||||
// Use a service.
|
||||
class TestServiceCalculator : public CalculatorBase {
|
||||
|
|
|
@ -57,7 +57,7 @@ namespace mediapipe {
|
|||
// have underflow/overflow etc. This type is used internally by Timestamp
|
||||
// and TimestampDiff.
|
||||
MEDIAPIPE_DEFINE_SAFE_INT_TYPE(TimestampBaseType, int64,
|
||||
mediapipe::intops::LogFatalOnError);
|
||||
mediapipe::intops::LogFatalOnError)
|
||||
|
||||
class TimestampDiff;
|
||||
|
||||
|
|
|
@ -272,17 +272,20 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string);
|
|||
#define MEDIAPIPE_REGISTER_TYPE(type, type_name, serialize_fn, deserialize_fn) \
|
||||
SET_MEDIAPIPE_TYPE_MAP_VALUE( \
|
||||
mediapipe::PacketTypeIdToMediaPipeTypeData, \
|
||||
mediapipe::tool::GetTypeHash< \
|
||||
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
|
||||
mediapipe::TypeId::Of< \
|
||||
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
|
||||
.hash_code(), \
|
||||
(mediapipe::MediaPipeTypeData{ \
|
||||
mediapipe::tool::GetTypeHash< \
|
||||
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
|
||||
mediapipe::TypeId::Of< \
|
||||
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
|
||||
.hash_code(), \
|
||||
type_name, serialize_fn, deserialize_fn})); \
|
||||
SET_MEDIAPIPE_TYPE_MAP_VALUE( \
|
||||
mediapipe::PacketTypeStringToMediaPipeTypeData, type_name, \
|
||||
(mediapipe::MediaPipeTypeData{ \
|
||||
mediapipe::tool::GetTypeHash< \
|
||||
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
|
||||
mediapipe::TypeId::Of< \
|
||||
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
|
||||
.hash_code(), \
|
||||
type_name, serialize_fn, deserialize_fn}));
|
||||
// End define MEDIAPIPE_REGISTER_TYPE.
|
||||
|
||||
|
|
|
@ -38,7 +38,10 @@ cc_library(
|
|||
srcs = ["gpu_service.cc"],
|
||||
hdrs = ["gpu_service.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:graph_service"] + select({
|
||||
deps = [
|
||||
"//mediapipe/framework:graph_service",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
":gpu_shared_data_internal",
|
||||
],
|
||||
|
@ -292,6 +295,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"@com_google_absl//absl/functional:bind_front",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
] + select({
|
||||
|
@ -630,6 +634,7 @@ cc_library(
|
|||
"//mediapipe/framework:executor",
|
||||
"//mediapipe/framework/deps:no_destructor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
] + select({
|
||||
"//conditions:default": [],
|
||||
"//mediapipe:apple": [
|
||||
|
|
|
@ -47,6 +47,7 @@ std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Create(int width, int height,
|
|||
auto buf = absl::make_unique<GlTextureBuffer>(GL_TEXTURE_2D, 0, width, height,
|
||||
format, nullptr);
|
||||
if (!buf->CreateInternal(data, alignment)) {
|
||||
LOG(WARNING) << "Failed to create a GL texture";
|
||||
return nullptr;
|
||||
}
|
||||
return buf;
|
||||
|
@ -106,7 +107,10 @@ GlTextureBuffer::GlTextureBuffer(GLenum target, GLuint name, int width,
|
|||
|
||||
bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
|
||||
auto context = GlContext::GetCurrent();
|
||||
if (!context) return false;
|
||||
if (!context) {
|
||||
LOG(WARNING) << "Cannot create a GL texture without a valid context";
|
||||
return false;
|
||||
}
|
||||
|
||||
producer_context_ = context; // Save creation GL context.
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_format.h"
|
||||
|
@ -72,8 +73,10 @@ class GpuBuffer {
|
|||
// are not portable. Applications and calculators should normally obtain
|
||||
// GpuBuffers in a portable way from the framework, e.g. using
|
||||
// GpuBufferMultiPool.
|
||||
explicit GpuBuffer(std::shared_ptr<internal::GpuBufferStorage> storage)
|
||||
: holder_(std::make_shared<StorageHolder>(std::move(storage))) {}
|
||||
explicit GpuBuffer(std::shared_ptr<internal::GpuBufferStorage> storage) {
|
||||
CHECK(storage) << "Cannot construct GpuBuffer with null storage";
|
||||
holder_ = std::make_shared<StorageHolder>(std::move(storage));
|
||||
}
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
// This is used to support backward-compatible construction of GpuBuffer from
|
||||
|
|
|
@ -28,6 +28,12 @@ namespace mediapipe {
|
|||
#define GL_HALF_FLOAT 0x140B
|
||||
#endif // GL_HALF_FLOAT
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
#ifndef GL_HALF_FLOAT_OES
|
||||
#define GL_HALF_FLOAT_OES 0x8D61
|
||||
#endif // GL_HALF_FLOAT_OES
|
||||
#endif // __EMSCRIPTEN__
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#ifdef GL_ES_VERSION_2_0
|
||||
static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) {
|
||||
|
@ -48,6 +54,12 @@ static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) {
|
|||
case GL_RG8:
|
||||
info->gl_internal_format = info->gl_format = GL_RG_EXT;
|
||||
return;
|
||||
#ifdef __EMSCRIPTEN__
|
||||
case GL_RGBA16F:
|
||||
info->gl_internal_format = GL_RGBA;
|
||||
info->gl_type = GL_HALF_FLOAT_OES;
|
||||
return;
|
||||
#endif // __EMSCRIPTEN__
|
||||
default:
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#ifndef MEDIAPIPE_GPU_GPU_SERVICE_H_
|
||||
#define MEDIAPIPE_GPU_GPU_SERVICE_H_
|
||||
|
||||
#include "absl/base/attributes.h"
|
||||
#include "mediapipe/framework/graph_service.h"
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
|
@ -29,7 +30,7 @@ class GpuResources {
|
|||
};
|
||||
#endif // MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
extern const GraphService<GpuResources> kGpuService;
|
||||
ABSL_CONST_INIT extern const GraphService<GpuResources> kGpuService;
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include "mediapipe/gpu/gpu_shared_data_internal.h"
|
||||
|
||||
#include "absl/base/attributes.h"
|
||||
#include "mediapipe/framework/deps/no_destructor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/gpu/gl_context.h"
|
||||
|
@ -116,7 +117,7 @@ GpuResources::~GpuResources() {
|
|||
#endif // __APPLE__
|
||||
}
|
||||
|
||||
extern const GraphService<GpuResources> kGpuService;
|
||||
ABSL_CONST_INIT extern const GraphService<GpuResources> kGpuService;
|
||||
|
||||
absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) {
|
||||
CHECK(node->Contract().ServiceRequests().contains(kGpuService.key));
|
||||
|
|
|
@ -187,7 +187,7 @@ class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
|
|||
"""Instantiates perceptual loss.
|
||||
|
||||
Args:
|
||||
feature_weight: The weight coeffcients of multiple model extracted
|
||||
feature_weight: The weight coefficients of multiple model extracted
|
||||
features used for calculating the perceptual loss.
|
||||
loss_weight: The weight coefficients between `style_loss` and
|
||||
`content_loss`.
|
||||
|
|
|
@ -105,7 +105,7 @@ class FaceStylizer(object):
|
|||
self._train_model(train_data=train_data, preprocessor=self._preprocessor)
|
||||
|
||||
def _create_model(self):
|
||||
"""Creates the componenets of face stylizer."""
|
||||
"""Creates the components of face stylizer."""
|
||||
self._encoder = model_util.load_keras_model(
|
||||
constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path()
|
||||
)
|
||||
|
@ -138,7 +138,7 @@ class FaceStylizer(object):
|
|||
"""
|
||||
train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor)
|
||||
|
||||
# TODO: Support processing mulitple input style images. The
|
||||
# TODO: Support processing multiple input style images. The
|
||||
# input style images are expected to have similar style.
|
||||
# style_sample represents a tuple of (style_image, style_label).
|
||||
style_sample = next(iter(train_dataset))
|
||||
|
|
|
@ -103,8 +103,8 @@ class ModelResourcesCache {
|
|||
};
|
||||
|
||||
// Global service for mediapipe task model resources cache.
|
||||
const mediapipe::GraphService<ModelResourcesCache> kModelResourcesCacheService(
|
||||
"mediapipe::tasks::ModelResourcesCacheService");
|
||||
inline constexpr mediapipe::GraphService<ModelResourcesCache>
|
||||
kModelResourcesCacheService("mediapipe::tasks::ModelResourcesCacheService");
|
||||
|
||||
} // namespace core
|
||||
} // namespace tasks
|
||||
|
|
|
@ -291,8 +291,11 @@ class TensorsToSegmentationCalculator : public Node {
|
|||
static constexpr Output<Image>::Multiple kConfidenceMaskOut{
|
||||
"CONFIDENCE_MASK"};
|
||||
static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"};
|
||||
static constexpr Output<std::vector<float>>::Optional kQualityScoresOut{
|
||||
"QUALITY_SCORES"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut,
|
||||
kConfidenceMaskOut, kCategoryMaskOut);
|
||||
kConfidenceMaskOut, kCategoryMaskOut,
|
||||
kQualityScoresOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||
|
||||
|
@ -345,12 +348,33 @@ absl::Status TensorsToSegmentationCalculator::Open(
|
|||
|
||||
absl::Status TensorsToSegmentationCalculator::Process(
|
||||
mediapipe::CalculatorContext* cc) {
|
||||
RET_CHECK_EQ(kTensorsIn(cc).Get().size(), 1)
|
||||
<< "Expect a vector of single Tensor.";
|
||||
const auto& input_tensor = kTensorsIn(cc).Get()[0];
|
||||
const auto& input_tensors = kTensorsIn(cc).Get();
|
||||
if (input_tensors.size() != 1 && input_tensors.size() != 2) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Expect input tensor vector of size 1 or 2.");
|
||||
}
|
||||
const auto& input_tensor = *input_tensors.rbegin();
|
||||
ASSIGN_OR_RETURN(const Shape input_shape,
|
||||
GetImageLikeTensorShape(input_tensor));
|
||||
|
||||
// TODO: should use tensor signature to get the correct output
|
||||
// tensor.
|
||||
if (input_tensors.size() == 2) {
|
||||
const auto& quality_tensor = input_tensors[0];
|
||||
const float* quality_score_buffer =
|
||||
quality_tensor.GetCpuReadView().buffer<float>();
|
||||
const std::vector<float> quality_scores(
|
||||
quality_score_buffer,
|
||||
quality_score_buffer +
|
||||
(quality_tensor.bytes() / quality_tensor.element_size()));
|
||||
kQualityScoresOut(cc).Send(quality_scores);
|
||||
} else {
|
||||
// If the input_tensors don't contain quality scores, send the default
|
||||
// quality scores as 1.
|
||||
const std::vector<float> quality_scores(input_shape.channels, 1.0f);
|
||||
kQualityScoresOut(cc).Send(quality_scores);
|
||||
}
|
||||
|
||||
// Category mask does not require activation function.
|
||||
if (options_.segmenter_options().output_type() ==
|
||||
SegmenterOptions::CONFIDENCE_MASK &&
|
||||
|
|
|
@ -46,6 +46,8 @@ constexpr char kImageOutStreamName[] = "image_out";
|
|||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kQualityScoresStreamName[] = "quality_scores";
|
||||
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
@ -77,6 +79,8 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
||||
graph.Out(kCategoryMaskTag);
|
||||
}
|
||||
task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
|
||||
graph.Out(kQualityScoresTag);
|
||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||
graph.Out(kImageTag);
|
||||
if (enable_flow_limiting) {
|
||||
|
@ -172,9 +176,13 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
|||
category_mask =
|
||||
status_or_packets.value()[kCategoryMaskStreamName].Get<Image>();
|
||||
}
|
||||
const std::vector<float>& quality_scores =
|
||||
status_or_packets.value()[kQualityScoresStreamName]
|
||||
.Get<std::vector<float>>();
|
||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||
result_callback(
|
||||
{{confidence_masks, category_mask}}, image_packet.Get<Image>(),
|
||||
{{confidence_masks, category_mask, quality_scores}},
|
||||
image_packet.Get<Image>(),
|
||||
image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
|
||||
};
|
||||
}
|
||||
|
@ -227,7 +235,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
|
|||
if (output_category_mask_) {
|
||||
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||
}
|
||||
return {{confidence_masks, category_mask}};
|
||||
const std::vector<float>& quality_scores =
|
||||
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
|
||||
return {{confidence_masks, category_mask, quality_scores}};
|
||||
}
|
||||
|
||||
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
|
||||
|
@ -260,7 +270,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
|
|||
if (output_category_mask_) {
|
||||
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||
}
|
||||
return {{confidence_masks, category_mask}};
|
||||
const std::vector<float>& quality_scores =
|
||||
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
|
||||
return {{confidence_masks, category_mask, quality_scores}};
|
||||
}
|
||||
|
||||
absl::Status ImageSegmenter::SegmentAsync(
|
||||
|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
@ -81,6 +82,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU";
|
|||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
|
||||
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
|
||||
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
|
||||
|
||||
// Struct holding the different output streams produced by the image segmenter
|
||||
|
@ -90,6 +92,7 @@ struct ImageSegmenterOutputs {
|
|||
std::optional<std::vector<Source<Image>>> confidence_masks;
|
||||
std::optional<Source<Image>> category_mask;
|
||||
// The same as the input image, mainly used for live stream mode.
|
||||
std::optional<Source<std::vector<float>>> quality_scores;
|
||||
Source<Image> image;
|
||||
};
|
||||
|
||||
|
@ -191,19 +194,12 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
|
|||
"Segmentation tflite models are assumed to have a single subgraph.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||
if (primary_subgraph->outputs()->size() != 1) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Segmentation tflite models are assumed to have a single output.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
*options->mutable_label_items(),
|
||||
GetLabelItemsIfAny(*metadata_extractor,
|
||||
*metadata_extractor->GetOutputTensorMetadata()->Get(0),
|
||||
segmenter_option.display_names_locale()));
|
||||
GetLabelItemsIfAny(
|
||||
*metadata_extractor,
|
||||
**metadata_extractor->GetOutputTensorMetadata()->crbegin(),
|
||||
segmenter_option.display_names_locale()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -213,10 +209,16 @@ absl::StatusOr<const tflite::Tensor*> GetOutputTensor(
|
|||
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||
const auto* output_tensor =
|
||||
(*primary_subgraph->tensors())[(*primary_subgraph->outputs())[0]];
|
||||
(*primary_subgraph->tensors())[*(*primary_subgraph->outputs()).rbegin()];
|
||||
return output_tensor;
|
||||
}
|
||||
|
||||
uint32_t GetOutputTensorsSize(const core::ModelResources& model_resources) {
|
||||
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||
return primary_subgraph->outputs()->size();
|
||||
}
|
||||
|
||||
// Get the input tensor from the tflite model of given model resources.
|
||||
absl::StatusOr<const tflite::Tensor*> GetInputTensor(
|
||||
const core::ModelResources& model_resources) {
|
||||
|
@ -433,6 +435,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
*output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
|
||||
}
|
||||
}
|
||||
if (output_streams.quality_scores) {
|
||||
*output_streams.quality_scores >>
|
||||
graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
|
||||
}
|
||||
output_streams.image >> graph[Output<Image>(kImageTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
@ -530,9 +536,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
|
||||
}
|
||||
}
|
||||
auto quality_scores =
|
||||
tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
|
||||
return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
|
||||
/*confidence_masks=*/std::nullopt,
|
||||
/*category_mask=*/std::nullopt,
|
||||
/*quality_scores=*/quality_scores,
|
||||
/*image=*/image_and_tensors.image};
|
||||
} else {
|
||||
std::optional<std::vector<Source<Image>>> confidence_masks;
|
||||
|
@ -552,9 +561,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
if (output_category_mask_) {
|
||||
category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)];
|
||||
}
|
||||
auto quality_scores =
|
||||
tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
|
||||
return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt,
|
||||
/*confidence_masks=*/confidence_masks,
|
||||
/*category_mask=*/category_mask,
|
||||
/*quality_scores=*/quality_scores,
|
||||
/*image=*/image_and_tensors.image};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,6 +33,10 @@ struct ImageSegmenterResult {
|
|||
// A category mask of uint8 image in GRAY8 format where each pixel represents
|
||||
// the class which the pixel in the original image was predicted to belong to.
|
||||
std::optional<Image> category_mask;
|
||||
// The quality scores of the result masks, in the range of [0, 1]. Defaults to
|
||||
// `1` if the model doesn't output quality scores. Each element corresponds to
|
||||
// the score of the category in the model outputs.
|
||||
std::vector<float> quality_scores;
|
||||
};
|
||||
|
||||
} // namespace image_segmenter
|
||||
|
|
|
@ -51,12 +51,14 @@ constexpr char kImageInStreamName[] = "image_in";
|
|||
constexpr char kImageOutStreamName[] = "image_out";
|
||||
constexpr char kRoiStreamName[] = "roi_in";
|
||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||
constexpr char kQualityScoresStreamName[] = "quality_scores";
|
||||
|
||||
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
|
||||
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
|
||||
constexpr absl::string_view kImageTag{"IMAGE"};
|
||||
constexpr absl::string_view kRoiTag{"ROI"};
|
||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
||||
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
|
||||
|
||||
constexpr absl::string_view kSubgraphTypeName{
|
||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
|
||||
|
@ -91,6 +93,8 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
||||
graph.Out(kCategoryMaskTag);
|
||||
}
|
||||
task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
|
||||
graph.Out(kQualityScoresTag);
|
||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||
graph.Out(kImageTag);
|
||||
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||
|
@ -201,7 +205,9 @@ absl::StatusOr<ImageSegmenterResult> InteractiveSegmenter::Segment(
|
|||
if (output_category_mask_) {
|
||||
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||
}
|
||||
return {{confidence_masks, category_mask}};
|
||||
const std::vector<float>& quality_scores =
|
||||
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
|
||||
return {{confidence_masks, category_mask, quality_scores}};
|
||||
}
|
||||
|
||||
} // namespace interactive_segmenter
|
||||
|
|
|
@ -58,6 +58,7 @@ constexpr absl::string_view kAlphaTag{"ALPHA"};
|
|||
constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
|
||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
||||
constexpr absl::string_view kRoiTag{"ROI"};
|
||||
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
|
||||
|
||||
// Updates the graph to return `roi` stream which has same dimension as
|
||||
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
|
||||
|
@ -200,6 +201,8 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
|
|||
graph[Output<Image>(kCategoryMaskTag)];
|
||||
}
|
||||
}
|
||||
image_segmenter.Out(kQualityScoresTag) >>
|
||||
graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
|
||||
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
|
||||
|
||||
return graph.GetConfig();
|
||||
|
|
|
@ -81,7 +81,7 @@ strip_api_include_path_prefix(
|
|||
"//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierResult.h",
|
||||
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetector.h",
|
||||
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorOptions.h",
|
||||
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectionResult.h",
|
||||
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorResult.h",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -162,7 +162,7 @@ apple_static_xcframework(
|
|||
":MPPImageClassifierResult.h",
|
||||
":MPPObjectDetector.h",
|
||||
":MPPObjectDetectorOptions.h",
|
||||
":MPPObjectDetectionResult.h",
|
||||
":MPPObjectDetectorResult.h",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier",
|
||||
|
|
|
@ -16,17 +16,6 @@
|
|||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* MediaPipe Tasks delegate.
|
||||
*/
|
||||
typedef NS_ENUM(NSUInteger, MPPDelegate) {
|
||||
/** CPU. */
|
||||
MPPDelegateCPU,
|
||||
|
||||
/** GPU. */
|
||||
MPPDelegateGPU
|
||||
} NS_SWIFT_NAME(Delegate);
|
||||
|
||||
/**
|
||||
* Holds the base options that is used for creation of any type of task. It has fields with
|
||||
* important information acceleration configuration, TFLite model source etc.
|
||||
|
@ -37,12 +26,6 @@ NS_SWIFT_NAME(BaseOptions)
|
|||
/** The path to the model asset to open and mmap in memory. */
|
||||
@property(nonatomic, copy) NSString *modelAssetPath;
|
||||
|
||||
/**
|
||||
* Device delegate to run the MediaPipe pipeline. If the delegate is not set, the default
|
||||
* delegate CPU is used.
|
||||
*/
|
||||
@property(nonatomic) MPPDelegate delegate;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
|
|
@ -28,7 +28,6 @@
|
|||
MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init];
|
||||
|
||||
baseOptions.modelAssetPath = self.modelAssetPath;
|
||||
baseOptions.delegate = self.delegate;
|
||||
|
||||
return baseOptions;
|
||||
}
|
||||
|
|
|
@ -33,20 +33,6 @@ using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions;
|
|||
if (self.modelAssetPath) {
|
||||
baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String);
|
||||
}
|
||||
|
||||
switch (self.delegate) {
|
||||
case MPPDelegateCPU: {
|
||||
baseOptionsProto->mutable_acceleration()->mutable_tflite();
|
||||
break;
|
||||
}
|
||||
case MPPDelegateGPU: {
|
||||
// TODO: Provide an implementation for GPU Delegate.
|
||||
[NSException raise:@"Invalid value for delegate" format:@"GPU Delegate is not implemented."];
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
|
@ -28,9 +28,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
|||
XCTAssertNotNil(error); \
|
||||
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||
XCTAssertEqual(error.code, expectedError.code); \
|
||||
XCTAssertNotEqual( \
|
||||
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
|
||||
NSNotFound)
|
||||
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
|
||||
|
||||
#define AssertEqualCategoryArrays(categories, expectedCategories) \
|
||||
XCTAssertEqual(categories.count, expectedCategories.count); \
|
||||
|
|
|
@ -29,9 +29,7 @@ static const float kSimilarityDiffTolerance = 1e-4;
|
|||
XCTAssertNotNil(error); \
|
||||
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||
XCTAssertEqual(error.code, expectedError.code); \
|
||||
XCTAssertNotEqual( \
|
||||
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
|
||||
NSNotFound)
|
||||
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
|
||||
|
||||
#define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \
|
||||
XCTAssertNotNil(textEmbedderResult); \
|
||||
|
|
|
@ -34,9 +34,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
XCTAssertNotNil(error); \
|
||||
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||
XCTAssertEqual(error.code, expectedError.code); \
|
||||
XCTAssertNotEqual( \
|
||||
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
|
||||
NSNotFound)
|
||||
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
|
||||
|
||||
#define AssertEqualCategoryArrays(categories, expectedCategories) \
|
||||
XCTAssertEqual(categories.count, expectedCategories.count); \
|
||||
|
@ -670,10 +668,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
|
||||
// Because of flow limiting, we cannot ensure that the callback will be
|
||||
// invoked `iterationCount` times.
|
||||
// An normal expectation will fail if expectation.fullfill() is not called
|
||||
// An normal expectation will fail if expectation.fulfill() is not called
|
||||
// `expectation.expectedFulfillmentCount` times.
|
||||
// If `expectation.isInverted = true`, the test will only succeed if
|
||||
// expectation is not fullfilled for the specified `expectedFulfillmentCount`.
|
||||
// expectation is not fulfilled for the specified `expectedFulfillmentCount`.
|
||||
// Since in our case we cannot predict how many times the expectation is
|
||||
// supposed to be fullfilled setting,
|
||||
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
||||
|
|
|
@ -32,9 +32,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
XCTAssertNotNil(error); \
|
||||
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||
XCTAssertEqual(error.code, expectedError.code); \
|
||||
XCTAssertNotEqual( \
|
||||
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
|
||||
NSNotFound)
|
||||
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
|
||||
|
||||
#define AssertEqualCategories(category, expectedCategory, detectionIndex, categoryIndex) \
|
||||
XCTAssertEqual(category.index, expectedCategory.index, \
|
||||
|
@ -70,7 +68,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
|
||||
#pragma mark Results
|
||||
|
||||
+ (MPPObjectDetectionResult *)expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
||||
+ (MPPObjectDetectorResult *)expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
||||
(NSInteger)timestampInMilliseconds {
|
||||
NSArray<MPPDetection *> *detections = @[
|
||||
[[MPPDetection alloc] initWithCategories:@[
|
||||
|
@ -95,8 +93,8 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
keypoints:nil],
|
||||
];
|
||||
|
||||
return [[MPPObjectDetectionResult alloc] initWithDetections:detections
|
||||
timestampInMilliseconds:timestampInMilliseconds];
|
||||
return [[MPPObjectDetectorResult alloc] initWithDetections:detections
|
||||
timestampInMilliseconds:timestampInMilliseconds];
|
||||
}
|
||||
|
||||
- (void)assertDetections:(NSArray<MPPDetection *> *)detections
|
||||
|
@ -112,25 +110,25 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
}
|
||||
}
|
||||
|
||||
- (void)assertObjectDetectionResult:(MPPObjectDetectionResult *)objectDetectionResult
|
||||
isEqualToExpectedResult:(MPPObjectDetectionResult *)expectedObjectDetectionResult
|
||||
expectedDetectionsCount:(NSInteger)expectedDetectionsCount {
|
||||
XCTAssertNotNil(objectDetectionResult);
|
||||
- (void)assertObjectDetectorResult:(MPPObjectDetectorResult *)objectDetectorResult
|
||||
isEqualToExpectedResult:(MPPObjectDetectorResult *)expectedObjectDetectorResult
|
||||
expectedDetectionsCount:(NSInteger)expectedDetectionsCount {
|
||||
XCTAssertNotNil(objectDetectorResult);
|
||||
|
||||
NSArray<MPPDetection *> *detectionsSubsetToCompare;
|
||||
XCTAssertEqual(objectDetectionResult.detections.count, expectedDetectionsCount);
|
||||
if (objectDetectionResult.detections.count > expectedObjectDetectionResult.detections.count) {
|
||||
detectionsSubsetToCompare = [objectDetectionResult.detections
|
||||
subarrayWithRange:NSMakeRange(0, expectedObjectDetectionResult.detections.count)];
|
||||
XCTAssertEqual(objectDetectorResult.detections.count, expectedDetectionsCount);
|
||||
if (objectDetectorResult.detections.count > expectedObjectDetectorResult.detections.count) {
|
||||
detectionsSubsetToCompare = [objectDetectorResult.detections
|
||||
subarrayWithRange:NSMakeRange(0, expectedObjectDetectorResult.detections.count)];
|
||||
} else {
|
||||
detectionsSubsetToCompare = objectDetectionResult.detections;
|
||||
detectionsSubsetToCompare = objectDetectorResult.detections;
|
||||
}
|
||||
|
||||
[self assertDetections:detectionsSubsetToCompare
|
||||
isEqualToExpectedDetections:expectedObjectDetectionResult.detections];
|
||||
isEqualToExpectedDetections:expectedObjectDetectorResult.detections];
|
||||
|
||||
XCTAssertEqual(objectDetectionResult.timestampInMilliseconds,
|
||||
expectedObjectDetectionResult.timestampInMilliseconds);
|
||||
XCTAssertEqual(objectDetectorResult.timestampInMilliseconds,
|
||||
expectedObjectDetectorResult.timestampInMilliseconds);
|
||||
}
|
||||
|
||||
#pragma mark File
|
||||
|
@ -195,28 +193,27 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
- (void)assertResultsOfDetectInImage:(MPPImage *)mppImage
|
||||
usingObjectDetector:(MPPObjectDetector *)objectDetector
|
||||
maxResults:(NSInteger)maxResults
|
||||
equalsObjectDetectionResult:(MPPObjectDetectionResult *)expectedObjectDetectionResult {
|
||||
MPPObjectDetectionResult *objectDetectionResult = [objectDetector detectInImage:mppImage
|
||||
error:nil];
|
||||
equalsObjectDetectorResult:(MPPObjectDetectorResult *)expectedObjectDetectorResult {
|
||||
MPPObjectDetectorResult *ObjectDetectorResult = [objectDetector detectInImage:mppImage error:nil];
|
||||
|
||||
[self assertObjectDetectionResult:objectDetectionResult
|
||||
isEqualToExpectedResult:expectedObjectDetectionResult
|
||||
expectedDetectionsCount:maxResults > 0 ? maxResults
|
||||
: objectDetectionResult.detections.count];
|
||||
[self assertObjectDetectorResult:ObjectDetectorResult
|
||||
isEqualToExpectedResult:expectedObjectDetectorResult
|
||||
expectedDetectionsCount:maxResults > 0 ? maxResults
|
||||
: ObjectDetectorResult.detections.count];
|
||||
}
|
||||
|
||||
- (void)assertResultsOfDetectInImageWithFileInfo:(NSDictionary *)fileInfo
|
||||
usingObjectDetector:(MPPObjectDetector *)objectDetector
|
||||
maxResults:(NSInteger)maxResults
|
||||
|
||||
equalsObjectDetectionResult:
|
||||
(MPPObjectDetectionResult *)expectedObjectDetectionResult {
|
||||
equalsObjectDetectorResult:
|
||||
(MPPObjectDetectorResult *)expectedObjectDetectorResult {
|
||||
MPPImage *mppImage = [self imageWithFileInfo:fileInfo];
|
||||
|
||||
[self assertResultsOfDetectInImage:mppImage
|
||||
usingObjectDetector:objectDetector
|
||||
maxResults:maxResults
|
||||
equalsObjectDetectionResult:expectedObjectDetectionResult];
|
||||
equalsObjectDetectorResult:expectedObjectDetectorResult];
|
||||
}
|
||||
|
||||
#pragma mark General Tests
|
||||
|
@ -266,10 +263,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
||||
usingObjectDetector:objectDetector
|
||||
maxResults:-1
|
||||
equalsObjectDetectionResult:
|
||||
[MPPObjectDetectorTests
|
||||
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
||||
0]];
|
||||
equalsObjectDetectorResult:
|
||||
[MPPObjectDetectorTests
|
||||
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
||||
0]];
|
||||
}
|
||||
|
||||
- (void)testDetectWithOptionsSucceeds {
|
||||
|
@ -280,10 +277,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
||||
usingObjectDetector:objectDetector
|
||||
maxResults:-1
|
||||
equalsObjectDetectionResult:
|
||||
[MPPObjectDetectorTests
|
||||
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
||||
0]];
|
||||
equalsObjectDetectorResult:
|
||||
[MPPObjectDetectorTests
|
||||
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
||||
0]];
|
||||
}
|
||||
|
||||
- (void)testDetectWithMaxResultsSucceeds {
|
||||
|
@ -297,10 +294,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
||||
usingObjectDetector:objectDetector
|
||||
maxResults:maxResults
|
||||
equalsObjectDetectionResult:
|
||||
[MPPObjectDetectorTests
|
||||
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
||||
0]];
|
||||
equalsObjectDetectorResult:
|
||||
[MPPObjectDetectorTests
|
||||
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
||||
0]];
|
||||
}
|
||||
|
||||
- (void)testDetectWithScoreThresholdSucceeds {
|
||||
|
@ -316,13 +313,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
boundingBox:CGRectMake(608, 161, 381, 439)
|
||||
keypoints:nil],
|
||||
];
|
||||
MPPObjectDetectionResult *expectedObjectDetectionResult =
|
||||
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
||||
MPPObjectDetectorResult *expectedObjectDetectorResult =
|
||||
[[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
||||
|
||||
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
||||
usingObjectDetector:objectDetector
|
||||
maxResults:-1
|
||||
equalsObjectDetectionResult:expectedObjectDetectionResult];
|
||||
equalsObjectDetectorResult:expectedObjectDetectorResult];
|
||||
}
|
||||
|
||||
- (void)testDetectWithCategoryAllowlistSucceeds {
|
||||
|
@ -359,13 +356,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
keypoints:nil],
|
||||
];
|
||||
|
||||
MPPObjectDetectionResult *expectedDetectionResult =
|
||||
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
||||
MPPObjectDetectorResult *expectedDetectionResult =
|
||||
[[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
||||
|
||||
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
||||
usingObjectDetector:objectDetector
|
||||
maxResults:-1
|
||||
equalsObjectDetectionResult:expectedDetectionResult];
|
||||
equalsObjectDetectorResult:expectedDetectionResult];
|
||||
}
|
||||
|
||||
- (void)testDetectWithCategoryDenylistSucceeds {
|
||||
|
@ -414,13 +411,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
keypoints:nil],
|
||||
];
|
||||
|
||||
MPPObjectDetectionResult *expectedDetectionResult =
|
||||
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
||||
MPPObjectDetectorResult *expectedDetectionResult =
|
||||
[[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
||||
|
||||
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
||||
usingObjectDetector:objectDetector
|
||||
maxResults:-1
|
||||
equalsObjectDetectionResult:expectedDetectionResult];
|
||||
equalsObjectDetectorResult:expectedDetectionResult];
|
||||
}
|
||||
|
||||
- (void)testDetectWithOrientationSucceeds {
|
||||
|
@ -437,8 +434,8 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
keypoints:nil],
|
||||
];
|
||||
|
||||
MPPObjectDetectionResult *expectedDetectionResult =
|
||||
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
||||
MPPObjectDetectorResult *expectedDetectionResult =
|
||||
[[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
||||
|
||||
MPPImage *image = [self imageWithFileInfo:kCatsAndDogsRotatedImage
|
||||
orientation:UIImageOrientationRight];
|
||||
|
@ -446,7 +443,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
[self assertResultsOfDetectInImage:image
|
||||
usingObjectDetector:objectDetector
|
||||
maxResults:1
|
||||
equalsObjectDetectionResult:expectedDetectionResult];
|
||||
equalsObjectDetectorResult:expectedDetectionResult];
|
||||
}
|
||||
|
||||
#pragma mark Running Mode Tests
|
||||
|
@ -613,15 +610,15 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
MPPImage *image = [self imageWithFileInfo:kCatsAndDogsImage];
|
||||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
MPPObjectDetectionResult *objectDetectionResult = [objectDetector detectInVideoFrame:image
|
||||
timestampInMilliseconds:i
|
||||
error:nil];
|
||||
MPPObjectDetectorResult *ObjectDetectorResult = [objectDetector detectInVideoFrame:image
|
||||
timestampInMilliseconds:i
|
||||
error:nil];
|
||||
|
||||
[self assertObjectDetectionResult:objectDetectionResult
|
||||
isEqualToExpectedResult:
|
||||
[MPPObjectDetectorTests
|
||||
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:i]
|
||||
expectedDetectionsCount:maxResults];
|
||||
[self assertObjectDetectorResult:ObjectDetectorResult
|
||||
isEqualToExpectedResult:
|
||||
[MPPObjectDetectorTests
|
||||
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:i]
|
||||
expectedDetectionsCount:maxResults];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -676,15 +673,15 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
|
||||
// Because of flow limiting, we cannot ensure that the callback will be
|
||||
// invoked `iterationCount` times.
|
||||
// An normal expectation will fail if expectation.fullfill() is not called
|
||||
// An normal expectation will fail if expectation.fulfill() is not called
|
||||
// `expectation.expectedFulfillmentCount` times.
|
||||
// If `expectation.isInverted = true`, the test will only succeed if
|
||||
// expectation is not fullfilled for the specified `expectedFulfillmentCount`.
|
||||
// expectation is not fulfilled for the specified `expectedFulfillmentCount`.
|
||||
// Since in our case we cannot predict how many times the expectation is
|
||||
// supposed to be fullfilled setting,
|
||||
// supposed to be fulfilled setting,
|
||||
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
||||
// `expectation.isInverted = true` ensures that test succeeds if
|
||||
// expectation is fullfilled <= `iterationCount` times.
|
||||
// expectation is fulfilled <= `iterationCount` times.
|
||||
XCTestExpectation *expectation = [[XCTestExpectation alloc]
|
||||
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
|
||||
expectation.expectedFulfillmentCount = iterationCount + 1;
|
||||
|
@ -714,16 +711,16 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
|
||||
#pragma mark MPPObjectDetectorLiveStreamDelegate Methods
|
||||
- (void)objectDetector:(MPPObjectDetector *)objectDetector
|
||||
didFinishDetectionWithResult:(MPPObjectDetectionResult *)objectDetectionResult
|
||||
didFinishDetectionWithResult:(MPPObjectDetectorResult *)ObjectDetectorResult
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError *)error {
|
||||
NSInteger maxResults = 4;
|
||||
[self assertObjectDetectionResult:objectDetectionResult
|
||||
isEqualToExpectedResult:
|
||||
[MPPObjectDetectorTests
|
||||
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
||||
timestampInMilliseconds]
|
||||
expectedDetectionsCount:maxResults];
|
||||
[self assertObjectDetectorResult:ObjectDetectorResult
|
||||
isEqualToExpectedResult:
|
||||
[MPPObjectDetectorTests
|
||||
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
||||
timestampInMilliseconds]
|
||||
expectedDetectionsCount:maxResults];
|
||||
|
||||
if (objectDetector == outOfOrderTimestampTestDict[kLiveStreamTestsDictObjectDetectorKey]) {
|
||||
[outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
|
||||
|
|
|
@ -64,4 +64,3 @@ objc_library(
|
|||
"//mediapipe/tasks/ios/vision/gesture_recognizer/utils:MPPGestureRecognizerResultHelpers",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -87,7 +87,7 @@ NS_SWIFT_NAME(GestureRecognizerOptions)
|
|||
gestureRecognizerLiveStreamDelegate;
|
||||
|
||||
/** Sets the maximum number of hands can be detected by the GestureRecognizer. */
|
||||
@property(nonatomic) NSInteger numberOfHands;
|
||||
@property(nonatomic) NSInteger numberOfHands NS_SWIFT_NAME(numHands);
|
||||
|
||||
/** Sets minimum confidence score for the hand detection to be considered successful */
|
||||
@property(nonatomic) float minHandDetectionConfidence;
|
||||
|
|
|
@ -31,7 +31,8 @@
|
|||
MPPGestureRecognizerOptions *gestureRecognizerOptions = [super copyWithZone:zone];
|
||||
|
||||
gestureRecognizerOptions.runningMode = self.runningMode;
|
||||
gestureRecognizerOptions.gestureRecognizerLiveStreamDelegate = self.gestureRecognizerLiveStreamDelegate;
|
||||
gestureRecognizerOptions.gestureRecognizerLiveStreamDelegate =
|
||||
self.gestureRecognizerLiveStreamDelegate;
|
||||
gestureRecognizerOptions.numberOfHands = self.numberOfHands;
|
||||
gestureRecognizerOptions.minHandDetectionConfidence = self.minHandDetectionConfidence;
|
||||
gestureRecognizerOptions.minHandPresenceConfidence = self.minHandPresenceConfidence;
|
||||
|
|
|
@ -18,9 +18,9 @@
|
|||
|
||||
- (instancetype)initWithGestures:(NSArray<NSArray<MPPCategory *> *> *)gestures
|
||||
handedness:(NSArray<NSArray<MPPCategory *> *> *)handedness
|
||||
landmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)landmarks
|
||||
worldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)worldLandmarks
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||
landmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)landmarks
|
||||
worldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)worldLandmarks
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
|
||||
if (self) {
|
||||
_landmarks = landmarks;
|
||||
|
|
|
@ -22,7 +22,12 @@ objc_library(
|
|||
hdrs = ["sources/MPPGestureRecognizerOptions+Helpers.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||
"//mediapipe/tasks/ios/components/processors/utils:MPPClassifierOptionsHelpers",
|
||||
"//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol",
|
||||
|
|
|
@ -18,7 +18,12 @@
|
|||
#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||
|
||||
namespace {
|
||||
using CalculatorOptionsProto = mediapipe::CalculatorOptions;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
|
||||
static const int kMicroSecondsPerMilliSecond = 1000;
|
||||
static const int kMicrosecondsPerMillisecond = 1000;
|
||||
|
||||
namespace {
|
||||
using ClassificationResultProto =
|
||||
|
@ -29,19 +29,26 @@ using ::mediapipe::Packet;
|
|||
|
||||
+ (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket:
|
||||
(const Packet &)packet {
|
||||
MPPClassificationResult *classificationResult;
|
||||
// Even if packet does not validate as the expected type, you can safely access the timestamp.
|
||||
NSInteger timestampInMilliSeconds =
|
||||
(NSInteger)(packet.Timestamp().Value() / kMicrosecondsPerMillisecond);
|
||||
|
||||
if (!packet.ValidateAsType<ClassificationResultProto>().ok()) {
|
||||
return nil;
|
||||
// MPPClassificationResult's timestamp is populated from timestamp `ClassificationResultProto`'s
|
||||
// timestamp_ms(). It is 0 since the packet can't be validated as a `ClassificationResultProto`.
|
||||
return [[MPPImageClassifierResult alloc]
|
||||
initWithClassificationResult:[[MPPClassificationResult alloc] initWithClassifications:@[]
|
||||
timestampInMilliseconds:0]
|
||||
timestampInMilliseconds:timestampInMilliSeconds];
|
||||
}
|
||||
|
||||
classificationResult = [MPPClassificationResult
|
||||
MPPClassificationResult *classificationResult = [MPPClassificationResult
|
||||
classificationResultWithProto:packet.Get<ClassificationResultProto>()];
|
||||
|
||||
return [[MPPImageClassifierResult alloc]
|
||||
initWithClassificationResult:classificationResult
|
||||
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
||||
kMicroSecondsPerMilliSecond)];
|
||||
kMicrosecondsPerMillisecond)];
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
|
@ -17,9 +17,9 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
licenses(["notice"])
|
||||
|
||||
objc_library(
|
||||
name = "MPPObjectDetectionResult",
|
||||
srcs = ["sources/MPPObjectDetectionResult.m"],
|
||||
hdrs = ["sources/MPPObjectDetectionResult.h"],
|
||||
name = "MPPObjectDetectorResult",
|
||||
srcs = ["sources/MPPObjectDetectorResult.m"],
|
||||
hdrs = ["sources/MPPObjectDetectorResult.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/ios/components/containers:MPPDetection",
|
||||
"//mediapipe/tasks/ios/core:MPPTaskResult",
|
||||
|
@ -31,7 +31,7 @@ objc_library(
|
|||
srcs = ["sources/MPPObjectDetectorOptions.m"],
|
||||
hdrs = ["sources/MPPObjectDetectorOptions.h"],
|
||||
deps = [
|
||||
":MPPObjectDetectionResult",
|
||||
":MPPObjectDetectorResult",
|
||||
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
||||
"//mediapipe/tasks/ios/vision/core:MPPRunningMode",
|
||||
],
|
||||
|
@ -47,8 +47,8 @@ objc_library(
|
|||
"-x objective-c++",
|
||||
],
|
||||
deps = [
|
||||
":MPPObjectDetectionResult",
|
||||
":MPPObjectDetectorOptions",
|
||||
":MPPObjectDetectorResult",
|
||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||
|
@ -56,7 +56,7 @@ objc_library(
|
|||
"//mediapipe/tasks/ios/vision/core:MPPImage",
|
||||
"//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
|
||||
"//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner",
|
||||
"//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectionResultHelpers",
|
||||
"//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectorOptionsHelpers",
|
||||
"//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectorResultHelpers",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
#import <Foundation/Foundation.h>
|
||||
|
||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
|
||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
|
||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h"
|
||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
|
@ -109,14 +109,13 @@ NS_SWIFT_NAME(ObjectDetector)
|
|||
* @param error An optional error parameter populated when there is an error in performing object
|
||||
* detection on the input image.
|
||||
*
|
||||
* @return An `MPPObjectDetectionResult` object that contains a list of detections, each detection
|
||||
* @return An `MPPObjectDetectorResult` object that contains a list of detections, each detection
|
||||
* has a bounding box that is expressed in the unrotated input frame of reference coordinates
|
||||
* system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying
|
||||
* image data.
|
||||
*/
|
||||
- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image
|
||||
error:(NSError **)error
|
||||
NS_SWIFT_NAME(detect(image:));
|
||||
- (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image
|
||||
error:(NSError **)error NS_SWIFT_NAME(detect(image:));
|
||||
|
||||
/**
|
||||
* Performs object detection on the provided video frame of type `MPPImage` using the whole
|
||||
|
@ -139,14 +138,14 @@ NS_SWIFT_NAME(ObjectDetector)
|
|||
* @param error An optional error parameter populated when there is an error in performing object
|
||||
* detection on the input image.
|
||||
*
|
||||
* @return An `MPPObjectDetectionResult` object that contains a list of detections, each detection
|
||||
* @return An `MPPObjectDetectorResult` object that contains a list of detections, each detection
|
||||
* has a bounding box that is expressed in the unrotated input frame of reference coordinates
|
||||
* system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying
|
||||
* image data.
|
||||
*/
|
||||
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error
|
||||
- (nullable MPPObjectDetectorResult *)detectInVideoFrame:(MPPImage *)image
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error
|
||||
NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:));
|
||||
|
||||
/**
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h"
|
||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h"
|
||||
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorOptions+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.h"
|
||||
|
||||
namespace {
|
||||
using ::mediapipe::NormalizedRect;
|
||||
|
@ -118,9 +118,9 @@ static NSString *const kTaskName = @"objectDetector";
|
|||
return;
|
||||
}
|
||||
|
||||
MPPObjectDetectionResult *result = [MPPObjectDetectionResult
|
||||
objectDetectionResultWithDetectionsPacket:statusOrPackets.value()[kDetectionsStreamName
|
||||
.cppString]];
|
||||
MPPObjectDetectorResult *result = [MPPObjectDetectorResult
|
||||
objectDetectorResultWithDetectionsPacket:statusOrPackets
|
||||
.value()[kDetectionsStreamName.cppString]];
|
||||
|
||||
NSInteger timeStampInMilliseconds =
|
||||
outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() /
|
||||
|
@ -184,9 +184,9 @@ static NSString *const kTaskName = @"objectDetector";
|
|||
return inputPacketMap;
|
||||
}
|
||||
|
||||
- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error {
|
||||
- (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error {
|
||||
std::optional<NormalizedRect> rect =
|
||||
[_visionTaskRunner normalizedRectFromRegionOfInterest:roi
|
||||
imageSize:CGSizeMake(image.width, image.height)
|
||||
|
@ -213,18 +213,18 @@ static NSString *const kTaskName = @"objectDetector";
|
|||
return nil;
|
||||
}
|
||||
|
||||
return [MPPObjectDetectionResult
|
||||
objectDetectionResultWithDetectionsPacket:outputPacketMap
|
||||
.value()[kDetectionsStreamName.cppString]];
|
||||
return [MPPObjectDetectorResult
|
||||
objectDetectorResultWithDetectionsPacket:outputPacketMap
|
||||
.value()[kDetectionsStreamName.cppString]];
|
||||
}
|
||||
|
||||
- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image error:(NSError **)error {
|
||||
- (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image error:(NSError **)error {
|
||||
return [self detectInImage:image regionOfInterest:CGRectZero error:error];
|
||||
}
|
||||
|
||||
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error {
|
||||
- (nullable MPPObjectDetectorResult *)detectInVideoFrame:(MPPImage *)image
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error {
|
||||
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
|
||||
timestampInMilliseconds:timestampInMilliseconds
|
||||
error:error];
|
||||
|
@ -239,9 +239,9 @@ static NSString *const kTaskName = @"objectDetector";
|
|||
return nil;
|
||||
}
|
||||
|
||||
return [MPPObjectDetectionResult
|
||||
objectDetectionResultWithDetectionsPacket:outputPacketMap
|
||||
.value()[kDetectionsStreamName.cppString]];
|
||||
return [MPPObjectDetectorResult
|
||||
objectDetectorResultWithDetectionsPacket:outputPacketMap
|
||||
.value()[kDetectionsStreamName.cppString]];
|
||||
}
|
||||
|
||||
- (BOOL)detectAsyncInImage:(MPPImage *)image
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h"
|
||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
|
||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
|
@ -44,7 +44,7 @@ NS_SWIFT_NAME(ObjectDetectorLiveStreamDelegate)
|
|||
*
|
||||
* @param objectDetector The object detector which performed the object detection.
|
||||
* This is useful to test equality when there are multiple instances of `MPPObjectDetector`.
|
||||
* @param result The `MPPObjectDetectionResult` object that contains a list of detections, each
|
||||
* @param result The `MPPObjectDetectorResult` object that contains a list of detections, each
|
||||
* detection has a bounding box that is expressed in the unrotated input frame of reference
|
||||
* coordinates system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the
|
||||
* underlying image data.
|
||||
|
@ -54,7 +54,7 @@ NS_SWIFT_NAME(ObjectDetectorLiveStreamDelegate)
|
|||
* detection on the input live stream image data.
|
||||
*/
|
||||
- (void)objectDetector:(MPPObjectDetector *)objectDetector
|
||||
didFinishDetectionWithResult:(nullable MPPObjectDetectionResult *)result
|
||||
didFinishDetectionWithResult:(nullable MPPObjectDetectorResult *)result
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(nullable NSError *)error
|
||||
NS_SWIFT_NAME(objectDetector(_:didFinishDetection:timestampInMilliseconds:error:));
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/** Represents the detection results generated by `MPPObjectDetector`. */
|
||||
NS_SWIFT_NAME(ObjectDetectionResult)
|
||||
@interface MPPObjectDetectionResult : MPPTaskResult
|
||||
NS_SWIFT_NAME(ObjectDetectorResult)
|
||||
@interface MPPObjectDetectorResult : MPPTaskResult
|
||||
|
||||
/**
|
||||
* The array of `MPPDetection` objects each of which has a bounding box that is expressed in the
|
||||
|
@ -30,7 +30,7 @@ NS_SWIFT_NAME(ObjectDetectionResult)
|
|||
@property(nonatomic, readonly) NSArray<MPPDetection *> *detections;
|
||||
|
||||
/**
|
||||
* Initializes a new `MPPObjectDetectionResult` with the given array of detections and timestamp (in
|
||||
* Initializes a new `MPPObjectDetectorResult` with the given array of detections and timestamp (in
|
||||
* milliseconds).
|
||||
*
|
||||
* @param detections An array of `MPPDetection` objects each of which has a bounding box that is
|
||||
|
@ -38,7 +38,7 @@ NS_SWIFT_NAME(ObjectDetectionResult)
|
|||
* x [0,image_height)`, which are the dimensions of the underlying image data.
|
||||
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
|
||||
*
|
||||
* @return An instance of `MPPObjectDetectionResult` initialized with the given array of detections
|
||||
* @return An instance of `MPPObjectDetectorResult` initialized with the given array of detections
|
||||
* and timestamp (in milliseconds).
|
||||
*/
|
||||
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
|
|
@ -12,9 +12,9 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
|
||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
|
||||
|
||||
@implementation MPPObjectDetectionResult
|
||||
@implementation MPPObjectDetectorResult
|
||||
|
||||
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
|
@ -31,12 +31,12 @@ objc_library(
|
|||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPObjectDetectionResultHelpers",
|
||||
srcs = ["sources/MPPObjectDetectionResult+Helpers.mm"],
|
||||
hdrs = ["sources/MPPObjectDetectionResult+Helpers.h"],
|
||||
name = "MPPObjectDetectorResultHelpers",
|
||||
srcs = ["sources/MPPObjectDetectorResult+Helpers.mm"],
|
||||
hdrs = ["sources/MPPObjectDetectorResult+Helpers.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/tasks/ios/components/containers/utils:MPPDetectionHelpers",
|
||||
"//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetectionResult",
|
||||
"//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetectorResult",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
|
||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
|
||||
|
||||
#include "mediapipe/framework/packet.h"
|
||||
|
||||
|
@ -20,17 +20,17 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
|
||||
static const int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
@interface MPPObjectDetectionResult (Helpers)
|
||||
@interface MPPObjectDetectorResult (Helpers)
|
||||
|
||||
/**
|
||||
* Creates an `MPPObjectDetectionResult` from a MediaPipe packet containing a
|
||||
* Creates an `MPPObjectDetectorResult` from a MediaPipe packet containing a
|
||||
* `std::vector<DetectionProto>`.
|
||||
*
|
||||
* @param packet a MediaPipe packet wrapping a `std::vector<DetectionProto>`.
|
||||
*
|
||||
* @return An `MPPObjectDetectionResult` object that contains a list of detections.
|
||||
* @return An `MPPObjectDetectorResult` object that contains a list of detections.
|
||||
*/
|
||||
+ (nullable MPPObjectDetectionResult *)objectDetectionResultWithDetectionsPacket:
|
||||
+ (nullable MPPObjectDetectorResult *)objectDetectorResultWithDetectionsPacket:
|
||||
(const mediapipe::Packet &)packet;
|
||||
|
||||
@end
|
|
@ -12,7 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.h"
|
||||
|
||||
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h"
|
||||
|
||||
|
@ -21,9 +21,9 @@ using DetectionProto = ::mediapipe::Detection;
|
|||
using ::mediapipe::Packet;
|
||||
} // namespace
|
||||
|
||||
@implementation MPPObjectDetectionResult (Helpers)
|
||||
@implementation MPPObjectDetectorResult (Helpers)
|
||||
|
||||
+ (nullable MPPObjectDetectionResult *)objectDetectionResultWithDetectionsPacket:
|
||||
+ (nullable MPPObjectDetectorResult *)objectDetectorResultWithDetectionsPacket:
|
||||
(const Packet &)packet {
|
||||
if (!packet.ValidateAsType<std::vector<DetectionProto>>().ok()) {
|
||||
return nil;
|
||||
|
@ -37,10 +37,10 @@ using ::mediapipe::Packet;
|
|||
[detections addObject:[MPPDetection detectionWithProto:detectionProto]];
|
||||
}
|
||||
|
||||
return [[MPPObjectDetectionResult alloc]
|
||||
initWithDetections:detections
|
||||
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
||||
kMicroSecondsPerMilliSecond)];
|
||||
return
|
||||
[[MPPObjectDetectorResult alloc] initWithDetections:detections
|
||||
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
||||
kMicroSecondsPerMilliSecond)];
|
||||
}
|
||||
|
||||
@end
|
|
@ -166,7 +166,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
// For 90° and 270° rotations, we need to swap width and height.
|
||||
// This is due to the internal behavior of ImageToTensorCalculator, which:
|
||||
// - first denormalizes the provided rect by multiplying the rect width or
|
||||
// height by the image width or height, repectively.
|
||||
// height by the image width or height, respectively.
|
||||
// - then rotates this by denormalized rect by the provided rotation, and
|
||||
// uses this for cropping,
|
||||
// - then finally rotates this back.
|
||||
|
|
|
@ -115,6 +115,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
segmenterOptions.outputCategoryMask()
|
||||
? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
|
||||
: -1;
|
||||
final int qualityScoresOutStreamIndex =
|
||||
getStreamIndex.apply(outputStreams, "QUALITY_SCORES:quality_scores");
|
||||
final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
|
||||
|
||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||
|
@ -128,6 +130,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
return ImageSegmenterResult.create(
|
||||
Optional.empty(),
|
||||
Optional.empty(),
|
||||
new ArrayList<>(),
|
||||
packets.get(imageOutStreamIndex).getTimestamp());
|
||||
}
|
||||
boolean copyImage = !segmenterOptions.resultListener().isPresent();
|
||||
|
@ -182,9 +185,16 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
|
||||
categoryMask = Optional.of(builder.build());
|
||||
}
|
||||
float[] qualityScores =
|
||||
PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
|
||||
List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
|
||||
for (float score : qualityScores) {
|
||||
qualityScoresList.add(score);
|
||||
}
|
||||
return ImageSegmenterResult.create(
|
||||
confidenceMasks,
|
||||
categoryMask,
|
||||
qualityScoresList,
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
|
||||
}
|
||||
|
@ -592,8 +602,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
public abstract Builder setOutputCategoryMask(boolean value);
|
||||
|
||||
/**
|
||||
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
|
||||
* pipeline is done processing an image.
|
||||
* /** Sets an optional {@link ResultListener} to receive the segmentation results when the
|
||||
* graph pipeline is done processing an image.
|
||||
*/
|
||||
public abstract Builder setResultListener(
|
||||
ResultListener<ImageSegmenterResult, MPImage> value);
|
||||
|
|
|
@ -34,19 +34,30 @@ public abstract class ImageSegmenterResult implements TaskResult {
|
|||
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
|
||||
* category mask, where each pixel represents the class which the pixel in the original image
|
||||
* was predicted to belong to.
|
||||
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Defaults
|
||||
* to `1` if the model doesn't output quality scores. Each element corresponds to the score of
|
||||
* the category in the model outputs.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
// TODO: consolidate output formats across platforms.
|
||||
public static ImageSegmenterResult create(
|
||||
Optional<List<MPImage>> confidenceMasks, Optional<MPImage> categoryMask, long timestampMs) {
|
||||
Optional<List<MPImage>> confidenceMasks,
|
||||
Optional<MPImage> categoryMask,
|
||||
List<Float> qualityScores,
|
||||
long timestampMs) {
|
||||
return new AutoValue_ImageSegmenterResult(
|
||||
confidenceMasks.map(Collections::unmodifiableList), categoryMask, timestampMs);
|
||||
confidenceMasks.map(Collections::unmodifiableList),
|
||||
categoryMask,
|
||||
Collections.unmodifiableList(qualityScores),
|
||||
timestampMs);
|
||||
}
|
||||
|
||||
public abstract Optional<List<MPImage>> confidenceMasks();
|
||||
|
||||
public abstract Optional<MPImage> categoryMask();
|
||||
|
||||
public abstract List<Float> qualityScores();
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
}
|
||||
|
|
|
@ -127,6 +127,10 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
outputStreams.add("CATEGORY_MASK:category_mask");
|
||||
}
|
||||
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
|
||||
|
||||
outputStreams.add("QUALITY_SCORES:quality_scores");
|
||||
final int qualityScoresOutStreamIndex = outputStreams.size() - 1;
|
||||
|
||||
outputStreams.add("IMAGE:image_out");
|
||||
// TODO: add test for stream indices.
|
||||
final int imageOutStreamIndex = outputStreams.size() - 1;
|
||||
|
@ -142,6 +146,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
return ImageSegmenterResult.create(
|
||||
Optional.empty(),
|
||||
Optional.empty(),
|
||||
new ArrayList<>(),
|
||||
packets.get(imageOutStreamIndex).getTimestamp());
|
||||
}
|
||||
// If resultListener is not provided, the resulted MPImage is deep copied from
|
||||
|
@ -199,9 +204,17 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
categoryMask = Optional.of(builder.build());
|
||||
}
|
||||
|
||||
float[] qualityScores =
|
||||
PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
|
||||
List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
|
||||
for (float score : qualityScores) {
|
||||
qualityScoresList.add(score);
|
||||
}
|
||||
|
||||
return ImageSegmenterResult.create(
|
||||
confidenceMasks,
|
||||
categoryMask,
|
||||
qualityScoresList,
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
|
||||
}
|
||||
|
|
|
@ -201,6 +201,7 @@ py_test(
|
|||
"//mediapipe/tasks/testdata/vision:test_images",
|
||||
"//mediapipe/tasks/testdata/vision:test_models",
|
||||
],
|
||||
tags = ["not_run:arm"],
|
||||
deps = [
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
|
|
|
@ -27,13 +27,14 @@ export declare interface Detection {
|
|||
boundingBox?: BoundingBox;
|
||||
|
||||
/**
|
||||
* Optional list of keypoints associated with the detection. Keypoints
|
||||
* represent interesting points related to the detection. For example, the
|
||||
* keypoints represent the eye, ear and mouth from face detection model. Or
|
||||
* in the template matching detection, e.g. KNIFT, they can represent the
|
||||
* feature points for template matching.
|
||||
* List of keypoints associated with the detection. Keypoints represent
|
||||
* interesting points related to the detection. For example, the keypoints
|
||||
* represent the eye, ear and mouth from face detection model. Or in the
|
||||
* template matching detection, e.g. KNIFT, they can represent the feature
|
||||
* points for template matching. Contains an empty list if no keypoints are
|
||||
* detected.
|
||||
*/
|
||||
keypoints?: NormalizedKeypoint[];
|
||||
keypoints: NormalizedKeypoint[];
|
||||
}
|
||||
|
||||
/** Detection results of a model. */
|
||||
|
|
|
@ -85,7 +85,8 @@ describe('convertFromDetectionProto()', () => {
|
|||
categoryName: '',
|
||||
displayName: '',
|
||||
}],
|
||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
|
||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
|
||||
keypoints: []
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -26,7 +26,7 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
|
|||
const labels = source.getLabelList();
|
||||
const displayNames = source.getDisplayNameList();
|
||||
|
||||
const detection: Detection = {categories: []};
|
||||
const detection: Detection = {categories: [], keypoints: []};
|
||||
for (let i = 0; i < scores.length; i++) {
|
||||
detection.categories.push({
|
||||
score: scores[i],
|
||||
|
@ -47,7 +47,6 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
|
|||
}
|
||||
|
||||
if (source.getLocationData()?.getRelativeKeypointsList().length) {
|
||||
detection.keypoints = [];
|
||||
for (const keypoint of
|
||||
source.getLocationData()!.getRelativeKeypointsList()) {
|
||||
detection.keypoints.push({
|
||||
|
|
|
@ -62,7 +62,10 @@ jasmine_node_test(
|
|||
mediapipe_ts_library(
|
||||
name = "mask",
|
||||
srcs = ["mask.ts"],
|
||||
deps = [":image"],
|
||||
deps = [
|
||||
":image",
|
||||
"//mediapipe/web/graph_runner:platform_utils",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
|
|
|
@ -60,6 +60,10 @@ class MPImageTestContext {
|
|||
|
||||
this.webGLTexture = gl.createTexture()!;
|
||||
gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
|
||||
gl.texImage2D(
|
||||
gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, this.imageBitmap);
|
||||
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||
|
|
|
@ -187,10 +187,11 @@ export class MPImage {
|
|||
destinationContainer =
|
||||
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
||||
gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
|
||||
|
||||
this.configureTextureParams();
|
||||
gl.texImage2D(
|
||||
gl.TEXTURE_2D, 0, gl.RGBA, this.width, this.height, 0, gl.RGBA,
|
||||
gl.UNSIGNED_BYTE, null);
|
||||
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||
|
||||
shaderContext.bindFramebuffer(gl, destinationContainer);
|
||||
shaderContext.run(gl, /* flipVertically= */ false, () => {
|
||||
|
@ -302,6 +303,20 @@ export class MPImage {
|
|||
return webGLTexture;
|
||||
}
|
||||
|
||||
/** Sets texture params for the currently bound texture. */
|
||||
private configureTextureParams() {
|
||||
const gl = this.getGL();
|
||||
// `gl.LINEAR` might break rendering for some textures, but it allows us to
|
||||
// do smooth resizing. Ideally, this would be user-configurable, but for now
|
||||
// we hard-code the value here to `gl.LINEAR` (versus `gl.NEAREST` for
|
||||
// `MPMask` where we do not want to interpolate mask values, especially for
|
||||
// category masks).
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
|
||||
}
|
||||
|
||||
/**
|
||||
* Binds the backing texture to the canvas. If the texture does not yet
|
||||
* exist, creates it first.
|
||||
|
@ -318,16 +333,12 @@ export class MPImage {
|
|||
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
||||
this.containers.push(webGLTexture);
|
||||
this.ownsWebGLTexture = true;
|
||||
|
||||
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
|
||||
this.configureTextureParams();
|
||||
} else {
|
||||
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
|
||||
}
|
||||
|
||||
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
|
||||
// TODO: Ideally, we would only set these once per texture and
|
||||
// not once every frame.
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
|
||||
|
||||
return webGLTexture;
|
||||
}
|
||||
|
||||
|
|
|
@ -60,8 +60,11 @@ class MPMaskTestContext {
|
|||
}
|
||||
|
||||
this.webGLTexture = gl.createTexture()!;
|
||||
|
||||
gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
|
||||
gl.texImage2D(
|
||||
gl.TEXTURE_2D, 0, gl.R32F, width, height, 0, gl.RED, gl.FLOAT,
|
||||
new Float32Array(pixels).map(v => v / 255));
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
|
||||
import {isIOS} from '../../../../web/graph_runner/platform_utils';
|
||||
|
||||
/** Number of instances a user can keep alive before we raise a warning. */
|
||||
const INSTANCE_COUNT_WARNING_THRESHOLD = 250;
|
||||
|
@ -32,6 +33,8 @@ enum MPMaskType {
|
|||
/** The supported mask formats. For internal usage. */
|
||||
export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* The wrapper class for MediaPipe segmentation masks.
|
||||
*
|
||||
|
@ -56,6 +59,9 @@ export class MPMask {
|
|||
*/
|
||||
private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD;
|
||||
|
||||
/** The format used to write pixel values from textures. */
|
||||
private static texImage2DFormat?: GLenum;
|
||||
|
||||
/** @hideconstructor */
|
||||
constructor(
|
||||
private readonly containers: MPMaskContainer[],
|
||||
|
@ -127,6 +133,29 @@ export class MPMask {
|
|||
return this.convertToWebGLTexture();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the texture format used for writing float textures on this
|
||||
* platform.
|
||||
*/
|
||||
getTexImage2DFormat(): GLenum {
|
||||
const gl = this.getGL();
|
||||
if (!MPMask.texImage2DFormat) {
|
||||
// Note: This is the same check we use in
|
||||
// `SegmentationPostprocessorGl::GetSegmentationResultGpu()`.
|
||||
if (gl.getExtension('EXT_color_buffer_float') &&
|
||||
gl.getExtension('OES_texture_float_linear') &&
|
||||
gl.getExtension('EXT_float_blend')) {
|
||||
MPMask.texImage2DFormat = gl.R32F;
|
||||
} else if (gl.getExtension('EXT_color_buffer_half_float')) {
|
||||
MPMask.texImage2DFormat = gl.R16F;
|
||||
} else {
|
||||
throw new Error(
|
||||
'GPU does not fully support 4-channel float32 or float16 formats');
|
||||
}
|
||||
}
|
||||
return MPMask.texImage2DFormat;
|
||||
}
|
||||
|
||||
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
|
||||
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
|
||||
private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
|
||||
|
@ -175,8 +204,10 @@ export class MPMask {
|
|||
destinationContainer =
|
||||
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
||||
gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
|
||||
this.configureTextureParams();
|
||||
const format = this.getTexImage2DFormat();
|
||||
gl.texImage2D(
|
||||
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
|
||||
gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
|
||||
gl.FLOAT, null);
|
||||
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||
|
||||
|
@ -207,7 +238,7 @@ export class MPMask {
|
|||
if (!this.canvas) {
|
||||
throw new Error(
|
||||
'Conversion to different image formats require that a canvas ' +
|
||||
'is passed when iniitializing the image.');
|
||||
'is passed when initializing the image.');
|
||||
}
|
||||
if (!this.gl) {
|
||||
this.gl = assertNotNull(
|
||||
|
@ -215,11 +246,6 @@ export class MPMask {
|
|||
'You cannot use a canvas that is already bound to a different ' +
|
||||
'type of rendering context.');
|
||||
}
|
||||
const ext = this.gl.getExtension('EXT_color_buffer_float');
|
||||
if (!ext) {
|
||||
// TODO: Ensure this works on iOS
|
||||
throw new Error('Missing required EXT_color_buffer_float extension');
|
||||
}
|
||||
return this.gl;
|
||||
}
|
||||
|
||||
|
@ -237,18 +263,34 @@ export class MPMask {
|
|||
if (uint8Array) {
|
||||
float32Array = new Float32Array(uint8Array).map(v => v / 255);
|
||||
} else {
|
||||
float32Array = new Float32Array(this.width * this.height);
|
||||
|
||||
const gl = this.getGL();
|
||||
const shaderContext = this.getShaderContext();
|
||||
float32Array = new Float32Array(this.width * this.height);
|
||||
|
||||
// Create texture if needed
|
||||
const webGlTexture = this.convertToWebGLTexture();
|
||||
|
||||
// Create a framebuffer from the texture and read back pixels
|
||||
shaderContext.bindFramebuffer(gl, webGlTexture);
|
||||
gl.readPixels(
|
||||
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
|
||||
shaderContext.unbindFramebuffer();
|
||||
|
||||
if (isIOS()) {
|
||||
// WebKit on iOS only supports gl.HALF_FLOAT for single channel reads
|
||||
// (as tested on iOS 16.4). HALF_FLOAT requires reading data into a
|
||||
// Uint16Array, however, and requires a manual bitwise conversion from
|
||||
// Uint16 to floating point numbers. This conversion is more expensive
|
||||
// that reading back a Float32Array from the RGBA image and dropping
|
||||
// the superfluous data, so we do this instead.
|
||||
const outputArray = new Float32Array(this.width * this.height * 4);
|
||||
gl.readPixels(
|
||||
0, 0, this.width, this.height, gl.RGBA, gl.FLOAT, outputArray);
|
||||
for (let i = 0, j = 0; i < float32Array.length; ++i, j += 4) {
|
||||
float32Array[i] = outputArray[j];
|
||||
}
|
||||
} else {
|
||||
gl.readPixels(
|
||||
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
|
||||
}
|
||||
}
|
||||
this.containers.push(float32Array);
|
||||
}
|
||||
|
@ -273,9 +315,9 @@ export class MPMask {
|
|||
webGLTexture = this.bindTexture();
|
||||
|
||||
const data = this.convertToFloat32Array();
|
||||
// TODO: Add support for R16F to support iOS
|
||||
const format = this.getTexImage2DFormat();
|
||||
gl.texImage2D(
|
||||
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
|
||||
gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
|
||||
gl.FLOAT, data);
|
||||
this.unbindTexture();
|
||||
}
|
||||
|
@ -283,6 +325,19 @@ export class MPMask {
|
|||
return webGLTexture;
|
||||
}
|
||||
|
||||
/** Sets texture params for the currently bound texture. */
|
||||
private configureTextureParams() {
|
||||
const gl = this.getGL();
|
||||
// `gl.NEAREST` ensures that we do not get interpolated values for
|
||||
// masks. In some cases, the user might want interpolation (e.g. for
|
||||
// confidence masks), so we might want to make this user-configurable.
|
||||
// Note that `MPImage` uses `gl.LINEAR`.
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
|
||||
}
|
||||
|
||||
/**
|
||||
* Binds the backing texture to the canvas. If the texture does not yet
|
||||
* exist, creates it first.
|
||||
|
@ -299,15 +354,12 @@ export class MPMask {
|
|||
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
||||
this.containers.push(webGLTexture);
|
||||
this.ownsWebGLTexture = true;
|
||||
}
|
||||
|
||||
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
|
||||
// TODO: Ideally, we would only set these once per texture and
|
||||
// not once every frame.
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
|
||||
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
|
||||
this.configureTextureParams();
|
||||
} else {
|
||||
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
|
||||
}
|
||||
|
||||
return webGLTexture;
|
||||
}
|
||||
|
|
|
@ -191,7 +191,8 @@ describe('FaceDetector', () => {
|
|||
categoryName: '',
|
||||
displayName: '',
|
||||
}],
|
||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
|
||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
|
||||
keypoints: []
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -171,7 +171,7 @@ export class FaceStylizer extends VisionTaskRunner {
|
|||
/**
|
||||
* Performs face stylization on the provided single image and returns the
|
||||
* result. This method creates a copy of the resulting image and should not be
|
||||
* used in high-throughput applictions. Only use this method when the
|
||||
* used in high-throughput applications. Only use this method when the
|
||||
* FaceStylizer is created with the image running mode.
|
||||
*
|
||||
* @param image An image to process.
|
||||
|
@ -182,7 +182,7 @@ export class FaceStylizer extends VisionTaskRunner {
|
|||
/**
|
||||
* Performs face stylization on the provided single image and returns the
|
||||
* result. This method creates a copy of the resulting image and should not be
|
||||
* used in high-throughput applictions. Only use this method when the
|
||||
* used in high-throughput applications. Only use this method when the
|
||||
* FaceStylizer is created with the image running mode.
|
||||
*
|
||||
* The 'imageProcessingOptions' parameter can be used to specify one or all
|
||||
|
@ -275,7 +275,7 @@ export class FaceStylizer extends VisionTaskRunner {
|
|||
/**
|
||||
* Performs face stylization on the provided video frame. This method creates
|
||||
* a copy of the resulting image and should not be used in high-throughput
|
||||
* applictions. Only use this method when the FaceStylizer is created with the
|
||||
* applications. Only use this method when the FaceStylizer is created with the
|
||||
* video running mode.
|
||||
*
|
||||
* The input frame can be of any size. It's required to provide the video
|
||||
|
|
|
@ -39,6 +39,7 @@ const IMAGE_STREAM = 'image_in';
|
|||
const NORM_RECT_STREAM = 'norm_rect';
|
||||
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||
const QUALITY_SCORES_STREAM = 'quality_scores';
|
||||
const IMAGE_SEGMENTER_GRAPH =
|
||||
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
||||
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||
|
@ -61,6 +62,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
|
|||
export class ImageSegmenter extends VisionTaskRunner {
|
||||
private categoryMask?: MPMask;
|
||||
private confidenceMasks?: MPMask[];
|
||||
private qualityScores?: number[];
|
||||
private labels: string[] = [];
|
||||
private userCallback?: ImageSegmenterCallback;
|
||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||
|
@ -229,7 +231,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
/**
|
||||
* Performs image segmentation on the provided single image and returns the
|
||||
* segmentation result. This method creates a copy of the resulting masks and
|
||||
* should not be used in high-throughput applictions. Only use this method
|
||||
* should not be used in high-throughput applications. Only use this method
|
||||
* when the ImageSegmenter is created with running mode `image`.
|
||||
*
|
||||
* @param image An image to process.
|
||||
|
@ -240,7 +242,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
/**
|
||||
* Performs image segmentation on the provided single image and returns the
|
||||
* segmentation result. This method creates a copy of the resulting masks and
|
||||
* should not be used in high-v applictions. Only use this method when
|
||||
* should not be used in high-v applications. Only use this method when
|
||||
* the ImageSegmenter is created with running mode `image`.
|
||||
*
|
||||
* @param image An image to process.
|
||||
|
@ -318,7 +320,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
/**
|
||||
* Performs image segmentation on the provided video frame and returns the
|
||||
* segmentation result. This method creates a copy of the resulting masks and
|
||||
* should not be used in high-v applictions. Only use this method when
|
||||
* should not be used in high-v applications. Only use this method when
|
||||
* the ImageSegmenter is created with running mode `video`.
|
||||
*
|
||||
* @param videoFrame A video frame to process.
|
||||
|
@ -367,12 +369,13 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
private reset(): void {
|
||||
this.categoryMask = undefined;
|
||||
this.confidenceMasks = undefined;
|
||||
this.qualityScores = undefined;
|
||||
}
|
||||
|
||||
private processResults(): ImageSegmenterResult|void {
|
||||
try {
|
||||
const result =
|
||||
new ImageSegmenterResult(this.confidenceMasks, this.categoryMask);
|
||||
const result = new ImageSegmenterResult(
|
||||
this.confidenceMasks, this.categoryMask, this.qualityScores);
|
||||
if (this.userCallback) {
|
||||
this.userCallback(result);
|
||||
} else {
|
||||
|
@ -442,6 +445,20 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
});
|
||||
}
|
||||
|
||||
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
|
||||
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
|
||||
|
||||
this.graphRunner.attachFloatVectorListener(
|
||||
QUALITY_SCORES_STREAM, (scores, timestamp) => {
|
||||
this.qualityScores = scores;
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
QUALITY_SCORES_STREAM, timestamp => {
|
||||
this.categoryMask = undefined;
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,13 @@ export class ImageSegmenterResult {
|
|||
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
||||
* which the pixel in the original image was predicted to belong to.
|
||||
*/
|
||||
readonly categoryMask?: MPMask) {}
|
||||
readonly categoryMask?: MPMask,
|
||||
/**
|
||||
* The quality scores of the result masks, in the range of [0, 1].
|
||||
* Defaults to `1` if the model doesn't output quality scores. Each
|
||||
* element corresponds to the score of the category in the model outputs.
|
||||
*/
|
||||
readonly qualityScores?: number[]) {}
|
||||
|
||||
/** Frees the resources held by the category and confidence masks. */
|
||||
close(): void {
|
||||
|
|
|
@ -35,6 +35,8 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
|||
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||
confidenceMasksListener:
|
||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||
qualityScoresListener:
|
||||
((data: number[], timestamp: number) => void)|undefined;
|
||||
|
||||
constructor() {
|
||||
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||
|
@ -52,6 +54,12 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
|||
expect(stream).toEqual('confidence_masks');
|
||||
this.confidenceMasksListener = listener;
|
||||
});
|
||||
this.attachListenerSpies[2] =
|
||||
spyOn(this.graphRunner, 'attachFloatVectorListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('quality_scores');
|
||||
this.qualityScoresListener = listener;
|
||||
});
|
||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||
});
|
||||
|
@ -266,6 +274,7 @@ describe('ImageSegmenter', () => {
|
|||
it('invokes listener after masks are available', async () => {
|
||||
const categoryMask = new Uint8Array([1]);
|
||||
const confidenceMask = new Float32Array([0.0]);
|
||||
const qualityScores = [1.0];
|
||||
let listenerCalled = false;
|
||||
|
||||
await imageSegmenter.setOptions(
|
||||
|
@ -283,11 +292,16 @@ describe('ImageSegmenter', () => {
|
|||
],
|
||||
1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
imageSegmenter.qualityScoresListener!(qualityScores, 1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
});
|
||||
|
||||
return new Promise<void>(resolve => {
|
||||
imageSegmenter.segment({} as HTMLImageElement, () => {
|
||||
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||
listenerCalled = true;
|
||||
expect(result.categoryMask).toBeInstanceOf(MPMask);
|
||||
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
|
||||
expect(result.qualityScores).toEqual(qualityScores);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
|
|
@ -42,6 +42,7 @@ const NORM_RECT_IN_STREAM = 'norm_rect_in';
|
|||
const ROI_IN_STREAM = 'roi_in';
|
||||
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||
const QUALITY_SCORES_STREAM = 'quality_scores';
|
||||
const IMAGEA_SEGMENTER_GRAPH =
|
||||
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
|
||||
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
|
||||
|
@ -86,6 +87,7 @@ export type InteractiveSegmenterCallback =
|
|||
export class InteractiveSegmenter extends VisionTaskRunner {
|
||||
private categoryMask?: MPMask;
|
||||
private confidenceMasks?: MPMask[];
|
||||
private qualityScores?: number[];
|
||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||
private userCallback?: InteractiveSegmenterCallback;
|
||||
|
@ -284,12 +286,13 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
private reset(): void {
|
||||
this.confidenceMasks = undefined;
|
||||
this.categoryMask = undefined;
|
||||
this.qualityScores = undefined;
|
||||
}
|
||||
|
||||
private processResults(): InteractiveSegmenterResult|void {
|
||||
try {
|
||||
const result = new InteractiveSegmenterResult(
|
||||
this.confidenceMasks, this.categoryMask);
|
||||
this.confidenceMasks, this.categoryMask, this.qualityScores);
|
||||
if (this.userCallback) {
|
||||
this.userCallback(result);
|
||||
} else {
|
||||
|
@ -361,6 +364,20 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
});
|
||||
}
|
||||
|
||||
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
|
||||
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
|
||||
|
||||
this.graphRunner.attachFloatVectorListener(
|
||||
QUALITY_SCORES_STREAM, (scores, timestamp) => {
|
||||
this.qualityScores = scores;
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
QUALITY_SCORES_STREAM, timestamp => {
|
||||
this.categoryMask = undefined;
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,13 @@ export class InteractiveSegmenterResult {
|
|||
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
||||
* which the pixel in the original image was predicted to belong to.
|
||||
*/
|
||||
readonly categoryMask?: MPMask) {}
|
||||
readonly categoryMask?: MPMask,
|
||||
/**
|
||||
* The quality scores of the result masks, in the range of [0, 1].
|
||||
* Defaults to `1` if the model doesn't output quality scores. Each
|
||||
* element corresponds to the score of the category in the model outputs.
|
||||
*/
|
||||
readonly qualityScores?: number[]) {}
|
||||
|
||||
/** Frees the resources held by the category and confidence masks. */
|
||||
close(): void {
|
||||
|
|
|
@ -46,6 +46,8 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
|||
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||
confidenceMasksListener:
|
||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||
qualityScoresListener:
|
||||
((data: number[], timestamp: number) => void)|undefined;
|
||||
lastRoi?: RenderDataProto;
|
||||
|
||||
constructor() {
|
||||
|
@ -64,6 +66,12 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
|||
expect(stream).toEqual('confidence_masks');
|
||||
this.confidenceMasksListener = listener;
|
||||
});
|
||||
this.attachListenerSpies[2] =
|
||||
spyOn(this.graphRunner, 'attachFloatVectorListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('quality_scores');
|
||||
this.qualityScoresListener = listener;
|
||||
});
|
||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||
});
|
||||
|
@ -277,9 +285,10 @@ describe('InteractiveSegmenter', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('invokes listener after masks are avaiblae', async () => {
|
||||
it('invokes listener after masks are available', async () => {
|
||||
const categoryMask = new Uint8Array([1]);
|
||||
const confidenceMask = new Float32Array([0.0]);
|
||||
const qualityScores = [1.0];
|
||||
let listenerCalled = false;
|
||||
|
||||
await interactiveSegmenter.setOptions(
|
||||
|
@ -297,11 +306,16 @@ describe('InteractiveSegmenter', () => {
|
|||
],
|
||||
1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
interactiveSegmenter.qualityScoresListener!(qualityScores, 1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
});
|
||||
|
||||
return new Promise<void>(resolve => {
|
||||
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => {
|
||||
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
|
||||
listenerCalled = true;
|
||||
expect(result.categoryMask).toBeInstanceOf(MPMask);
|
||||
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
|
||||
expect(result.qualityScores).toEqual(qualityScores);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
|
|
@ -210,7 +210,8 @@ describe('ObjectDetector', () => {
|
|||
categoryName: '',
|
||||
displayName: '',
|
||||
}],
|
||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
|
||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
|
||||
keypoints: []
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -21,3 +21,16 @@ export function isWebKit(browser = navigator) {
|
|||
// it uses "CriOS".
|
||||
return userAgent.includes('Safari') && !userAgent.includes('Chrome');
|
||||
}
|
||||
|
||||
/** Detect if code is running on iOS. */
|
||||
export function isIOS() {
|
||||
// Source:
|
||||
// https://stackoverflow.com/questions/9038625/detect-if-device-is-ios
|
||||
return [
|
||||
'iPad Simulator', 'iPhone Simulator', 'iPod Simulator', 'iPad', 'iPhone',
|
||||
'iPod'
|
||||
// tslint:disable-next-line:deprecation
|
||||
].includes(navigator.platform)
|
||||
// iPad on iOS 13 detection
|
||||
|| (navigator.userAgent.includes('Mac') && 'ontouchend' in document);
|
||||
}
|
||||
|
|
5
setup.py
5
setup.py
|
@ -357,7 +357,10 @@ class BuildExtension(build_ext.build_ext):
|
|||
for ext in self.extensions:
|
||||
target_name = self.get_ext_fullpath(ext.name)
|
||||
# Build x86
|
||||
self._build_binary(ext)
|
||||
self._build_binary(
|
||||
ext,
|
||||
['--cpu=darwin', '--ios_multi_cpus=i386,x86_64,armv7,arm64'],
|
||||
)
|
||||
x86_name = self.get_ext_fullpath(ext.name)
|
||||
# Build Arm64
|
||||
ext.name = ext.name + '.arm64'
|
||||
|
|
3
third_party/flatbuffers/BUILD.bazel
vendored
3
third_party/flatbuffers/BUILD.bazel
vendored
|
@ -42,16 +42,15 @@ filegroup(
|
|||
"include/flatbuffers/allocator.h",
|
||||
"include/flatbuffers/array.h",
|
||||
"include/flatbuffers/base.h",
|
||||
"include/flatbuffers/bfbs_generator.h",
|
||||
"include/flatbuffers/buffer.h",
|
||||
"include/flatbuffers/buffer_ref.h",
|
||||
"include/flatbuffers/code_generator.h",
|
||||
"include/flatbuffers/code_generators.h",
|
||||
"include/flatbuffers/default_allocator.h",
|
||||
"include/flatbuffers/detached_buffer.h",
|
||||
"include/flatbuffers/file_manager.h",
|
||||
"include/flatbuffers/flatbuffer_builder.h",
|
||||
"include/flatbuffers/flatbuffers.h",
|
||||
"include/flatbuffers/flatc.h",
|
||||
"include/flatbuffers/flex_flat_util.h",
|
||||
"include/flatbuffers/flexbuffers.h",
|
||||
"include/flatbuffers/grpc.h",
|
||||
|
|
8
third_party/flatbuffers/workspace.bzl
vendored
8
third_party/flatbuffers/workspace.bzl
vendored
|
@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
|
|||
def repo():
|
||||
third_party_http_archive(
|
||||
name = "flatbuffers",
|
||||
strip_prefix = "flatbuffers-23.1.21",
|
||||
sha256 = "d84cb25686514348e615163b458ae0767001b24b42325f426fd56406fd384238",
|
||||
strip_prefix = "flatbuffers-23.5.8",
|
||||
sha256 = "55b75dfa5b6f6173e4abf9c35284a10482ba65db886b39db511eba6c244f1e88",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
|
||||
"https://github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
|
||||
"https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz",
|
||||
"https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz",
|
||||
],
|
||||
build_file = "//third_party/flatbuffers:BUILD.bazel",
|
||||
delete = ["build_defs.bzl", "BUILD.bazel"],
|
||||
|
|
Loading…
Reference in New Issue
Block a user