Pulled changes from master
This commit is contained in:
commit
164eae8c16
|
@ -55,6 +55,10 @@ MEDIAPIPE_REGISTER_NODE(ConcatenateUInt64VectorCalculator);
|
||||||
typedef ConcatenateVectorCalculator<bool> ConcatenateBoolVectorCalculator;
|
typedef ConcatenateVectorCalculator<bool> ConcatenateBoolVectorCalculator;
|
||||||
MEDIAPIPE_REGISTER_NODE(ConcatenateBoolVectorCalculator);
|
MEDIAPIPE_REGISTER_NODE(ConcatenateBoolVectorCalculator);
|
||||||
|
|
||||||
|
typedef ConcatenateVectorCalculator<std::string>
|
||||||
|
ConcatenateStringVectorCalculator;
|
||||||
|
MEDIAPIPE_REGISTER_NODE(ConcatenateStringVectorCalculator);
|
||||||
|
|
||||||
// Example config:
|
// Example config:
|
||||||
// node {
|
// node {
|
||||||
// calculator: "ConcatenateTfLiteTensorVectorCalculator"
|
// calculator: "ConcatenateTfLiteTensorVectorCalculator"
|
||||||
|
|
|
@ -30,13 +30,15 @@ namespace mediapipe {
|
||||||
typedef ConcatenateVectorCalculator<int> TestConcatenateIntVectorCalculator;
|
typedef ConcatenateVectorCalculator<int> TestConcatenateIntVectorCalculator;
|
||||||
MEDIAPIPE_REGISTER_NODE(TestConcatenateIntVectorCalculator);
|
MEDIAPIPE_REGISTER_NODE(TestConcatenateIntVectorCalculator);
|
||||||
|
|
||||||
void AddInputVector(int index, const std::vector<int>& input, int64_t timestamp,
|
template <typename T>
|
||||||
|
void AddInputVector(int index, const std::vector<T>& input, int64_t timestamp,
|
||||||
CalculatorRunner* runner) {
|
CalculatorRunner* runner) {
|
||||||
runner->MutableInputs()->Index(index).packets.push_back(
|
runner->MutableInputs()->Index(index).packets.push_back(
|
||||||
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
|
MakePacket<std::vector<T>>(input).At(Timestamp(timestamp)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddInputVectors(const std::vector<std::vector<int>>& inputs,
|
template <typename T>
|
||||||
|
void AddInputVectors(const std::vector<std::vector<T>>& inputs,
|
||||||
int64_t timestamp, CalculatorRunner* runner) {
|
int64_t timestamp, CalculatorRunner* runner) {
|
||||||
for (int i = 0; i < inputs.size(); ++i) {
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
AddInputVector(i, inputs[i], timestamp, runner);
|
AddInputVector(i, inputs[i], timestamp, runner);
|
||||||
|
@ -382,6 +384,23 @@ TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) {
|
||||||
EXPECT_EQ(0, outputs.size());
|
EXPECT_EQ(0, outputs.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ConcatenateStringVectorCalculatorTest, OneTimestamp) {
|
||||||
|
CalculatorRunner runner("ConcatenateStringVectorCalculator",
|
||||||
|
/*options_string=*/"", /*num_inputs=*/3,
|
||||||
|
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||||
|
|
||||||
|
std::vector<std::vector<std::string>> inputs = {
|
||||||
|
{"a", "b"}, {"c"}, {"d", "e", "f"}};
|
||||||
|
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||||
|
EXPECT_EQ(1, outputs.size());
|
||||||
|
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||||
|
std::vector<std::string> expected_vector = {"a", "b", "c", "d", "e", "f"};
|
||||||
|
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<std::string>>());
|
||||||
|
}
|
||||||
|
|
||||||
typedef ConcatenateVectorCalculator<std::unique_ptr<int>>
|
typedef ConcatenateVectorCalculator<std::unique_ptr<int>>
|
||||||
TestConcatenateUniqueIntPtrCalculator;
|
TestConcatenateUniqueIntPtrCalculator;
|
||||||
MEDIAPIPE_REGISTER_NODE(TestConcatenateUniqueIntPtrCalculator);
|
MEDIAPIPE_REGISTER_NODE(TestConcatenateUniqueIntPtrCalculator);
|
||||||
|
|
|
@ -1099,6 +1099,7 @@ cc_library(
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
],
|
],
|
||||||
|
alwayslink = True, # Defines TestServiceCalculator
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
|
|
@ -68,11 +68,11 @@ StatusBuilder&& StatusBuilder::SetNoLogging() && {
|
||||||
return std::move(SetNoLogging());
|
return std::move(SetNoLogging());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusBuilder::operator Status() const& {
|
StatusBuilder::operator absl::Status() const& {
|
||||||
return StatusBuilder(*this).JoinMessageToStatus();
|
return StatusBuilder(*this).JoinMessageToStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusBuilder::operator Status() && { return JoinMessageToStatus(); }
|
StatusBuilder::operator absl::Status() && { return JoinMessageToStatus(); }
|
||||||
|
|
||||||
absl::Status StatusBuilder::JoinMessageToStatus() {
|
absl::Status StatusBuilder::JoinMessageToStatus() {
|
||||||
if (!impl_) {
|
if (!impl_) {
|
||||||
|
|
|
@ -83,8 +83,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
|
||||||
return std::move(*this << msg);
|
return std::move(*this << msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
operator Status() const&;
|
operator absl::Status() const&;
|
||||||
operator Status() &&;
|
operator absl::Status() &&;
|
||||||
|
|
||||||
absl::Status JoinMessageToStatus();
|
absl::Status JoinMessageToStatus();
|
||||||
|
|
||||||
|
|
|
@ -403,11 +403,11 @@ std::ostream &operator<<(std::ostream &os,
|
||||||
lhs op## = rhs; \
|
lhs op## = rhs; \
|
||||||
return lhs; \
|
return lhs; \
|
||||||
}
|
}
|
||||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(+);
|
STRONG_INT_VS_STRONG_INT_BINARY_OP(+)
|
||||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(-);
|
STRONG_INT_VS_STRONG_INT_BINARY_OP(-)
|
||||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(&);
|
STRONG_INT_VS_STRONG_INT_BINARY_OP(&)
|
||||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(|);
|
STRONG_INT_VS_STRONG_INT_BINARY_OP(|)
|
||||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(^);
|
STRONG_INT_VS_STRONG_INT_BINARY_OP(^)
|
||||||
#undef STRONG_INT_VS_STRONG_INT_BINARY_OP
|
#undef STRONG_INT_VS_STRONG_INT_BINARY_OP
|
||||||
|
|
||||||
// Define operators that take one StrongInt and one native integer argument.
|
// Define operators that take one StrongInt and one native integer argument.
|
||||||
|
@ -431,12 +431,12 @@ STRONG_INT_VS_STRONG_INT_BINARY_OP(^);
|
||||||
rhs op## = lhs; \
|
rhs op## = lhs; \
|
||||||
return rhs; \
|
return rhs; \
|
||||||
}
|
}
|
||||||
STRONG_INT_VS_NUMERIC_BINARY_OP(*);
|
STRONG_INT_VS_NUMERIC_BINARY_OP(*)
|
||||||
NUMERIC_VS_STRONG_INT_BINARY_OP(*);
|
NUMERIC_VS_STRONG_INT_BINARY_OP(*)
|
||||||
STRONG_INT_VS_NUMERIC_BINARY_OP(/);
|
STRONG_INT_VS_NUMERIC_BINARY_OP(/)
|
||||||
STRONG_INT_VS_NUMERIC_BINARY_OP(%);
|
STRONG_INT_VS_NUMERIC_BINARY_OP(%)
|
||||||
STRONG_INT_VS_NUMERIC_BINARY_OP(<<); // NOLINT(whitespace/operators)
|
STRONG_INT_VS_NUMERIC_BINARY_OP(<<) // NOLINT(whitespace/operators)
|
||||||
STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators)
|
STRONG_INT_VS_NUMERIC_BINARY_OP(>>) // NOLINT(whitespace/operators)
|
||||||
#undef STRONG_INT_VS_NUMERIC_BINARY_OP
|
#undef STRONG_INT_VS_NUMERIC_BINARY_OP
|
||||||
#undef NUMERIC_VS_STRONG_INT_BINARY_OP
|
#undef NUMERIC_VS_STRONG_INT_BINARY_OP
|
||||||
|
|
||||||
|
@ -447,12 +447,12 @@ STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators)
|
||||||
StrongInt<TagType, ValueType, ValidatorType> rhs) { \
|
StrongInt<TagType, ValueType, ValidatorType> rhs) { \
|
||||||
return lhs.value() op rhs.value(); \
|
return lhs.value() op rhs.value(); \
|
||||||
}
|
}
|
||||||
STRONG_INT_COMPARISON_OP(==); // NOLINT(whitespace/operators)
|
STRONG_INT_COMPARISON_OP(==) // NOLINT(whitespace/operators)
|
||||||
STRONG_INT_COMPARISON_OP(!=); // NOLINT(whitespace/operators)
|
STRONG_INT_COMPARISON_OP(!=) // NOLINT(whitespace/operators)
|
||||||
STRONG_INT_COMPARISON_OP(<); // NOLINT(whitespace/operators)
|
STRONG_INT_COMPARISON_OP(<) // NOLINT(whitespace/operators)
|
||||||
STRONG_INT_COMPARISON_OP(<=); // NOLINT(whitespace/operators)
|
STRONG_INT_COMPARISON_OP(<=) // NOLINT(whitespace/operators)
|
||||||
STRONG_INT_COMPARISON_OP(>); // NOLINT(whitespace/operators)
|
STRONG_INT_COMPARISON_OP(>) // NOLINT(whitespace/operators)
|
||||||
STRONG_INT_COMPARISON_OP(>=); // NOLINT(whitespace/operators)
|
STRONG_INT_COMPARISON_OP(>=) // NOLINT(whitespace/operators)
|
||||||
#undef STRONG_INT_COMPARISON_OP
|
#undef STRONG_INT_COMPARISON_OP
|
||||||
|
|
||||||
} // namespace intops
|
} // namespace intops
|
||||||
|
|
|
@ -44,7 +44,6 @@ class GraphServiceBase {
|
||||||
|
|
||||||
constexpr GraphServiceBase(const char* key) : key(key) {}
|
constexpr GraphServiceBase(const char* key) : key(key) {}
|
||||||
|
|
||||||
virtual ~GraphServiceBase() = default;
|
|
||||||
inline virtual absl::StatusOr<Packet> CreateDefaultObject() const {
|
inline virtual absl::StatusOr<Packet> CreateDefaultObject() const {
|
||||||
return DefaultInitializationUnsupported();
|
return DefaultInitializationUnsupported();
|
||||||
}
|
}
|
||||||
|
@ -52,14 +51,32 @@ class GraphServiceBase {
|
||||||
const char* key;
|
const char* key;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
// `GraphService<T>` objects, deriving `GraphServiceBase` are designed to be
|
||||||
|
// global constants and not ever deleted through `GraphServiceBase`. Hence,
|
||||||
|
// protected and non-virtual destructor which helps to make `GraphService<T>`
|
||||||
|
// trivially destructible and properly defined as global constants.
|
||||||
|
//
|
||||||
|
// A class with any virtual functions should have a destructor that is either
|
||||||
|
// public and virtual or else protected and non-virtual.
|
||||||
|
// https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rc-dtor-virtual
|
||||||
|
~GraphServiceBase() = default;
|
||||||
|
|
||||||
absl::Status DefaultInitializationUnsupported() const {
|
absl::Status DefaultInitializationUnsupported() const {
|
||||||
return absl::UnimplementedError(absl::StrCat(
|
return absl::UnimplementedError(absl::StrCat(
|
||||||
"Graph service '", key, "' does not support default initialization"));
|
"Graph service '", key, "' does not support default initialization"));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// A global constant to refer a service:
|
||||||
|
// - Requesting `CalculatorContract::UseService` from calculator
|
||||||
|
// - Accessing `Calculator/SubgraphContext::Service`from calculator/subgraph
|
||||||
|
// - Setting before graph initialization `CalculatorGraph::SetServiceObject`
|
||||||
|
//
|
||||||
|
// NOTE: In headers, define your graph service reference safely as following:
|
||||||
|
// `inline constexpr GraphService<YourService> kYourService("YourService");`
|
||||||
|
//
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class GraphService : public GraphServiceBase {
|
class GraphService final : public GraphServiceBase {
|
||||||
public:
|
public:
|
||||||
using type = T;
|
using type = T;
|
||||||
using packet_type = std::shared_ptr<T>;
|
using packet_type = std::shared_ptr<T>;
|
||||||
|
@ -68,7 +85,7 @@ class GraphService : public GraphServiceBase {
|
||||||
kDisallowDefaultInitialization)
|
kDisallowDefaultInitialization)
|
||||||
: GraphServiceBase(my_key), default_init_(default_init) {}
|
: GraphServiceBase(my_key), default_init_(default_init) {}
|
||||||
|
|
||||||
absl::StatusOr<Packet> CreateDefaultObject() const override {
|
absl::StatusOr<Packet> CreateDefaultObject() const final {
|
||||||
if (default_init_ != kAllowDefaultInitialization) {
|
if (default_init_ != kAllowDefaultInitialization) {
|
||||||
return DefaultInitializationUnsupported();
|
return DefaultInitializationUnsupported();
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
const GraphService<int> kIntService("mediapipe::IntService");
|
constexpr GraphService<int> kIntService("mediapipe::IntService");
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TEST(GraphServiceManager, SetGetServiceObject) {
|
TEST(GraphServiceManager, SetGetServiceObject) {
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
|
|
||||||
#include "mediapipe/framework/graph_service.h"
|
#include "mediapipe/framework/graph_service.h"
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
#include "mediapipe/framework/calculator_contract.h"
|
#include "mediapipe/framework/calculator_contract.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/port/canonical_errors.h"
|
#include "mediapipe/framework/port/canonical_errors.h"
|
||||||
|
@ -159,7 +161,7 @@ TEST_F(GraphServiceTest, CreateDefault) {
|
||||||
|
|
||||||
struct TestServiceData {};
|
struct TestServiceData {};
|
||||||
|
|
||||||
const GraphService<TestServiceData> kTestServiceAllowDefaultInitialization(
|
constexpr GraphService<TestServiceData> kTestServiceAllowDefaultInitialization(
|
||||||
"kTestServiceAllowDefaultInitialization",
|
"kTestServiceAllowDefaultInitialization",
|
||||||
GraphServiceBase::kAllowDefaultInitialization);
|
GraphServiceBase::kAllowDefaultInitialization);
|
||||||
|
|
||||||
|
@ -272,9 +274,13 @@ TEST(AllowDefaultInitializationGraphServiceTest,
|
||||||
HasSubstr("Service is unavailable.")));
|
HasSubstr("Service is unavailable.")));
|
||||||
}
|
}
|
||||||
|
|
||||||
const GraphService<TestServiceData> kTestServiceDisallowDefaultInitialization(
|
constexpr GraphService<TestServiceData>
|
||||||
"kTestServiceDisallowDefaultInitialization",
|
kTestServiceDisallowDefaultInitialization(
|
||||||
GraphServiceBase::kDisallowDefaultInitialization);
|
"kTestServiceDisallowDefaultInitialization",
|
||||||
|
GraphServiceBase::kDisallowDefaultInitialization);
|
||||||
|
|
||||||
|
static_assert(std::is_trivially_destructible_v<GraphService<TestServiceData>>,
|
||||||
|
"GraphService is not trivially destructible");
|
||||||
|
|
||||||
class FailOnUnavailableOptionalDisallowDefaultInitServiceCalculator
|
class FailOnUnavailableOptionalDisallowDefaultInitServiceCalculator
|
||||||
: public CalculatorBase {
|
: public CalculatorBase {
|
||||||
|
|
|
@ -16,15 +16,6 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
const GraphService<TestServiceObject> kTestService(
|
|
||||||
"test_service", GraphServiceBase::kDisallowDefaultInitialization);
|
|
||||||
const GraphService<int> kAnotherService(
|
|
||||||
"another_service", GraphServiceBase::kAllowDefaultInitialization);
|
|
||||||
const GraphService<NoDefaultConstructor> kNoDefaultService(
|
|
||||||
"no_default_service", GraphServiceBase::kAllowDefaultInitialization);
|
|
||||||
const GraphService<NeedsCreateMethod> kNeedsCreateService(
|
|
||||||
"needs_create_service", GraphServiceBase::kAllowDefaultInitialization);
|
|
||||||
|
|
||||||
absl::Status TestServiceCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status TestServiceCalculator::GetContract(CalculatorContract* cc) {
|
||||||
cc->Inputs().Index(0).Set<int>();
|
cc->Inputs().Index(0).Set<int>();
|
||||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||||
|
|
|
@ -22,14 +22,17 @@ namespace mediapipe {
|
||||||
|
|
||||||
using TestServiceObject = std::map<std::string, int>;
|
using TestServiceObject = std::map<std::string, int>;
|
||||||
|
|
||||||
extern const GraphService<TestServiceObject> kTestService;
|
inline constexpr GraphService<TestServiceObject> kTestService(
|
||||||
extern const GraphService<int> kAnotherService;
|
"test_service", GraphServiceBase::kDisallowDefaultInitialization);
|
||||||
|
inline constexpr GraphService<int> kAnotherService(
|
||||||
|
"another_service", GraphServiceBase::kAllowDefaultInitialization);
|
||||||
|
|
||||||
class NoDefaultConstructor {
|
class NoDefaultConstructor {
|
||||||
public:
|
public:
|
||||||
NoDefaultConstructor() = delete;
|
NoDefaultConstructor() = delete;
|
||||||
};
|
};
|
||||||
extern const GraphService<NoDefaultConstructor> kNoDefaultService;
|
inline constexpr GraphService<NoDefaultConstructor> kNoDefaultService(
|
||||||
|
"no_default_service", GraphServiceBase::kAllowDefaultInitialization);
|
||||||
|
|
||||||
class NeedsCreateMethod {
|
class NeedsCreateMethod {
|
||||||
public:
|
public:
|
||||||
|
@ -40,7 +43,8 @@ class NeedsCreateMethod {
|
||||||
private:
|
private:
|
||||||
NeedsCreateMethod() = default;
|
NeedsCreateMethod() = default;
|
||||||
};
|
};
|
||||||
extern const GraphService<NeedsCreateMethod> kNeedsCreateService;
|
inline constexpr GraphService<NeedsCreateMethod> kNeedsCreateService(
|
||||||
|
"needs_create_service", GraphServiceBase::kAllowDefaultInitialization);
|
||||||
|
|
||||||
// Use a service.
|
// Use a service.
|
||||||
class TestServiceCalculator : public CalculatorBase {
|
class TestServiceCalculator : public CalculatorBase {
|
||||||
|
|
|
@ -57,7 +57,7 @@ namespace mediapipe {
|
||||||
// have underflow/overflow etc. This type is used internally by Timestamp
|
// have underflow/overflow etc. This type is used internally by Timestamp
|
||||||
// and TimestampDiff.
|
// and TimestampDiff.
|
||||||
MEDIAPIPE_DEFINE_SAFE_INT_TYPE(TimestampBaseType, int64,
|
MEDIAPIPE_DEFINE_SAFE_INT_TYPE(TimestampBaseType, int64,
|
||||||
mediapipe::intops::LogFatalOnError);
|
mediapipe::intops::LogFatalOnError)
|
||||||
|
|
||||||
class TimestampDiff;
|
class TimestampDiff;
|
||||||
|
|
||||||
|
|
|
@ -272,17 +272,20 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string);
|
||||||
#define MEDIAPIPE_REGISTER_TYPE(type, type_name, serialize_fn, deserialize_fn) \
|
#define MEDIAPIPE_REGISTER_TYPE(type, type_name, serialize_fn, deserialize_fn) \
|
||||||
SET_MEDIAPIPE_TYPE_MAP_VALUE( \
|
SET_MEDIAPIPE_TYPE_MAP_VALUE( \
|
||||||
mediapipe::PacketTypeIdToMediaPipeTypeData, \
|
mediapipe::PacketTypeIdToMediaPipeTypeData, \
|
||||||
mediapipe::tool::GetTypeHash< \
|
mediapipe::TypeId::Of< \
|
||||||
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
|
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
|
||||||
|
.hash_code(), \
|
||||||
(mediapipe::MediaPipeTypeData{ \
|
(mediapipe::MediaPipeTypeData{ \
|
||||||
mediapipe::tool::GetTypeHash< \
|
mediapipe::TypeId::Of< \
|
||||||
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
|
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
|
||||||
|
.hash_code(), \
|
||||||
type_name, serialize_fn, deserialize_fn})); \
|
type_name, serialize_fn, deserialize_fn})); \
|
||||||
SET_MEDIAPIPE_TYPE_MAP_VALUE( \
|
SET_MEDIAPIPE_TYPE_MAP_VALUE( \
|
||||||
mediapipe::PacketTypeStringToMediaPipeTypeData, type_name, \
|
mediapipe::PacketTypeStringToMediaPipeTypeData, type_name, \
|
||||||
(mediapipe::MediaPipeTypeData{ \
|
(mediapipe::MediaPipeTypeData{ \
|
||||||
mediapipe::tool::GetTypeHash< \
|
mediapipe::TypeId::Of< \
|
||||||
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
|
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
|
||||||
|
.hash_code(), \
|
||||||
type_name, serialize_fn, deserialize_fn}));
|
type_name, serialize_fn, deserialize_fn}));
|
||||||
// End define MEDIAPIPE_REGISTER_TYPE.
|
// End define MEDIAPIPE_REGISTER_TYPE.
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,10 @@ cc_library(
|
||||||
srcs = ["gpu_service.cc"],
|
srcs = ["gpu_service.cc"],
|
||||||
hdrs = ["gpu_service.h"],
|
hdrs = ["gpu_service.h"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = ["//mediapipe/framework:graph_service"] + select({
|
deps = [
|
||||||
|
"//mediapipe/framework:graph_service",
|
||||||
|
"@com_google_absl//absl/base:core_headers",
|
||||||
|
] + select({
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
":gpu_shared_data_internal",
|
":gpu_shared_data_internal",
|
||||||
],
|
],
|
||||||
|
@ -292,6 +295,7 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
"//mediapipe/framework/port:logging",
|
"//mediapipe/framework/port:logging",
|
||||||
"@com_google_absl//absl/functional:bind_front",
|
"@com_google_absl//absl/functional:bind_front",
|
||||||
|
"@com_google_absl//absl/log:check",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
] + select({
|
] + select({
|
||||||
|
@ -630,6 +634,7 @@ cc_library(
|
||||||
"//mediapipe/framework:executor",
|
"//mediapipe/framework:executor",
|
||||||
"//mediapipe/framework/deps:no_destructor",
|
"//mediapipe/framework/deps:no_destructor",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
"@com_google_absl//absl/base:core_headers",
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
"//mediapipe:apple": [
|
"//mediapipe:apple": [
|
||||||
|
|
|
@ -47,6 +47,7 @@ std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Create(int width, int height,
|
||||||
auto buf = absl::make_unique<GlTextureBuffer>(GL_TEXTURE_2D, 0, width, height,
|
auto buf = absl::make_unique<GlTextureBuffer>(GL_TEXTURE_2D, 0, width, height,
|
||||||
format, nullptr);
|
format, nullptr);
|
||||||
if (!buf->CreateInternal(data, alignment)) {
|
if (!buf->CreateInternal(data, alignment)) {
|
||||||
|
LOG(WARNING) << "Failed to create a GL texture";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return buf;
|
return buf;
|
||||||
|
@ -106,7 +107,10 @@ GlTextureBuffer::GlTextureBuffer(GLenum target, GLuint name, int width,
|
||||||
|
|
||||||
bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
|
bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
|
||||||
auto context = GlContext::GetCurrent();
|
auto context = GlContext::GetCurrent();
|
||||||
if (!context) return false;
|
if (!context) {
|
||||||
|
LOG(WARNING) << "Cannot create a GL texture without a valid context";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
producer_context_ = context; // Save creation GL context.
|
producer_context_ = context; // Save creation GL context.
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/log/check.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "mediapipe/framework/formats/image_frame.h"
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
#include "mediapipe/gpu/gpu_buffer_format.h"
|
#include "mediapipe/gpu/gpu_buffer_format.h"
|
||||||
|
@ -72,8 +73,10 @@ class GpuBuffer {
|
||||||
// are not portable. Applications and calculators should normally obtain
|
// are not portable. Applications and calculators should normally obtain
|
||||||
// GpuBuffers in a portable way from the framework, e.g. using
|
// GpuBuffers in a portable way from the framework, e.g. using
|
||||||
// GpuBufferMultiPool.
|
// GpuBufferMultiPool.
|
||||||
explicit GpuBuffer(std::shared_ptr<internal::GpuBufferStorage> storage)
|
explicit GpuBuffer(std::shared_ptr<internal::GpuBufferStorage> storage) {
|
||||||
: holder_(std::make_shared<StorageHolder>(std::move(storage))) {}
|
CHECK(storage) << "Cannot construct GpuBuffer with null storage";
|
||||||
|
holder_ = std::make_shared<StorageHolder>(std::move(storage));
|
||||||
|
}
|
||||||
|
|
||||||
#if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
#if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
// This is used to support backward-compatible construction of GpuBuffer from
|
// This is used to support backward-compatible construction of GpuBuffer from
|
||||||
|
|
|
@ -28,6 +28,12 @@ namespace mediapipe {
|
||||||
#define GL_HALF_FLOAT 0x140B
|
#define GL_HALF_FLOAT 0x140B
|
||||||
#endif // GL_HALF_FLOAT
|
#endif // GL_HALF_FLOAT
|
||||||
|
|
||||||
|
#ifdef __EMSCRIPTEN__
|
||||||
|
#ifndef GL_HALF_FLOAT_OES
|
||||||
|
#define GL_HALF_FLOAT_OES 0x8D61
|
||||||
|
#endif // GL_HALF_FLOAT_OES
|
||||||
|
#endif // __EMSCRIPTEN__
|
||||||
|
|
||||||
#if !MEDIAPIPE_DISABLE_GPU
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
#ifdef GL_ES_VERSION_2_0
|
#ifdef GL_ES_VERSION_2_0
|
||||||
static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) {
|
static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) {
|
||||||
|
@ -48,6 +54,12 @@ static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) {
|
||||||
case GL_RG8:
|
case GL_RG8:
|
||||||
info->gl_internal_format = info->gl_format = GL_RG_EXT;
|
info->gl_internal_format = info->gl_format = GL_RG_EXT;
|
||||||
return;
|
return;
|
||||||
|
#ifdef __EMSCRIPTEN__
|
||||||
|
case GL_RGBA16F:
|
||||||
|
info->gl_internal_format = GL_RGBA;
|
||||||
|
info->gl_type = GL_HALF_FLOAT_OES;
|
||||||
|
return;
|
||||||
|
#endif // __EMSCRIPTEN__
|
||||||
default:
|
default:
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#ifndef MEDIAPIPE_GPU_GPU_SERVICE_H_
|
#ifndef MEDIAPIPE_GPU_GPU_SERVICE_H_
|
||||||
#define MEDIAPIPE_GPU_GPU_SERVICE_H_
|
#define MEDIAPIPE_GPU_GPU_SERVICE_H_
|
||||||
|
|
||||||
|
#include "absl/base/attributes.h"
|
||||||
#include "mediapipe/framework/graph_service.h"
|
#include "mediapipe/framework/graph_service.h"
|
||||||
|
|
||||||
#if !MEDIAPIPE_DISABLE_GPU
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
|
@ -29,7 +30,7 @@ class GpuResources {
|
||||||
};
|
};
|
||||||
#endif // MEDIAPIPE_DISABLE_GPU
|
#endif // MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
extern const GraphService<GpuResources> kGpuService;
|
ABSL_CONST_INIT extern const GraphService<GpuResources> kGpuService;
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
#include "mediapipe/gpu/gpu_shared_data_internal.h"
|
#include "mediapipe/gpu/gpu_shared_data_internal.h"
|
||||||
|
|
||||||
|
#include "absl/base/attributes.h"
|
||||||
#include "mediapipe/framework/deps/no_destructor.h"
|
#include "mediapipe/framework/deps/no_destructor.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "mediapipe/gpu/gl_context.h"
|
#include "mediapipe/gpu/gl_context.h"
|
||||||
|
@ -116,7 +117,7 @@ GpuResources::~GpuResources() {
|
||||||
#endif // __APPLE__
|
#endif // __APPLE__
|
||||||
}
|
}
|
||||||
|
|
||||||
extern const GraphService<GpuResources> kGpuService;
|
ABSL_CONST_INIT extern const GraphService<GpuResources> kGpuService;
|
||||||
|
|
||||||
absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) {
|
absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) {
|
||||||
CHECK(node->Contract().ServiceRequests().contains(kGpuService.key));
|
CHECK(node->Contract().ServiceRequests().contains(kGpuService.key));
|
||||||
|
|
|
@ -187,7 +187,7 @@ class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
|
||||||
"""Instantiates perceptual loss.
|
"""Instantiates perceptual loss.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
feature_weight: The weight coeffcients of multiple model extracted
|
feature_weight: The weight coefficients of multiple model extracted
|
||||||
features used for calculating the perceptual loss.
|
features used for calculating the perceptual loss.
|
||||||
loss_weight: The weight coefficients between `style_loss` and
|
loss_weight: The weight coefficients between `style_loss` and
|
||||||
`content_loss`.
|
`content_loss`.
|
||||||
|
|
|
@ -105,7 +105,7 @@ class FaceStylizer(object):
|
||||||
self._train_model(train_data=train_data, preprocessor=self._preprocessor)
|
self._train_model(train_data=train_data, preprocessor=self._preprocessor)
|
||||||
|
|
||||||
def _create_model(self):
|
def _create_model(self):
|
||||||
"""Creates the componenets of face stylizer."""
|
"""Creates the components of face stylizer."""
|
||||||
self._encoder = model_util.load_keras_model(
|
self._encoder = model_util.load_keras_model(
|
||||||
constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path()
|
constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path()
|
||||||
)
|
)
|
||||||
|
@ -138,7 +138,7 @@ class FaceStylizer(object):
|
||||||
"""
|
"""
|
||||||
train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor)
|
train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor)
|
||||||
|
|
||||||
# TODO: Support processing mulitple input style images. The
|
# TODO: Support processing multiple input style images. The
|
||||||
# input style images are expected to have similar style.
|
# input style images are expected to have similar style.
|
||||||
# style_sample represents a tuple of (style_image, style_label).
|
# style_sample represents a tuple of (style_image, style_label).
|
||||||
style_sample = next(iter(train_dataset))
|
style_sample = next(iter(train_dataset))
|
||||||
|
|
|
@ -103,8 +103,8 @@ class ModelResourcesCache {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Global service for mediapipe task model resources cache.
|
// Global service for mediapipe task model resources cache.
|
||||||
const mediapipe::GraphService<ModelResourcesCache> kModelResourcesCacheService(
|
inline constexpr mediapipe::GraphService<ModelResourcesCache>
|
||||||
"mediapipe::tasks::ModelResourcesCacheService");
|
kModelResourcesCacheService("mediapipe::tasks::ModelResourcesCacheService");
|
||||||
|
|
||||||
} // namespace core
|
} // namespace core
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
|
|
|
@ -291,8 +291,11 @@ class TensorsToSegmentationCalculator : public Node {
|
||||||
static constexpr Output<Image>::Multiple kConfidenceMaskOut{
|
static constexpr Output<Image>::Multiple kConfidenceMaskOut{
|
||||||
"CONFIDENCE_MASK"};
|
"CONFIDENCE_MASK"};
|
||||||
static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"};
|
static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"};
|
||||||
|
static constexpr Output<std::vector<float>>::Optional kQualityScoresOut{
|
||||||
|
"QUALITY_SCORES"};
|
||||||
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut,
|
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut,
|
||||||
kConfidenceMaskOut, kCategoryMaskOut);
|
kConfidenceMaskOut, kCategoryMaskOut,
|
||||||
|
kQualityScoresOut);
|
||||||
|
|
||||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||||
|
|
||||||
|
@ -345,12 +348,33 @@ absl::Status TensorsToSegmentationCalculator::Open(
|
||||||
|
|
||||||
absl::Status TensorsToSegmentationCalculator::Process(
|
absl::Status TensorsToSegmentationCalculator::Process(
|
||||||
mediapipe::CalculatorContext* cc) {
|
mediapipe::CalculatorContext* cc) {
|
||||||
RET_CHECK_EQ(kTensorsIn(cc).Get().size(), 1)
|
const auto& input_tensors = kTensorsIn(cc).Get();
|
||||||
<< "Expect a vector of single Tensor.";
|
if (input_tensors.size() != 1 && input_tensors.size() != 2) {
|
||||||
const auto& input_tensor = kTensorsIn(cc).Get()[0];
|
return absl::InvalidArgumentError(
|
||||||
|
"Expect input tensor vector of size 1 or 2.");
|
||||||
|
}
|
||||||
|
const auto& input_tensor = *input_tensors.rbegin();
|
||||||
ASSIGN_OR_RETURN(const Shape input_shape,
|
ASSIGN_OR_RETURN(const Shape input_shape,
|
||||||
GetImageLikeTensorShape(input_tensor));
|
GetImageLikeTensorShape(input_tensor));
|
||||||
|
|
||||||
|
// TODO: should use tensor signature to get the correct output
|
||||||
|
// tensor.
|
||||||
|
if (input_tensors.size() == 2) {
|
||||||
|
const auto& quality_tensor = input_tensors[0];
|
||||||
|
const float* quality_score_buffer =
|
||||||
|
quality_tensor.GetCpuReadView().buffer<float>();
|
||||||
|
const std::vector<float> quality_scores(
|
||||||
|
quality_score_buffer,
|
||||||
|
quality_score_buffer +
|
||||||
|
(quality_tensor.bytes() / quality_tensor.element_size()));
|
||||||
|
kQualityScoresOut(cc).Send(quality_scores);
|
||||||
|
} else {
|
||||||
|
// If the input_tensors don't contain quality scores, send the default
|
||||||
|
// quality scores as 1.
|
||||||
|
const std::vector<float> quality_scores(input_shape.channels, 1.0f);
|
||||||
|
kQualityScoresOut(cc).Send(quality_scores);
|
||||||
|
}
|
||||||
|
|
||||||
// Category mask does not require activation function.
|
// Category mask does not require activation function.
|
||||||
if (options_.segmenter_options().output_type() ==
|
if (options_.segmenter_options().output_type() ==
|
||||||
SegmenterOptions::CONFIDENCE_MASK &&
|
SegmenterOptions::CONFIDENCE_MASK &&
|
||||||
|
|
|
@ -46,6 +46,8 @@ constexpr char kImageOutStreamName[] = "image_out";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
|
constexpr char kQualityScoresStreamName[] = "quality_scores";
|
||||||
|
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
|
||||||
constexpr char kSubgraphTypeName[] =
|
constexpr char kSubgraphTypeName[] =
|
||||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
||||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
|
@ -77,6 +79,8 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
||||||
graph.Out(kCategoryMaskTag);
|
graph.Out(kCategoryMaskTag);
|
||||||
}
|
}
|
||||||
|
task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
|
||||||
|
graph.Out(kQualityScoresTag);
|
||||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||||
graph.Out(kImageTag);
|
graph.Out(kImageTag);
|
||||||
if (enable_flow_limiting) {
|
if (enable_flow_limiting) {
|
||||||
|
@ -172,9 +176,13 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
||||||
category_mask =
|
category_mask =
|
||||||
status_or_packets.value()[kCategoryMaskStreamName].Get<Image>();
|
status_or_packets.value()[kCategoryMaskStreamName].Get<Image>();
|
||||||
}
|
}
|
||||||
|
const std::vector<float>& quality_scores =
|
||||||
|
status_or_packets.value()[kQualityScoresStreamName]
|
||||||
|
.Get<std::vector<float>>();
|
||||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||||
result_callback(
|
result_callback(
|
||||||
{{confidence_masks, category_mask}}, image_packet.Get<Image>(),
|
{{confidence_masks, category_mask, quality_scores}},
|
||||||
|
image_packet.Get<Image>(),
|
||||||
image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
|
image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -227,7 +235,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
|
||||||
if (output_category_mask_) {
|
if (output_category_mask_) {
|
||||||
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||||
}
|
}
|
||||||
return {{confidence_masks, category_mask}};
|
const std::vector<float>& quality_scores =
|
||||||
|
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
|
||||||
|
return {{confidence_masks, category_mask, quality_scores}};
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
|
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
|
||||||
|
@ -260,7 +270,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
|
||||||
if (output_category_mask_) {
|
if (output_category_mask_) {
|
||||||
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||||
}
|
}
|
||||||
return {{confidence_masks, category_mask}};
|
const std::vector<float>& quality_scores =
|
||||||
|
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
|
||||||
|
return {{confidence_masks, category_mask, quality_scores}};
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ImageSegmenter::SegmentAsync(
|
absl::Status ImageSegmenter::SegmentAsync(
|
||||||
|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
@ -81,6 +82,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU";
|
||||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
|
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
|
||||||
|
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
|
||||||
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
|
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
|
||||||
|
|
||||||
// Struct holding the different output streams produced by the image segmenter
|
// Struct holding the different output streams produced by the image segmenter
|
||||||
|
@ -90,6 +92,7 @@ struct ImageSegmenterOutputs {
|
||||||
std::optional<std::vector<Source<Image>>> confidence_masks;
|
std::optional<std::vector<Source<Image>>> confidence_masks;
|
||||||
std::optional<Source<Image>> category_mask;
|
std::optional<Source<Image>> category_mask;
|
||||||
// The same as the input image, mainly used for live stream mode.
|
// The same as the input image, mainly used for live stream mode.
|
||||||
|
std::optional<Source<std::vector<float>>> quality_scores;
|
||||||
Source<Image> image;
|
Source<Image> image;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -191,19 +194,12 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
|
||||||
"Segmentation tflite models are assumed to have a single subgraph.",
|
"Segmentation tflite models are assumed to have a single subgraph.",
|
||||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
}
|
}
|
||||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
|
||||||
if (primary_subgraph->outputs()->size() != 1) {
|
|
||||||
return CreateStatusWithPayload(
|
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
"Segmentation tflite models are assumed to have a single output.",
|
|
||||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
|
||||||
}
|
|
||||||
|
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
*options->mutable_label_items(),
|
*options->mutable_label_items(),
|
||||||
GetLabelItemsIfAny(*metadata_extractor,
|
GetLabelItemsIfAny(
|
||||||
*metadata_extractor->GetOutputTensorMetadata()->Get(0),
|
*metadata_extractor,
|
||||||
segmenter_option.display_names_locale()));
|
**metadata_extractor->GetOutputTensorMetadata()->crbegin(),
|
||||||
|
segmenter_option.display_names_locale()));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -213,10 +209,16 @@ absl::StatusOr<const tflite::Tensor*> GetOutputTensor(
|
||||||
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||||
const auto* output_tensor =
|
const auto* output_tensor =
|
||||||
(*primary_subgraph->tensors())[(*primary_subgraph->outputs())[0]];
|
(*primary_subgraph->tensors())[*(*primary_subgraph->outputs()).rbegin()];
|
||||||
return output_tensor;
|
return output_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t GetOutputTensorsSize(const core::ModelResources& model_resources) {
|
||||||
|
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||||
|
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||||
|
return primary_subgraph->outputs()->size();
|
||||||
|
}
|
||||||
|
|
||||||
// Get the input tensor from the tflite model of given model resources.
|
// Get the input tensor from the tflite model of given model resources.
|
||||||
absl::StatusOr<const tflite::Tensor*> GetInputTensor(
|
absl::StatusOr<const tflite::Tensor*> GetInputTensor(
|
||||||
const core::ModelResources& model_resources) {
|
const core::ModelResources& model_resources) {
|
||||||
|
@ -433,6 +435,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
*output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
|
*output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (output_streams.quality_scores) {
|
||||||
|
*output_streams.quality_scores >>
|
||||||
|
graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
|
||||||
|
}
|
||||||
output_streams.image >> graph[Output<Image>(kImageTag)];
|
output_streams.image >> graph[Output<Image>(kImageTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
@ -530,9 +536,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
|
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
auto quality_scores =
|
||||||
|
tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
|
||||||
return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
|
return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
|
||||||
/*confidence_masks=*/std::nullopt,
|
/*confidence_masks=*/std::nullopt,
|
||||||
/*category_mask=*/std::nullopt,
|
/*category_mask=*/std::nullopt,
|
||||||
|
/*quality_scores=*/quality_scores,
|
||||||
/*image=*/image_and_tensors.image};
|
/*image=*/image_and_tensors.image};
|
||||||
} else {
|
} else {
|
||||||
std::optional<std::vector<Source<Image>>> confidence_masks;
|
std::optional<std::vector<Source<Image>>> confidence_masks;
|
||||||
|
@ -552,9 +561,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
if (output_category_mask_) {
|
if (output_category_mask_) {
|
||||||
category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)];
|
category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)];
|
||||||
}
|
}
|
||||||
|
auto quality_scores =
|
||||||
|
tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
|
||||||
return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt,
|
return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt,
|
||||||
/*confidence_masks=*/confidence_masks,
|
/*confidence_masks=*/confidence_masks,
|
||||||
/*category_mask=*/category_mask,
|
/*category_mask=*/category_mask,
|
||||||
|
/*quality_scores=*/quality_scores,
|
||||||
/*image=*/image_and_tensors.image};
|
/*image=*/image_and_tensors.image};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,10 @@ struct ImageSegmenterResult {
|
||||||
// A category mask of uint8 image in GRAY8 format where each pixel represents
|
// A category mask of uint8 image in GRAY8 format where each pixel represents
|
||||||
// the class which the pixel in the original image was predicted to belong to.
|
// the class which the pixel in the original image was predicted to belong to.
|
||||||
std::optional<Image> category_mask;
|
std::optional<Image> category_mask;
|
||||||
|
// The quality scores of the result masks, in the range of [0, 1]. Defaults to
|
||||||
|
// `1` if the model doesn't output quality scores. Each element corresponds to
|
||||||
|
// the score of the category in the model outputs.
|
||||||
|
std::vector<float> quality_scores;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace image_segmenter
|
} // namespace image_segmenter
|
||||||
|
|
|
@ -51,12 +51,14 @@ constexpr char kImageInStreamName[] = "image_in";
|
||||||
constexpr char kImageOutStreamName[] = "image_out";
|
constexpr char kImageOutStreamName[] = "image_out";
|
||||||
constexpr char kRoiStreamName[] = "roi_in";
|
constexpr char kRoiStreamName[] = "roi_in";
|
||||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||||
|
constexpr char kQualityScoresStreamName[] = "quality_scores";
|
||||||
|
|
||||||
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
|
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
|
||||||
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
|
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
|
||||||
constexpr absl::string_view kImageTag{"IMAGE"};
|
constexpr absl::string_view kImageTag{"IMAGE"};
|
||||||
constexpr absl::string_view kRoiTag{"ROI"};
|
constexpr absl::string_view kRoiTag{"ROI"};
|
||||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
||||||
|
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
|
||||||
|
|
||||||
constexpr absl::string_view kSubgraphTypeName{
|
constexpr absl::string_view kSubgraphTypeName{
|
||||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
|
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
|
||||||
|
@ -91,6 +93,8 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
||||||
graph.Out(kCategoryMaskTag);
|
graph.Out(kCategoryMaskTag);
|
||||||
}
|
}
|
||||||
|
task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
|
||||||
|
graph.Out(kQualityScoresTag);
|
||||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||||
graph.Out(kImageTag);
|
graph.Out(kImageTag);
|
||||||
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||||
|
@ -201,7 +205,9 @@ absl::StatusOr<ImageSegmenterResult> InteractiveSegmenter::Segment(
|
||||||
if (output_category_mask_) {
|
if (output_category_mask_) {
|
||||||
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||||
}
|
}
|
||||||
return {{confidence_masks, category_mask}};
|
const std::vector<float>& quality_scores =
|
||||||
|
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
|
||||||
|
return {{confidence_masks, category_mask, quality_scores}};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace interactive_segmenter
|
} // namespace interactive_segmenter
|
||||||
|
|
|
@ -58,6 +58,7 @@ constexpr absl::string_view kAlphaTag{"ALPHA"};
|
||||||
constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
|
constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
|
||||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
||||||
constexpr absl::string_view kRoiTag{"ROI"};
|
constexpr absl::string_view kRoiTag{"ROI"};
|
||||||
|
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
|
||||||
|
|
||||||
// Updates the graph to return `roi` stream which has same dimension as
|
// Updates the graph to return `roi` stream which has same dimension as
|
||||||
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
|
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
|
||||||
|
@ -200,6 +201,8 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
|
||||||
graph[Output<Image>(kCategoryMaskTag)];
|
graph[Output<Image>(kCategoryMaskTag)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
image_segmenter.Out(kQualityScoresTag) >>
|
||||||
|
graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
|
||||||
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
|
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
|
||||||
|
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
|
|
|
@ -81,7 +81,7 @@ strip_api_include_path_prefix(
|
||||||
"//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierResult.h",
|
"//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierResult.h",
|
||||||
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetector.h",
|
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetector.h",
|
||||||
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorOptions.h",
|
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorOptions.h",
|
||||||
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectionResult.h",
|
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorResult.h",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -162,7 +162,7 @@ apple_static_xcframework(
|
||||||
":MPPImageClassifierResult.h",
|
":MPPImageClassifierResult.h",
|
||||||
":MPPObjectDetector.h",
|
":MPPObjectDetector.h",
|
||||||
":MPPObjectDetectorOptions.h",
|
":MPPObjectDetectorOptions.h",
|
||||||
":MPPObjectDetectionResult.h",
|
":MPPObjectDetectorResult.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier",
|
"//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier",
|
||||||
|
|
|
@ -16,17 +16,6 @@
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_BEGIN
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
/**
|
|
||||||
* MediaPipe Tasks delegate.
|
|
||||||
*/
|
|
||||||
typedef NS_ENUM(NSUInteger, MPPDelegate) {
|
|
||||||
/** CPU. */
|
|
||||||
MPPDelegateCPU,
|
|
||||||
|
|
||||||
/** GPU. */
|
|
||||||
MPPDelegateGPU
|
|
||||||
} NS_SWIFT_NAME(Delegate);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Holds the base options that is used for creation of any type of task. It has fields with
|
* Holds the base options that is used for creation of any type of task. It has fields with
|
||||||
* important information acceleration configuration, TFLite model source etc.
|
* important information acceleration configuration, TFLite model source etc.
|
||||||
|
@ -37,12 +26,6 @@ NS_SWIFT_NAME(BaseOptions)
|
||||||
/** The path to the model asset to open and mmap in memory. */
|
/** The path to the model asset to open and mmap in memory. */
|
||||||
@property(nonatomic, copy) NSString *modelAssetPath;
|
@property(nonatomic, copy) NSString *modelAssetPath;
|
||||||
|
|
||||||
/**
|
|
||||||
* Device delegate to run the MediaPipe pipeline. If the delegate is not set, the default
|
|
||||||
* delegate CPU is used.
|
|
||||||
*/
|
|
||||||
@property(nonatomic) MPPDelegate delegate;
|
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_END
|
NS_ASSUME_NONNULL_END
|
||||||
|
|
|
@ -28,7 +28,6 @@
|
||||||
MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init];
|
MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init];
|
||||||
|
|
||||||
baseOptions.modelAssetPath = self.modelAssetPath;
|
baseOptions.modelAssetPath = self.modelAssetPath;
|
||||||
baseOptions.delegate = self.delegate;
|
|
||||||
|
|
||||||
return baseOptions;
|
return baseOptions;
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,20 +33,6 @@ using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions;
|
||||||
if (self.modelAssetPath) {
|
if (self.modelAssetPath) {
|
||||||
baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String);
|
baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String);
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (self.delegate) {
|
|
||||||
case MPPDelegateCPU: {
|
|
||||||
baseOptionsProto->mutable_acceleration()->mutable_tflite();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case MPPDelegateGPU: {
|
|
||||||
// TODO: Provide an implementation for GPU Delegate.
|
|
||||||
[NSException raise:@"Invalid value for delegate" format:@"GPU Delegate is not implemented."];
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
|
@ -28,9 +28,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
XCTAssertNotNil(error); \
|
XCTAssertNotNil(error); \
|
||||||
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||||
XCTAssertEqual(error.code, expectedError.code); \
|
XCTAssertEqual(error.code, expectedError.code); \
|
||||||
XCTAssertNotEqual( \
|
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
|
||||||
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
|
|
||||||
NSNotFound)
|
|
||||||
|
|
||||||
#define AssertEqualCategoryArrays(categories, expectedCategories) \
|
#define AssertEqualCategoryArrays(categories, expectedCategories) \
|
||||||
XCTAssertEqual(categories.count, expectedCategories.count); \
|
XCTAssertEqual(categories.count, expectedCategories.count); \
|
||||||
|
|
|
@ -29,9 +29,7 @@ static const float kSimilarityDiffTolerance = 1e-4;
|
||||||
XCTAssertNotNil(error); \
|
XCTAssertNotNil(error); \
|
||||||
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||||
XCTAssertEqual(error.code, expectedError.code); \
|
XCTAssertEqual(error.code, expectedError.code); \
|
||||||
XCTAssertNotEqual( \
|
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
|
||||||
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
|
|
||||||
NSNotFound)
|
|
||||||
|
|
||||||
#define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \
|
#define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \
|
||||||
XCTAssertNotNil(textEmbedderResult); \
|
XCTAssertNotNil(textEmbedderResult); \
|
||||||
|
|
|
@ -34,9 +34,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
XCTAssertNotNil(error); \
|
XCTAssertNotNil(error); \
|
||||||
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||||
XCTAssertEqual(error.code, expectedError.code); \
|
XCTAssertEqual(error.code, expectedError.code); \
|
||||||
XCTAssertNotEqual( \
|
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
|
||||||
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
|
|
||||||
NSNotFound)
|
|
||||||
|
|
||||||
#define AssertEqualCategoryArrays(categories, expectedCategories) \
|
#define AssertEqualCategoryArrays(categories, expectedCategories) \
|
||||||
XCTAssertEqual(categories.count, expectedCategories.count); \
|
XCTAssertEqual(categories.count, expectedCategories.count); \
|
||||||
|
@ -670,10 +668,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
|
|
||||||
// Because of flow limiting, we cannot ensure that the callback will be
|
// Because of flow limiting, we cannot ensure that the callback will be
|
||||||
// invoked `iterationCount` times.
|
// invoked `iterationCount` times.
|
||||||
// An normal expectation will fail if expectation.fullfill() is not called
|
// An normal expectation will fail if expectation.fulfill() is not called
|
||||||
// `expectation.expectedFulfillmentCount` times.
|
// `expectation.expectedFulfillmentCount` times.
|
||||||
// If `expectation.isInverted = true`, the test will only succeed if
|
// If `expectation.isInverted = true`, the test will only succeed if
|
||||||
// expectation is not fullfilled for the specified `expectedFulfillmentCount`.
|
// expectation is not fulfilled for the specified `expectedFulfillmentCount`.
|
||||||
// Since in our case we cannot predict how many times the expectation is
|
// Since in our case we cannot predict how many times the expectation is
|
||||||
// supposed to be fullfilled setting,
|
// supposed to be fullfilled setting,
|
||||||
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
||||||
|
|
|
@ -32,9 +32,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
XCTAssertNotNil(error); \
|
XCTAssertNotNil(error); \
|
||||||
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||||
XCTAssertEqual(error.code, expectedError.code); \
|
XCTAssertEqual(error.code, expectedError.code); \
|
||||||
XCTAssertNotEqual( \
|
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
|
||||||
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
|
|
||||||
NSNotFound)
|
|
||||||
|
|
||||||
#define AssertEqualCategories(category, expectedCategory, detectionIndex, categoryIndex) \
|
#define AssertEqualCategories(category, expectedCategory, detectionIndex, categoryIndex) \
|
||||||
XCTAssertEqual(category.index, expectedCategory.index, \
|
XCTAssertEqual(category.index, expectedCategory.index, \
|
||||||
|
@ -70,7 +68,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
|
|
||||||
#pragma mark Results
|
#pragma mark Results
|
||||||
|
|
||||||
+ (MPPObjectDetectionResult *)expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
+ (MPPObjectDetectorResult *)expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
||||||
(NSInteger)timestampInMilliseconds {
|
(NSInteger)timestampInMilliseconds {
|
||||||
NSArray<MPPDetection *> *detections = @[
|
NSArray<MPPDetection *> *detections = @[
|
||||||
[[MPPDetection alloc] initWithCategories:@[
|
[[MPPDetection alloc] initWithCategories:@[
|
||||||
|
@ -95,8 +93,8 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
keypoints:nil],
|
keypoints:nil],
|
||||||
];
|
];
|
||||||
|
|
||||||
return [[MPPObjectDetectionResult alloc] initWithDetections:detections
|
return [[MPPObjectDetectorResult alloc] initWithDetections:detections
|
||||||
timestampInMilliseconds:timestampInMilliseconds];
|
timestampInMilliseconds:timestampInMilliseconds];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)assertDetections:(NSArray<MPPDetection *> *)detections
|
- (void)assertDetections:(NSArray<MPPDetection *> *)detections
|
||||||
|
@ -112,25 +110,25 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)assertObjectDetectionResult:(MPPObjectDetectionResult *)objectDetectionResult
|
- (void)assertObjectDetectorResult:(MPPObjectDetectorResult *)objectDetectorResult
|
||||||
isEqualToExpectedResult:(MPPObjectDetectionResult *)expectedObjectDetectionResult
|
isEqualToExpectedResult:(MPPObjectDetectorResult *)expectedObjectDetectorResult
|
||||||
expectedDetectionsCount:(NSInteger)expectedDetectionsCount {
|
expectedDetectionsCount:(NSInteger)expectedDetectionsCount {
|
||||||
XCTAssertNotNil(objectDetectionResult);
|
XCTAssertNotNil(objectDetectorResult);
|
||||||
|
|
||||||
NSArray<MPPDetection *> *detectionsSubsetToCompare;
|
NSArray<MPPDetection *> *detectionsSubsetToCompare;
|
||||||
XCTAssertEqual(objectDetectionResult.detections.count, expectedDetectionsCount);
|
XCTAssertEqual(objectDetectorResult.detections.count, expectedDetectionsCount);
|
||||||
if (objectDetectionResult.detections.count > expectedObjectDetectionResult.detections.count) {
|
if (objectDetectorResult.detections.count > expectedObjectDetectorResult.detections.count) {
|
||||||
detectionsSubsetToCompare = [objectDetectionResult.detections
|
detectionsSubsetToCompare = [objectDetectorResult.detections
|
||||||
subarrayWithRange:NSMakeRange(0, expectedObjectDetectionResult.detections.count)];
|
subarrayWithRange:NSMakeRange(0, expectedObjectDetectorResult.detections.count)];
|
||||||
} else {
|
} else {
|
||||||
detectionsSubsetToCompare = objectDetectionResult.detections;
|
detectionsSubsetToCompare = objectDetectorResult.detections;
|
||||||
}
|
}
|
||||||
|
|
||||||
[self assertDetections:detectionsSubsetToCompare
|
[self assertDetections:detectionsSubsetToCompare
|
||||||
isEqualToExpectedDetections:expectedObjectDetectionResult.detections];
|
isEqualToExpectedDetections:expectedObjectDetectorResult.detections];
|
||||||
|
|
||||||
XCTAssertEqual(objectDetectionResult.timestampInMilliseconds,
|
XCTAssertEqual(objectDetectorResult.timestampInMilliseconds,
|
||||||
expectedObjectDetectionResult.timestampInMilliseconds);
|
expectedObjectDetectorResult.timestampInMilliseconds);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma mark File
|
#pragma mark File
|
||||||
|
@ -195,28 +193,27 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
- (void)assertResultsOfDetectInImage:(MPPImage *)mppImage
|
- (void)assertResultsOfDetectInImage:(MPPImage *)mppImage
|
||||||
usingObjectDetector:(MPPObjectDetector *)objectDetector
|
usingObjectDetector:(MPPObjectDetector *)objectDetector
|
||||||
maxResults:(NSInteger)maxResults
|
maxResults:(NSInteger)maxResults
|
||||||
equalsObjectDetectionResult:(MPPObjectDetectionResult *)expectedObjectDetectionResult {
|
equalsObjectDetectorResult:(MPPObjectDetectorResult *)expectedObjectDetectorResult {
|
||||||
MPPObjectDetectionResult *objectDetectionResult = [objectDetector detectInImage:mppImage
|
MPPObjectDetectorResult *ObjectDetectorResult = [objectDetector detectInImage:mppImage error:nil];
|
||||||
error:nil];
|
|
||||||
|
|
||||||
[self assertObjectDetectionResult:objectDetectionResult
|
[self assertObjectDetectorResult:ObjectDetectorResult
|
||||||
isEqualToExpectedResult:expectedObjectDetectionResult
|
isEqualToExpectedResult:expectedObjectDetectorResult
|
||||||
expectedDetectionsCount:maxResults > 0 ? maxResults
|
expectedDetectionsCount:maxResults > 0 ? maxResults
|
||||||
: objectDetectionResult.detections.count];
|
: ObjectDetectorResult.detections.count];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)assertResultsOfDetectInImageWithFileInfo:(NSDictionary *)fileInfo
|
- (void)assertResultsOfDetectInImageWithFileInfo:(NSDictionary *)fileInfo
|
||||||
usingObjectDetector:(MPPObjectDetector *)objectDetector
|
usingObjectDetector:(MPPObjectDetector *)objectDetector
|
||||||
maxResults:(NSInteger)maxResults
|
maxResults:(NSInteger)maxResults
|
||||||
|
|
||||||
equalsObjectDetectionResult:
|
equalsObjectDetectorResult:
|
||||||
(MPPObjectDetectionResult *)expectedObjectDetectionResult {
|
(MPPObjectDetectorResult *)expectedObjectDetectorResult {
|
||||||
MPPImage *mppImage = [self imageWithFileInfo:fileInfo];
|
MPPImage *mppImage = [self imageWithFileInfo:fileInfo];
|
||||||
|
|
||||||
[self assertResultsOfDetectInImage:mppImage
|
[self assertResultsOfDetectInImage:mppImage
|
||||||
usingObjectDetector:objectDetector
|
usingObjectDetector:objectDetector
|
||||||
maxResults:maxResults
|
maxResults:maxResults
|
||||||
equalsObjectDetectionResult:expectedObjectDetectionResult];
|
equalsObjectDetectorResult:expectedObjectDetectorResult];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma mark General Tests
|
#pragma mark General Tests
|
||||||
|
@ -266,10 +263,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
||||||
usingObjectDetector:objectDetector
|
usingObjectDetector:objectDetector
|
||||||
maxResults:-1
|
maxResults:-1
|
||||||
equalsObjectDetectionResult:
|
equalsObjectDetectorResult:
|
||||||
[MPPObjectDetectorTests
|
[MPPObjectDetectorTests
|
||||||
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
||||||
0]];
|
0]];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)testDetectWithOptionsSucceeds {
|
- (void)testDetectWithOptionsSucceeds {
|
||||||
|
@ -280,10 +277,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
||||||
usingObjectDetector:objectDetector
|
usingObjectDetector:objectDetector
|
||||||
maxResults:-1
|
maxResults:-1
|
||||||
equalsObjectDetectionResult:
|
equalsObjectDetectorResult:
|
||||||
[MPPObjectDetectorTests
|
[MPPObjectDetectorTests
|
||||||
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
||||||
0]];
|
0]];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)testDetectWithMaxResultsSucceeds {
|
- (void)testDetectWithMaxResultsSucceeds {
|
||||||
|
@ -297,10 +294,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
||||||
usingObjectDetector:objectDetector
|
usingObjectDetector:objectDetector
|
||||||
maxResults:maxResults
|
maxResults:maxResults
|
||||||
equalsObjectDetectionResult:
|
equalsObjectDetectorResult:
|
||||||
[MPPObjectDetectorTests
|
[MPPObjectDetectorTests
|
||||||
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
||||||
0]];
|
0]];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)testDetectWithScoreThresholdSucceeds {
|
- (void)testDetectWithScoreThresholdSucceeds {
|
||||||
|
@ -316,13 +313,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
boundingBox:CGRectMake(608, 161, 381, 439)
|
boundingBox:CGRectMake(608, 161, 381, 439)
|
||||||
keypoints:nil],
|
keypoints:nil],
|
||||||
];
|
];
|
||||||
MPPObjectDetectionResult *expectedObjectDetectionResult =
|
MPPObjectDetectorResult *expectedObjectDetectorResult =
|
||||||
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
[[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
||||||
|
|
||||||
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
||||||
usingObjectDetector:objectDetector
|
usingObjectDetector:objectDetector
|
||||||
maxResults:-1
|
maxResults:-1
|
||||||
equalsObjectDetectionResult:expectedObjectDetectionResult];
|
equalsObjectDetectorResult:expectedObjectDetectorResult];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)testDetectWithCategoryAllowlistSucceeds {
|
- (void)testDetectWithCategoryAllowlistSucceeds {
|
||||||
|
@ -359,13 +356,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
keypoints:nil],
|
keypoints:nil],
|
||||||
];
|
];
|
||||||
|
|
||||||
MPPObjectDetectionResult *expectedDetectionResult =
|
MPPObjectDetectorResult *expectedDetectionResult =
|
||||||
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
[[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
||||||
|
|
||||||
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
||||||
usingObjectDetector:objectDetector
|
usingObjectDetector:objectDetector
|
||||||
maxResults:-1
|
maxResults:-1
|
||||||
equalsObjectDetectionResult:expectedDetectionResult];
|
equalsObjectDetectorResult:expectedDetectionResult];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)testDetectWithCategoryDenylistSucceeds {
|
- (void)testDetectWithCategoryDenylistSucceeds {
|
||||||
|
@ -414,13 +411,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
keypoints:nil],
|
keypoints:nil],
|
||||||
];
|
];
|
||||||
|
|
||||||
MPPObjectDetectionResult *expectedDetectionResult =
|
MPPObjectDetectorResult *expectedDetectionResult =
|
||||||
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
[[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
||||||
|
|
||||||
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
|
||||||
usingObjectDetector:objectDetector
|
usingObjectDetector:objectDetector
|
||||||
maxResults:-1
|
maxResults:-1
|
||||||
equalsObjectDetectionResult:expectedDetectionResult];
|
equalsObjectDetectorResult:expectedDetectionResult];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)testDetectWithOrientationSucceeds {
|
- (void)testDetectWithOrientationSucceeds {
|
||||||
|
@ -437,8 +434,8 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
keypoints:nil],
|
keypoints:nil],
|
||||||
];
|
];
|
||||||
|
|
||||||
MPPObjectDetectionResult *expectedDetectionResult =
|
MPPObjectDetectorResult *expectedDetectionResult =
|
||||||
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
[[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
|
||||||
|
|
||||||
MPPImage *image = [self imageWithFileInfo:kCatsAndDogsRotatedImage
|
MPPImage *image = [self imageWithFileInfo:kCatsAndDogsRotatedImage
|
||||||
orientation:UIImageOrientationRight];
|
orientation:UIImageOrientationRight];
|
||||||
|
@ -446,7 +443,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
[self assertResultsOfDetectInImage:image
|
[self assertResultsOfDetectInImage:image
|
||||||
usingObjectDetector:objectDetector
|
usingObjectDetector:objectDetector
|
||||||
maxResults:1
|
maxResults:1
|
||||||
equalsObjectDetectionResult:expectedDetectionResult];
|
equalsObjectDetectorResult:expectedDetectionResult];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma mark Running Mode Tests
|
#pragma mark Running Mode Tests
|
||||||
|
@ -613,15 +610,15 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
MPPImage *image = [self imageWithFileInfo:kCatsAndDogsImage];
|
MPPImage *image = [self imageWithFileInfo:kCatsAndDogsImage];
|
||||||
|
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
MPPObjectDetectionResult *objectDetectionResult = [objectDetector detectInVideoFrame:image
|
MPPObjectDetectorResult *ObjectDetectorResult = [objectDetector detectInVideoFrame:image
|
||||||
timestampInMilliseconds:i
|
timestampInMilliseconds:i
|
||||||
error:nil];
|
error:nil];
|
||||||
|
|
||||||
[self assertObjectDetectionResult:objectDetectionResult
|
[self assertObjectDetectorResult:ObjectDetectorResult
|
||||||
isEqualToExpectedResult:
|
isEqualToExpectedResult:
|
||||||
[MPPObjectDetectorTests
|
[MPPObjectDetectorTests
|
||||||
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:i]
|
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:i]
|
||||||
expectedDetectionsCount:maxResults];
|
expectedDetectionsCount:maxResults];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -676,15 +673,15 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
|
|
||||||
// Because of flow limiting, we cannot ensure that the callback will be
|
// Because of flow limiting, we cannot ensure that the callback will be
|
||||||
// invoked `iterationCount` times.
|
// invoked `iterationCount` times.
|
||||||
// An normal expectation will fail if expectation.fullfill() is not called
|
// An normal expectation will fail if expectation.fulfill() is not called
|
||||||
// `expectation.expectedFulfillmentCount` times.
|
// `expectation.expectedFulfillmentCount` times.
|
||||||
// If `expectation.isInverted = true`, the test will only succeed if
|
// If `expectation.isInverted = true`, the test will only succeed if
|
||||||
// expectation is not fullfilled for the specified `expectedFulfillmentCount`.
|
// expectation is not fulfilled for the specified `expectedFulfillmentCount`.
|
||||||
// Since in our case we cannot predict how many times the expectation is
|
// Since in our case we cannot predict how many times the expectation is
|
||||||
// supposed to be fullfilled setting,
|
// supposed to be fulfilled setting,
|
||||||
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
||||||
// `expectation.isInverted = true` ensures that test succeeds if
|
// `expectation.isInverted = true` ensures that test succeeds if
|
||||||
// expectation is fullfilled <= `iterationCount` times.
|
// expectation is fulfilled <= `iterationCount` times.
|
||||||
XCTestExpectation *expectation = [[XCTestExpectation alloc]
|
XCTestExpectation *expectation = [[XCTestExpectation alloc]
|
||||||
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
|
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
|
||||||
expectation.expectedFulfillmentCount = iterationCount + 1;
|
expectation.expectedFulfillmentCount = iterationCount + 1;
|
||||||
|
@ -714,16 +711,16 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
|
|
||||||
#pragma mark MPPObjectDetectorLiveStreamDelegate Methods
|
#pragma mark MPPObjectDetectorLiveStreamDelegate Methods
|
||||||
- (void)objectDetector:(MPPObjectDetector *)objectDetector
|
- (void)objectDetector:(MPPObjectDetector *)objectDetector
|
||||||
didFinishDetectionWithResult:(MPPObjectDetectionResult *)objectDetectionResult
|
didFinishDetectionWithResult:(MPPObjectDetectorResult *)ObjectDetectorResult
|
||||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
error:(NSError *)error {
|
error:(NSError *)error {
|
||||||
NSInteger maxResults = 4;
|
NSInteger maxResults = 4;
|
||||||
[self assertObjectDetectionResult:objectDetectionResult
|
[self assertObjectDetectorResult:ObjectDetectorResult
|
||||||
isEqualToExpectedResult:
|
isEqualToExpectedResult:
|
||||||
[MPPObjectDetectorTests
|
[MPPObjectDetectorTests
|
||||||
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
|
||||||
timestampInMilliseconds]
|
timestampInMilliseconds]
|
||||||
expectedDetectionsCount:maxResults];
|
expectedDetectionsCount:maxResults];
|
||||||
|
|
||||||
if (objectDetector == outOfOrderTimestampTestDict[kLiveStreamTestsDictObjectDetectorKey]) {
|
if (objectDetector == outOfOrderTimestampTestDict[kLiveStreamTestsDictObjectDetectorKey]) {
|
||||||
[outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
|
[outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
|
||||||
|
|
|
@ -64,4 +64,3 @@ objc_library(
|
||||||
"//mediapipe/tasks/ios/vision/gesture_recognizer/utils:MPPGestureRecognizerResultHelpers",
|
"//mediapipe/tasks/ios/vision/gesture_recognizer/utils:MPPGestureRecognizerResultHelpers",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,7 @@ NS_SWIFT_NAME(GestureRecognizerOptions)
|
||||||
gestureRecognizerLiveStreamDelegate;
|
gestureRecognizerLiveStreamDelegate;
|
||||||
|
|
||||||
/** Sets the maximum number of hands can be detected by the GestureRecognizer. */
|
/** Sets the maximum number of hands can be detected by the GestureRecognizer. */
|
||||||
@property(nonatomic) NSInteger numberOfHands;
|
@property(nonatomic) NSInteger numberOfHands NS_SWIFT_NAME(numHands);
|
||||||
|
|
||||||
/** Sets minimum confidence score for the hand detection to be considered successful */
|
/** Sets minimum confidence score for the hand detection to be considered successful */
|
||||||
@property(nonatomic) float minHandDetectionConfidence;
|
@property(nonatomic) float minHandDetectionConfidence;
|
||||||
|
|
|
@ -31,7 +31,8 @@
|
||||||
MPPGestureRecognizerOptions *gestureRecognizerOptions = [super copyWithZone:zone];
|
MPPGestureRecognizerOptions *gestureRecognizerOptions = [super copyWithZone:zone];
|
||||||
|
|
||||||
gestureRecognizerOptions.runningMode = self.runningMode;
|
gestureRecognizerOptions.runningMode = self.runningMode;
|
||||||
gestureRecognizerOptions.gestureRecognizerLiveStreamDelegate = self.gestureRecognizerLiveStreamDelegate;
|
gestureRecognizerOptions.gestureRecognizerLiveStreamDelegate =
|
||||||
|
self.gestureRecognizerLiveStreamDelegate;
|
||||||
gestureRecognizerOptions.numberOfHands = self.numberOfHands;
|
gestureRecognizerOptions.numberOfHands = self.numberOfHands;
|
||||||
gestureRecognizerOptions.minHandDetectionConfidence = self.minHandDetectionConfidence;
|
gestureRecognizerOptions.minHandDetectionConfidence = self.minHandDetectionConfidence;
|
||||||
gestureRecognizerOptions.minHandPresenceConfidence = self.minHandPresenceConfidence;
|
gestureRecognizerOptions.minHandPresenceConfidence = self.minHandPresenceConfidence;
|
||||||
|
|
|
@ -18,9 +18,9 @@
|
||||||
|
|
||||||
- (instancetype)initWithGestures:(NSArray<NSArray<MPPCategory *> *> *)gestures
|
- (instancetype)initWithGestures:(NSArray<NSArray<MPPCategory *> *> *)gestures
|
||||||
handedness:(NSArray<NSArray<MPPCategory *> *> *)handedness
|
handedness:(NSArray<NSArray<MPPCategory *> *> *)handedness
|
||||||
landmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)landmarks
|
landmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)landmarks
|
||||||
worldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)worldLandmarks
|
worldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)worldLandmarks
|
||||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||||
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
|
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
|
||||||
if (self) {
|
if (self) {
|
||||||
_landmarks = landmarks;
|
_landmarks = landmarks;
|
||||||
|
|
|
@ -22,7 +22,12 @@ objc_library(
|
||||||
hdrs = ["sources/MPPGestureRecognizerOptions+Helpers.h"],
|
hdrs = ["sources/MPPGestureRecognizerOptions+Helpers.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_cc_proto",
|
"//mediapipe/framework:calculator_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
"//mediapipe/tasks/ios/components/processors/utils:MPPClassifierOptionsHelpers",
|
"//mediapipe/tasks/ios/components/processors/utils:MPPClassifierOptionsHelpers",
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol",
|
"//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol",
|
||||||
|
|
|
@ -18,7 +18,12 @@
|
||||||
#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h"
|
#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h"
|
||||||
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
|
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
using CalculatorOptionsProto = mediapipe::CalculatorOptions;
|
using CalculatorOptionsProto = mediapipe::CalculatorOptions;
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
|
|
||||||
static const int kMicroSecondsPerMilliSecond = 1000;
|
static const int kMicrosecondsPerMillisecond = 1000;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
using ClassificationResultProto =
|
using ClassificationResultProto =
|
||||||
|
@ -29,19 +29,26 @@ using ::mediapipe::Packet;
|
||||||
|
|
||||||
+ (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket:
|
+ (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket:
|
||||||
(const Packet &)packet {
|
(const Packet &)packet {
|
||||||
MPPClassificationResult *classificationResult;
|
// Even if packet does not validate as the expected type, you can safely access the timestamp.
|
||||||
|
NSInteger timestampInMilliSeconds =
|
||||||
|
(NSInteger)(packet.Timestamp().Value() / kMicrosecondsPerMillisecond);
|
||||||
|
|
||||||
if (!packet.ValidateAsType<ClassificationResultProto>().ok()) {
|
if (!packet.ValidateAsType<ClassificationResultProto>().ok()) {
|
||||||
return nil;
|
// MPPClassificationResult's timestamp is populated from timestamp `ClassificationResultProto`'s
|
||||||
|
// timestamp_ms(). It is 0 since the packet can't be validated as a `ClassificationResultProto`.
|
||||||
|
return [[MPPImageClassifierResult alloc]
|
||||||
|
initWithClassificationResult:[[MPPClassificationResult alloc] initWithClassifications:@[]
|
||||||
|
timestampInMilliseconds:0]
|
||||||
|
timestampInMilliseconds:timestampInMilliSeconds];
|
||||||
}
|
}
|
||||||
|
|
||||||
classificationResult = [MPPClassificationResult
|
MPPClassificationResult *classificationResult = [MPPClassificationResult
|
||||||
classificationResultWithProto:packet.Get<ClassificationResultProto>()];
|
classificationResultWithProto:packet.Get<ClassificationResultProto>()];
|
||||||
|
|
||||||
return [[MPPImageClassifierResult alloc]
|
return [[MPPImageClassifierResult alloc]
|
||||||
initWithClassificationResult:classificationResult
|
initWithClassificationResult:classificationResult
|
||||||
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
||||||
kMicroSecondsPerMilliSecond)];
|
kMicrosecondsPerMillisecond)];
|
||||||
}
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
|
@ -17,9 +17,9 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
objc_library(
|
objc_library(
|
||||||
name = "MPPObjectDetectionResult",
|
name = "MPPObjectDetectorResult",
|
||||||
srcs = ["sources/MPPObjectDetectionResult.m"],
|
srcs = ["sources/MPPObjectDetectorResult.m"],
|
||||||
hdrs = ["sources/MPPObjectDetectionResult.h"],
|
hdrs = ["sources/MPPObjectDetectorResult.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/ios/components/containers:MPPDetection",
|
"//mediapipe/tasks/ios/components/containers:MPPDetection",
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskResult",
|
"//mediapipe/tasks/ios/core:MPPTaskResult",
|
||||||
|
@ -31,7 +31,7 @@ objc_library(
|
||||||
srcs = ["sources/MPPObjectDetectorOptions.m"],
|
srcs = ["sources/MPPObjectDetectorOptions.m"],
|
||||||
hdrs = ["sources/MPPObjectDetectorOptions.h"],
|
hdrs = ["sources/MPPObjectDetectorOptions.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":MPPObjectDetectionResult",
|
":MPPObjectDetectorResult",
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
||||||
"//mediapipe/tasks/ios/vision/core:MPPRunningMode",
|
"//mediapipe/tasks/ios/vision/core:MPPRunningMode",
|
||||||
],
|
],
|
||||||
|
@ -47,8 +47,8 @@ objc_library(
|
||||||
"-x objective-c++",
|
"-x objective-c++",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":MPPObjectDetectionResult",
|
|
||||||
":MPPObjectDetectorOptions",
|
":MPPObjectDetectorOptions",
|
||||||
|
":MPPObjectDetectorResult",
|
||||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
|
@ -56,7 +56,7 @@ objc_library(
|
||||||
"//mediapipe/tasks/ios/vision/core:MPPImage",
|
"//mediapipe/tasks/ios/vision/core:MPPImage",
|
||||||
"//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
|
"//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
|
||||||
"//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner",
|
"//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner",
|
||||||
"//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectionResultHelpers",
|
|
||||||
"//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectorOptionsHelpers",
|
"//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectorOptionsHelpers",
|
||||||
|
"//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectorResultHelpers",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
#import <Foundation/Foundation.h>
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
|
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
|
||||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
|
|
||||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h"
|
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h"
|
||||||
|
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_BEGIN
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
|
@ -109,14 +109,13 @@ NS_SWIFT_NAME(ObjectDetector)
|
||||||
* @param error An optional error parameter populated when there is an error in performing object
|
* @param error An optional error parameter populated when there is an error in performing object
|
||||||
* detection on the input image.
|
* detection on the input image.
|
||||||
*
|
*
|
||||||
* @return An `MPPObjectDetectionResult` object that contains a list of detections, each detection
|
* @return An `MPPObjectDetectorResult` object that contains a list of detections, each detection
|
||||||
* has a bounding box that is expressed in the unrotated input frame of reference coordinates
|
* has a bounding box that is expressed in the unrotated input frame of reference coordinates
|
||||||
* system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying
|
* system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying
|
||||||
* image data.
|
* image data.
|
||||||
*/
|
*/
|
||||||
- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image
|
- (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image
|
||||||
error:(NSError **)error
|
error:(NSError **)error NS_SWIFT_NAME(detect(image:));
|
||||||
NS_SWIFT_NAME(detect(image:));
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Performs object detection on the provided video frame of type `MPPImage` using the whole
|
* Performs object detection on the provided video frame of type `MPPImage` using the whole
|
||||||
|
@ -139,14 +138,14 @@ NS_SWIFT_NAME(ObjectDetector)
|
||||||
* @param error An optional error parameter populated when there is an error in performing object
|
* @param error An optional error parameter populated when there is an error in performing object
|
||||||
* detection on the input image.
|
* detection on the input image.
|
||||||
*
|
*
|
||||||
* @return An `MPPObjectDetectionResult` object that contains a list of detections, each detection
|
* @return An `MPPObjectDetectorResult` object that contains a list of detections, each detection
|
||||||
* has a bounding box that is expressed in the unrotated input frame of reference coordinates
|
* has a bounding box that is expressed in the unrotated input frame of reference coordinates
|
||||||
* system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying
|
* system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying
|
||||||
* image data.
|
* image data.
|
||||||
*/
|
*/
|
||||||
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
- (nullable MPPObjectDetectorResult *)detectInVideoFrame:(MPPImage *)image
|
||||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
error:(NSError **)error
|
error:(NSError **)error
|
||||||
NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:));
|
NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:));
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -19,8 +19,8 @@
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
||||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h"
|
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h"
|
||||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h"
|
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h"
|
||||||
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.h"
|
|
||||||
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorOptions+Helpers.h"
|
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorOptions+Helpers.h"
|
||||||
|
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
using ::mediapipe::NormalizedRect;
|
using ::mediapipe::NormalizedRect;
|
||||||
|
@ -118,9 +118,9 @@ static NSString *const kTaskName = @"objectDetector";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
MPPObjectDetectionResult *result = [MPPObjectDetectionResult
|
MPPObjectDetectorResult *result = [MPPObjectDetectorResult
|
||||||
objectDetectionResultWithDetectionsPacket:statusOrPackets.value()[kDetectionsStreamName
|
objectDetectorResultWithDetectionsPacket:statusOrPackets
|
||||||
.cppString]];
|
.value()[kDetectionsStreamName.cppString]];
|
||||||
|
|
||||||
NSInteger timeStampInMilliseconds =
|
NSInteger timeStampInMilliseconds =
|
||||||
outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() /
|
outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() /
|
||||||
|
@ -184,9 +184,9 @@ static NSString *const kTaskName = @"objectDetector";
|
||||||
return inputPacketMap;
|
return inputPacketMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image
|
- (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image
|
||||||
regionOfInterest:(CGRect)roi
|
regionOfInterest:(CGRect)roi
|
||||||
error:(NSError **)error {
|
error:(NSError **)error {
|
||||||
std::optional<NormalizedRect> rect =
|
std::optional<NormalizedRect> rect =
|
||||||
[_visionTaskRunner normalizedRectFromRegionOfInterest:roi
|
[_visionTaskRunner normalizedRectFromRegionOfInterest:roi
|
||||||
imageSize:CGSizeMake(image.width, image.height)
|
imageSize:CGSizeMake(image.width, image.height)
|
||||||
|
@ -213,18 +213,18 @@ static NSString *const kTaskName = @"objectDetector";
|
||||||
return nil;
|
return nil;
|
||||||
}
|
}
|
||||||
|
|
||||||
return [MPPObjectDetectionResult
|
return [MPPObjectDetectorResult
|
||||||
objectDetectionResultWithDetectionsPacket:outputPacketMap
|
objectDetectorResultWithDetectionsPacket:outputPacketMap
|
||||||
.value()[kDetectionsStreamName.cppString]];
|
.value()[kDetectionsStreamName.cppString]];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image error:(NSError **)error {
|
- (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image error:(NSError **)error {
|
||||||
return [self detectInImage:image regionOfInterest:CGRectZero error:error];
|
return [self detectInImage:image regionOfInterest:CGRectZero error:error];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
- (nullable MPPObjectDetectorResult *)detectInVideoFrame:(MPPImage *)image
|
||||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
error:(NSError **)error {
|
error:(NSError **)error {
|
||||||
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
|
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
|
||||||
timestampInMilliseconds:timestampInMilliseconds
|
timestampInMilliseconds:timestampInMilliseconds
|
||||||
error:error];
|
error:error];
|
||||||
|
@ -239,9 +239,9 @@ static NSString *const kTaskName = @"objectDetector";
|
||||||
return nil;
|
return nil;
|
||||||
}
|
}
|
||||||
|
|
||||||
return [MPPObjectDetectionResult
|
return [MPPObjectDetectorResult
|
||||||
objectDetectionResultWithDetectionsPacket:outputPacketMap
|
objectDetectorResultWithDetectionsPacket:outputPacketMap
|
||||||
.value()[kDetectionsStreamName.cppString]];
|
.value()[kDetectionsStreamName.cppString]];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (BOOL)detectAsyncInImage:(MPPImage *)image
|
- (BOOL)detectAsyncInImage:(MPPImage *)image
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
||||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h"
|
#import "mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h"
|
||||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
|
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_BEGIN
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ NS_SWIFT_NAME(ObjectDetectorLiveStreamDelegate)
|
||||||
*
|
*
|
||||||
* @param objectDetector The object detector which performed the object detection.
|
* @param objectDetector The object detector which performed the object detection.
|
||||||
* This is useful to test equality when there are multiple instances of `MPPObjectDetector`.
|
* This is useful to test equality when there are multiple instances of `MPPObjectDetector`.
|
||||||
* @param result The `MPPObjectDetectionResult` object that contains a list of detections, each
|
* @param result The `MPPObjectDetectorResult` object that contains a list of detections, each
|
||||||
* detection has a bounding box that is expressed in the unrotated input frame of reference
|
* detection has a bounding box that is expressed in the unrotated input frame of reference
|
||||||
* coordinates system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the
|
* coordinates system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the
|
||||||
* underlying image data.
|
* underlying image data.
|
||||||
|
@ -54,7 +54,7 @@ NS_SWIFT_NAME(ObjectDetectorLiveStreamDelegate)
|
||||||
* detection on the input live stream image data.
|
* detection on the input live stream image data.
|
||||||
*/
|
*/
|
||||||
- (void)objectDetector:(MPPObjectDetector *)objectDetector
|
- (void)objectDetector:(MPPObjectDetector *)objectDetector
|
||||||
didFinishDetectionWithResult:(nullable MPPObjectDetectionResult *)result
|
didFinishDetectionWithResult:(nullable MPPObjectDetectorResult *)result
|
||||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
error:(nullable NSError *)error
|
error:(nullable NSError *)error
|
||||||
NS_SWIFT_NAME(objectDetector(_:didFinishDetection:timestampInMilliseconds:error:));
|
NS_SWIFT_NAME(objectDetector(_:didFinishDetection:timestampInMilliseconds:error:));
|
||||||
|
|
|
@ -19,8 +19,8 @@
|
||||||
NS_ASSUME_NONNULL_BEGIN
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
/** Represents the detection results generated by `MPPObjectDetector`. */
|
/** Represents the detection results generated by `MPPObjectDetector`. */
|
||||||
NS_SWIFT_NAME(ObjectDetectionResult)
|
NS_SWIFT_NAME(ObjectDetectorResult)
|
||||||
@interface MPPObjectDetectionResult : MPPTaskResult
|
@interface MPPObjectDetectorResult : MPPTaskResult
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The array of `MPPDetection` objects each of which has a bounding box that is expressed in the
|
* The array of `MPPDetection` objects each of which has a bounding box that is expressed in the
|
||||||
|
@ -30,7 +30,7 @@ NS_SWIFT_NAME(ObjectDetectionResult)
|
||||||
@property(nonatomic, readonly) NSArray<MPPDetection *> *detections;
|
@property(nonatomic, readonly) NSArray<MPPDetection *> *detections;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initializes a new `MPPObjectDetectionResult` with the given array of detections and timestamp (in
|
* Initializes a new `MPPObjectDetectorResult` with the given array of detections and timestamp (in
|
||||||
* milliseconds).
|
* milliseconds).
|
||||||
*
|
*
|
||||||
* @param detections An array of `MPPDetection` objects each of which has a bounding box that is
|
* @param detections An array of `MPPDetection` objects each of which has a bounding box that is
|
||||||
|
@ -38,7 +38,7 @@ NS_SWIFT_NAME(ObjectDetectionResult)
|
||||||
* x [0,image_height)`, which are the dimensions of the underlying image data.
|
* x [0,image_height)`, which are the dimensions of the underlying image data.
|
||||||
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
|
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
|
||||||
*
|
*
|
||||||
* @return An instance of `MPPObjectDetectionResult` initialized with the given array of detections
|
* @return An instance of `MPPObjectDetectorResult` initialized with the given array of detections
|
||||||
* and timestamp (in milliseconds).
|
* and timestamp (in milliseconds).
|
||||||
*/
|
*/
|
||||||
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
|
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
|
|
@ -12,9 +12,9 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
|
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
|
||||||
|
|
||||||
@implementation MPPObjectDetectionResult
|
@implementation MPPObjectDetectorResult
|
||||||
|
|
||||||
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
|
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
|
||||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
|
@ -31,12 +31,12 @@ objc_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
objc_library(
|
objc_library(
|
||||||
name = "MPPObjectDetectionResultHelpers",
|
name = "MPPObjectDetectorResultHelpers",
|
||||||
srcs = ["sources/MPPObjectDetectionResult+Helpers.mm"],
|
srcs = ["sources/MPPObjectDetectorResult+Helpers.mm"],
|
||||||
hdrs = ["sources/MPPObjectDetectionResult+Helpers.h"],
|
hdrs = ["sources/MPPObjectDetectorResult+Helpers.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:packet",
|
"//mediapipe/framework:packet",
|
||||||
"//mediapipe/tasks/ios/components/containers/utils:MPPDetectionHelpers",
|
"//mediapipe/tasks/ios/components/containers/utils:MPPDetectionHelpers",
|
||||||
"//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetectionResult",
|
"//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetectorResult",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
|
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
|
||||||
|
|
||||||
#include "mediapipe/framework/packet.h"
|
#include "mediapipe/framework/packet.h"
|
||||||
|
|
||||||
|
@ -20,17 +20,17 @@ NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
static const int kMicroSecondsPerMilliSecond = 1000;
|
static const int kMicroSecondsPerMilliSecond = 1000;
|
||||||
|
|
||||||
@interface MPPObjectDetectionResult (Helpers)
|
@interface MPPObjectDetectorResult (Helpers)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates an `MPPObjectDetectionResult` from a MediaPipe packet containing a
|
* Creates an `MPPObjectDetectorResult` from a MediaPipe packet containing a
|
||||||
* `std::vector<DetectionProto>`.
|
* `std::vector<DetectionProto>`.
|
||||||
*
|
*
|
||||||
* @param packet a MediaPipe packet wrapping a `std::vector<DetectionProto>`.
|
* @param packet a MediaPipe packet wrapping a `std::vector<DetectionProto>`.
|
||||||
*
|
*
|
||||||
* @return An `MPPObjectDetectionResult` object that contains a list of detections.
|
* @return An `MPPObjectDetectorResult` object that contains a list of detections.
|
||||||
*/
|
*/
|
||||||
+ (nullable MPPObjectDetectionResult *)objectDetectionResultWithDetectionsPacket:
|
+ (nullable MPPObjectDetectorResult *)objectDetectorResultWithDetectionsPacket:
|
||||||
(const mediapipe::Packet &)packet;
|
(const mediapipe::Packet &)packet;
|
||||||
|
|
||||||
@end
|
@end
|
|
@ -12,7 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.h"
|
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.h"
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h"
|
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h"
|
||||||
|
|
||||||
|
@ -21,9 +21,9 @@ using DetectionProto = ::mediapipe::Detection;
|
||||||
using ::mediapipe::Packet;
|
using ::mediapipe::Packet;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@implementation MPPObjectDetectionResult (Helpers)
|
@implementation MPPObjectDetectorResult (Helpers)
|
||||||
|
|
||||||
+ (nullable MPPObjectDetectionResult *)objectDetectionResultWithDetectionsPacket:
|
+ (nullable MPPObjectDetectorResult *)objectDetectorResultWithDetectionsPacket:
|
||||||
(const Packet &)packet {
|
(const Packet &)packet {
|
||||||
if (!packet.ValidateAsType<std::vector<DetectionProto>>().ok()) {
|
if (!packet.ValidateAsType<std::vector<DetectionProto>>().ok()) {
|
||||||
return nil;
|
return nil;
|
||||||
|
@ -37,10 +37,10 @@ using ::mediapipe::Packet;
|
||||||
[detections addObject:[MPPDetection detectionWithProto:detectionProto]];
|
[detections addObject:[MPPDetection detectionWithProto:detectionProto]];
|
||||||
}
|
}
|
||||||
|
|
||||||
return [[MPPObjectDetectionResult alloc]
|
return
|
||||||
initWithDetections:detections
|
[[MPPObjectDetectorResult alloc] initWithDetections:detections
|
||||||
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
||||||
kMicroSecondsPerMilliSecond)];
|
kMicroSecondsPerMilliSecond)];
|
||||||
}
|
}
|
||||||
|
|
||||||
@end
|
@end
|
|
@ -166,7 +166,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
||||||
// For 90° and 270° rotations, we need to swap width and height.
|
// For 90° and 270° rotations, we need to swap width and height.
|
||||||
// This is due to the internal behavior of ImageToTensorCalculator, which:
|
// This is due to the internal behavior of ImageToTensorCalculator, which:
|
||||||
// - first denormalizes the provided rect by multiplying the rect width or
|
// - first denormalizes the provided rect by multiplying the rect width or
|
||||||
// height by the image width or height, repectively.
|
// height by the image width or height, respectively.
|
||||||
// - then rotates this by denormalized rect by the provided rotation, and
|
// - then rotates this by denormalized rect by the provided rotation, and
|
||||||
// uses this for cropping,
|
// uses this for cropping,
|
||||||
// - then finally rotates this back.
|
// - then finally rotates this back.
|
||||||
|
|
|
@ -115,6 +115,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
segmenterOptions.outputCategoryMask()
|
segmenterOptions.outputCategoryMask()
|
||||||
? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
|
? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
|
||||||
: -1;
|
: -1;
|
||||||
|
final int qualityScoresOutStreamIndex =
|
||||||
|
getStreamIndex.apply(outputStreams, "QUALITY_SCORES:quality_scores");
|
||||||
final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
|
final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
|
||||||
|
|
||||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||||
|
@ -128,6 +130,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
Optional.empty(),
|
Optional.empty(),
|
||||||
Optional.empty(),
|
Optional.empty(),
|
||||||
|
new ArrayList<>(),
|
||||||
packets.get(imageOutStreamIndex).getTimestamp());
|
packets.get(imageOutStreamIndex).getTimestamp());
|
||||||
}
|
}
|
||||||
boolean copyImage = !segmenterOptions.resultListener().isPresent();
|
boolean copyImage = !segmenterOptions.resultListener().isPresent();
|
||||||
|
@ -182,9 +185,16 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
|
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
|
||||||
categoryMask = Optional.of(builder.build());
|
categoryMask = Optional.of(builder.build());
|
||||||
}
|
}
|
||||||
|
float[] qualityScores =
|
||||||
|
PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
|
||||||
|
List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
|
||||||
|
for (float score : qualityScores) {
|
||||||
|
qualityScoresList.add(score);
|
||||||
|
}
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
confidenceMasks,
|
confidenceMasks,
|
||||||
categoryMask,
|
categoryMask,
|
||||||
|
qualityScoresList,
|
||||||
BaseVisionTaskApi.generateResultTimestampMs(
|
BaseVisionTaskApi.generateResultTimestampMs(
|
||||||
segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
|
segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
|
||||||
}
|
}
|
||||||
|
@ -592,8 +602,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
public abstract Builder setOutputCategoryMask(boolean value);
|
public abstract Builder setOutputCategoryMask(boolean value);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
|
* /** Sets an optional {@link ResultListener} to receive the segmentation results when the
|
||||||
* pipeline is done processing an image.
|
* graph pipeline is done processing an image.
|
||||||
*/
|
*/
|
||||||
public abstract Builder setResultListener(
|
public abstract Builder setResultListener(
|
||||||
ResultListener<ImageSegmenterResult, MPImage> value);
|
ResultListener<ImageSegmenterResult, MPImage> value);
|
||||||
|
|
|
@ -34,19 +34,30 @@ public abstract class ImageSegmenterResult implements TaskResult {
|
||||||
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
|
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
|
||||||
* category mask, where each pixel represents the class which the pixel in the original image
|
* category mask, where each pixel represents the class which the pixel in the original image
|
||||||
* was predicted to belong to.
|
* was predicted to belong to.
|
||||||
|
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Defaults
|
||||||
|
* to `1` if the model doesn't output quality scores. Each element corresponds to the score of
|
||||||
|
* the category in the model outputs.
|
||||||
* @param timestampMs a timestamp for this result.
|
* @param timestampMs a timestamp for this result.
|
||||||
*/
|
*/
|
||||||
// TODO: consolidate output formats across platforms.
|
// TODO: consolidate output formats across platforms.
|
||||||
public static ImageSegmenterResult create(
|
public static ImageSegmenterResult create(
|
||||||
Optional<List<MPImage>> confidenceMasks, Optional<MPImage> categoryMask, long timestampMs) {
|
Optional<List<MPImage>> confidenceMasks,
|
||||||
|
Optional<MPImage> categoryMask,
|
||||||
|
List<Float> qualityScores,
|
||||||
|
long timestampMs) {
|
||||||
return new AutoValue_ImageSegmenterResult(
|
return new AutoValue_ImageSegmenterResult(
|
||||||
confidenceMasks.map(Collections::unmodifiableList), categoryMask, timestampMs);
|
confidenceMasks.map(Collections::unmodifiableList),
|
||||||
|
categoryMask,
|
||||||
|
Collections.unmodifiableList(qualityScores),
|
||||||
|
timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract Optional<List<MPImage>> confidenceMasks();
|
public abstract Optional<List<MPImage>> confidenceMasks();
|
||||||
|
|
||||||
public abstract Optional<MPImage> categoryMask();
|
public abstract Optional<MPImage> categoryMask();
|
||||||
|
|
||||||
|
public abstract List<Float> qualityScores();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public abstract long timestampMs();
|
public abstract long timestampMs();
|
||||||
}
|
}
|
||||||
|
|
|
@ -127,6 +127,10 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
outputStreams.add("CATEGORY_MASK:category_mask");
|
outputStreams.add("CATEGORY_MASK:category_mask");
|
||||||
}
|
}
|
||||||
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
|
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
|
||||||
|
|
||||||
|
outputStreams.add("QUALITY_SCORES:quality_scores");
|
||||||
|
final int qualityScoresOutStreamIndex = outputStreams.size() - 1;
|
||||||
|
|
||||||
outputStreams.add("IMAGE:image_out");
|
outputStreams.add("IMAGE:image_out");
|
||||||
// TODO: add test for stream indices.
|
// TODO: add test for stream indices.
|
||||||
final int imageOutStreamIndex = outputStreams.size() - 1;
|
final int imageOutStreamIndex = outputStreams.size() - 1;
|
||||||
|
@ -142,6 +146,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
Optional.empty(),
|
Optional.empty(),
|
||||||
Optional.empty(),
|
Optional.empty(),
|
||||||
|
new ArrayList<>(),
|
||||||
packets.get(imageOutStreamIndex).getTimestamp());
|
packets.get(imageOutStreamIndex).getTimestamp());
|
||||||
}
|
}
|
||||||
// If resultListener is not provided, the resulted MPImage is deep copied from
|
// If resultListener is not provided, the resulted MPImage is deep copied from
|
||||||
|
@ -199,9 +204,17 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
categoryMask = Optional.of(builder.build());
|
categoryMask = Optional.of(builder.build());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float[] qualityScores =
|
||||||
|
PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
|
||||||
|
List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
|
||||||
|
for (float score : qualityScores) {
|
||||||
|
qualityScoresList.add(score);
|
||||||
|
}
|
||||||
|
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
confidenceMasks,
|
confidenceMasks,
|
||||||
categoryMask,
|
categoryMask,
|
||||||
|
qualityScoresList,
|
||||||
BaseVisionTaskApi.generateResultTimestampMs(
|
BaseVisionTaskApi.generateResultTimestampMs(
|
||||||
RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
|
RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
|
||||||
}
|
}
|
||||||
|
|
|
@ -201,6 +201,7 @@ py_test(
|
||||||
"//mediapipe/tasks/testdata/vision:test_images",
|
"//mediapipe/tasks/testdata/vision:test_images",
|
||||||
"//mediapipe/tasks/testdata/vision:test_models",
|
"//mediapipe/tasks/testdata/vision:test_models",
|
||||||
],
|
],
|
||||||
|
tags = ["not_run:arm"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/python:_framework_bindings",
|
"//mediapipe/python:_framework_bindings",
|
||||||
"//mediapipe/tasks/python/components/containers:rect",
|
"//mediapipe/tasks/python/components/containers:rect",
|
||||||
|
|
|
@ -27,13 +27,14 @@ export declare interface Detection {
|
||||||
boundingBox?: BoundingBox;
|
boundingBox?: BoundingBox;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optional list of keypoints associated with the detection. Keypoints
|
* List of keypoints associated with the detection. Keypoints represent
|
||||||
* represent interesting points related to the detection. For example, the
|
* interesting points related to the detection. For example, the keypoints
|
||||||
* keypoints represent the eye, ear and mouth from face detection model. Or
|
* represent the eye, ear and mouth from face detection model. Or in the
|
||||||
* in the template matching detection, e.g. KNIFT, they can represent the
|
* template matching detection, e.g. KNIFT, they can represent the feature
|
||||||
* feature points for template matching.
|
* points for template matching. Contains an empty list if no keypoints are
|
||||||
|
* detected.
|
||||||
*/
|
*/
|
||||||
keypoints?: NormalizedKeypoint[];
|
keypoints: NormalizedKeypoint[];
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Detection results of a model. */
|
/** Detection results of a model. */
|
||||||
|
|
|
@ -85,7 +85,8 @@ describe('convertFromDetectionProto()', () => {
|
||||||
categoryName: '',
|
categoryName: '',
|
||||||
displayName: '',
|
displayName: '',
|
||||||
}],
|
}],
|
||||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
|
boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
|
||||||
|
keypoints: []
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -26,7 +26,7 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
|
||||||
const labels = source.getLabelList();
|
const labels = source.getLabelList();
|
||||||
const displayNames = source.getDisplayNameList();
|
const displayNames = source.getDisplayNameList();
|
||||||
|
|
||||||
const detection: Detection = {categories: []};
|
const detection: Detection = {categories: [], keypoints: []};
|
||||||
for (let i = 0; i < scores.length; i++) {
|
for (let i = 0; i < scores.length; i++) {
|
||||||
detection.categories.push({
|
detection.categories.push({
|
||||||
score: scores[i],
|
score: scores[i],
|
||||||
|
@ -47,7 +47,6 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (source.getLocationData()?.getRelativeKeypointsList().length) {
|
if (source.getLocationData()?.getRelativeKeypointsList().length) {
|
||||||
detection.keypoints = [];
|
|
||||||
for (const keypoint of
|
for (const keypoint of
|
||||||
source.getLocationData()!.getRelativeKeypointsList()) {
|
source.getLocationData()!.getRelativeKeypointsList()) {
|
||||||
detection.keypoints.push({
|
detection.keypoints.push({
|
||||||
|
|
|
@ -62,7 +62,10 @@ jasmine_node_test(
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
name = "mask",
|
name = "mask",
|
||||||
srcs = ["mask.ts"],
|
srcs = ["mask.ts"],
|
||||||
deps = [":image"],
|
deps = [
|
||||||
|
":image",
|
||||||
|
"//mediapipe/web/graph_runner:platform_utils",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
|
|
|
@ -60,6 +60,10 @@ class MPImageTestContext {
|
||||||
|
|
||||||
this.webGLTexture = gl.createTexture()!;
|
this.webGLTexture = gl.createTexture()!;
|
||||||
gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture);
|
gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
|
||||||
gl.texImage2D(
|
gl.texImage2D(
|
||||||
gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, this.imageBitmap);
|
gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, this.imageBitmap);
|
||||||
gl.bindTexture(gl.TEXTURE_2D, null);
|
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||||
|
|
|
@ -187,10 +187,11 @@ export class MPImage {
|
||||||
destinationContainer =
|
destinationContainer =
|
||||||
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
||||||
gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
|
gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
|
||||||
|
this.configureTextureParams();
|
||||||
gl.texImage2D(
|
gl.texImage2D(
|
||||||
gl.TEXTURE_2D, 0, gl.RGBA, this.width, this.height, 0, gl.RGBA,
|
gl.TEXTURE_2D, 0, gl.RGBA, this.width, this.height, 0, gl.RGBA,
|
||||||
gl.UNSIGNED_BYTE, null);
|
gl.UNSIGNED_BYTE, null);
|
||||||
|
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||||
|
|
||||||
shaderContext.bindFramebuffer(gl, destinationContainer);
|
shaderContext.bindFramebuffer(gl, destinationContainer);
|
||||||
shaderContext.run(gl, /* flipVertically= */ false, () => {
|
shaderContext.run(gl, /* flipVertically= */ false, () => {
|
||||||
|
@ -302,6 +303,20 @@ export class MPImage {
|
||||||
return webGLTexture;
|
return webGLTexture;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Sets texture params for the currently bound texture. */
|
||||||
|
private configureTextureParams() {
|
||||||
|
const gl = this.getGL();
|
||||||
|
// `gl.LINEAR` might break rendering for some textures, but it allows us to
|
||||||
|
// do smooth resizing. Ideally, this would be user-configurable, but for now
|
||||||
|
// we hard-code the value here to `gl.LINEAR` (versus `gl.NEAREST` for
|
||||||
|
// `MPMask` where we do not want to interpolate mask values, especially for
|
||||||
|
// category masks).
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Binds the backing texture to the canvas. If the texture does not yet
|
* Binds the backing texture to the canvas. If the texture does not yet
|
||||||
* exist, creates it first.
|
* exist, creates it first.
|
||||||
|
@ -318,16 +333,12 @@ export class MPImage {
|
||||||
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
||||||
this.containers.push(webGLTexture);
|
this.containers.push(webGLTexture);
|
||||||
this.ownsWebGLTexture = true;
|
this.ownsWebGLTexture = true;
|
||||||
|
|
||||||
|
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
|
||||||
|
this.configureTextureParams();
|
||||||
|
} else {
|
||||||
|
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
|
||||||
}
|
}
|
||||||
|
|
||||||
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
|
|
||||||
// TODO: Ideally, we would only set these once per texture and
|
|
||||||
// not once every frame.
|
|
||||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
|
|
||||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
|
|
||||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
|
|
||||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
|
|
||||||
|
|
||||||
return webGLTexture;
|
return webGLTexture;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -60,8 +60,11 @@ class MPMaskTestContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
this.webGLTexture = gl.createTexture()!;
|
this.webGLTexture = gl.createTexture()!;
|
||||||
|
|
||||||
gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture);
|
gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
|
||||||
gl.texImage2D(
|
gl.texImage2D(
|
||||||
gl.TEXTURE_2D, 0, gl.R32F, width, height, 0, gl.RED, gl.FLOAT,
|
gl.TEXTURE_2D, 0, gl.R32F, width, height, 0, gl.RED, gl.FLOAT,
|
||||||
new Float32Array(pixels).map(v => v / 255));
|
new Float32Array(pixels).map(v => v / 255));
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
|
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
|
||||||
|
import {isIOS} from '../../../../web/graph_runner/platform_utils';
|
||||||
|
|
||||||
/** Number of instances a user can keep alive before we raise a warning. */
|
/** Number of instances a user can keep alive before we raise a warning. */
|
||||||
const INSTANCE_COUNT_WARNING_THRESHOLD = 250;
|
const INSTANCE_COUNT_WARNING_THRESHOLD = 250;
|
||||||
|
@ -32,6 +33,8 @@ enum MPMaskType {
|
||||||
/** The supported mask formats. For internal usage. */
|
/** The supported mask formats. For internal usage. */
|
||||||
export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
|
export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The wrapper class for MediaPipe segmentation masks.
|
* The wrapper class for MediaPipe segmentation masks.
|
||||||
*
|
*
|
||||||
|
@ -56,6 +59,9 @@ export class MPMask {
|
||||||
*/
|
*/
|
||||||
private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD;
|
private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD;
|
||||||
|
|
||||||
|
/** The format used to write pixel values from textures. */
|
||||||
|
private static texImage2DFormat?: GLenum;
|
||||||
|
|
||||||
/** @hideconstructor */
|
/** @hideconstructor */
|
||||||
constructor(
|
constructor(
|
||||||
private readonly containers: MPMaskContainer[],
|
private readonly containers: MPMaskContainer[],
|
||||||
|
@ -127,6 +133,29 @@ export class MPMask {
|
||||||
return this.convertToWebGLTexture();
|
return this.convertToWebGLTexture();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the texture format used for writing float textures on this
|
||||||
|
* platform.
|
||||||
|
*/
|
||||||
|
getTexImage2DFormat(): GLenum {
|
||||||
|
const gl = this.getGL();
|
||||||
|
if (!MPMask.texImage2DFormat) {
|
||||||
|
// Note: This is the same check we use in
|
||||||
|
// `SegmentationPostprocessorGl::GetSegmentationResultGpu()`.
|
||||||
|
if (gl.getExtension('EXT_color_buffer_float') &&
|
||||||
|
gl.getExtension('OES_texture_float_linear') &&
|
||||||
|
gl.getExtension('EXT_float_blend')) {
|
||||||
|
MPMask.texImage2DFormat = gl.R32F;
|
||||||
|
} else if (gl.getExtension('EXT_color_buffer_half_float')) {
|
||||||
|
MPMask.texImage2DFormat = gl.R16F;
|
||||||
|
} else {
|
||||||
|
throw new Error(
|
||||||
|
'GPU does not fully support 4-channel float32 or float16 formats');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return MPMask.texImage2DFormat;
|
||||||
|
}
|
||||||
|
|
||||||
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
|
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
|
||||||
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
|
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
|
||||||
private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
|
private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
|
||||||
|
@ -175,8 +204,10 @@ export class MPMask {
|
||||||
destinationContainer =
|
destinationContainer =
|
||||||
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
||||||
gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
|
gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
|
||||||
|
this.configureTextureParams();
|
||||||
|
const format = this.getTexImage2DFormat();
|
||||||
gl.texImage2D(
|
gl.texImage2D(
|
||||||
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
|
gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
|
||||||
gl.FLOAT, null);
|
gl.FLOAT, null);
|
||||||
gl.bindTexture(gl.TEXTURE_2D, null);
|
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||||
|
|
||||||
|
@ -207,7 +238,7 @@ export class MPMask {
|
||||||
if (!this.canvas) {
|
if (!this.canvas) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
'Conversion to different image formats require that a canvas ' +
|
'Conversion to different image formats require that a canvas ' +
|
||||||
'is passed when iniitializing the image.');
|
'is passed when initializing the image.');
|
||||||
}
|
}
|
||||||
if (!this.gl) {
|
if (!this.gl) {
|
||||||
this.gl = assertNotNull(
|
this.gl = assertNotNull(
|
||||||
|
@ -215,11 +246,6 @@ export class MPMask {
|
||||||
'You cannot use a canvas that is already bound to a different ' +
|
'You cannot use a canvas that is already bound to a different ' +
|
||||||
'type of rendering context.');
|
'type of rendering context.');
|
||||||
}
|
}
|
||||||
const ext = this.gl.getExtension('EXT_color_buffer_float');
|
|
||||||
if (!ext) {
|
|
||||||
// TODO: Ensure this works on iOS
|
|
||||||
throw new Error('Missing required EXT_color_buffer_float extension');
|
|
||||||
}
|
|
||||||
return this.gl;
|
return this.gl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -237,18 +263,34 @@ export class MPMask {
|
||||||
if (uint8Array) {
|
if (uint8Array) {
|
||||||
float32Array = new Float32Array(uint8Array).map(v => v / 255);
|
float32Array = new Float32Array(uint8Array).map(v => v / 255);
|
||||||
} else {
|
} else {
|
||||||
|
float32Array = new Float32Array(this.width * this.height);
|
||||||
|
|
||||||
const gl = this.getGL();
|
const gl = this.getGL();
|
||||||
const shaderContext = this.getShaderContext();
|
const shaderContext = this.getShaderContext();
|
||||||
float32Array = new Float32Array(this.width * this.height);
|
|
||||||
|
|
||||||
// Create texture if needed
|
// Create texture if needed
|
||||||
const webGlTexture = this.convertToWebGLTexture();
|
const webGlTexture = this.convertToWebGLTexture();
|
||||||
|
|
||||||
// Create a framebuffer from the texture and read back pixels
|
// Create a framebuffer from the texture and read back pixels
|
||||||
shaderContext.bindFramebuffer(gl, webGlTexture);
|
shaderContext.bindFramebuffer(gl, webGlTexture);
|
||||||
gl.readPixels(
|
|
||||||
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
|
if (isIOS()) {
|
||||||
shaderContext.unbindFramebuffer();
|
// WebKit on iOS only supports gl.HALF_FLOAT for single channel reads
|
||||||
|
// (as tested on iOS 16.4). HALF_FLOAT requires reading data into a
|
||||||
|
// Uint16Array, however, and requires a manual bitwise conversion from
|
||||||
|
// Uint16 to floating point numbers. This conversion is more expensive
|
||||||
|
// that reading back a Float32Array from the RGBA image and dropping
|
||||||
|
// the superfluous data, so we do this instead.
|
||||||
|
const outputArray = new Float32Array(this.width * this.height * 4);
|
||||||
|
gl.readPixels(
|
||||||
|
0, 0, this.width, this.height, gl.RGBA, gl.FLOAT, outputArray);
|
||||||
|
for (let i = 0, j = 0; i < float32Array.length; ++i, j += 4) {
|
||||||
|
float32Array[i] = outputArray[j];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gl.readPixels(
|
||||||
|
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
this.containers.push(float32Array);
|
this.containers.push(float32Array);
|
||||||
}
|
}
|
||||||
|
@ -273,9 +315,9 @@ export class MPMask {
|
||||||
webGLTexture = this.bindTexture();
|
webGLTexture = this.bindTexture();
|
||||||
|
|
||||||
const data = this.convertToFloat32Array();
|
const data = this.convertToFloat32Array();
|
||||||
// TODO: Add support for R16F to support iOS
|
const format = this.getTexImage2DFormat();
|
||||||
gl.texImage2D(
|
gl.texImage2D(
|
||||||
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
|
gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
|
||||||
gl.FLOAT, data);
|
gl.FLOAT, data);
|
||||||
this.unbindTexture();
|
this.unbindTexture();
|
||||||
}
|
}
|
||||||
|
@ -283,6 +325,19 @@ export class MPMask {
|
||||||
return webGLTexture;
|
return webGLTexture;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Sets texture params for the currently bound texture. */
|
||||||
|
private configureTextureParams() {
|
||||||
|
const gl = this.getGL();
|
||||||
|
// `gl.NEAREST` ensures that we do not get interpolated values for
|
||||||
|
// masks. In some cases, the user might want interpolation (e.g. for
|
||||||
|
// confidence masks), so we might want to make this user-configurable.
|
||||||
|
// Note that `MPImage` uses `gl.LINEAR`.
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Binds the backing texture to the canvas. If the texture does not yet
|
* Binds the backing texture to the canvas. If the texture does not yet
|
||||||
* exist, creates it first.
|
* exist, creates it first.
|
||||||
|
@ -299,15 +354,12 @@ export class MPMask {
|
||||||
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
||||||
this.containers.push(webGLTexture);
|
this.containers.push(webGLTexture);
|
||||||
this.ownsWebGLTexture = true;
|
this.ownsWebGLTexture = true;
|
||||||
}
|
|
||||||
|
|
||||||
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
|
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
|
||||||
// TODO: Ideally, we would only set these once per texture and
|
this.configureTextureParams();
|
||||||
// not once every frame.
|
} else {
|
||||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
|
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
|
||||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
|
}
|
||||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
|
|
||||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
|
|
||||||
|
|
||||||
return webGLTexture;
|
return webGLTexture;
|
||||||
}
|
}
|
||||||
|
|
|
@ -191,7 +191,8 @@ describe('FaceDetector', () => {
|
||||||
categoryName: '',
|
categoryName: '',
|
||||||
displayName: '',
|
displayName: '',
|
||||||
}],
|
}],
|
||||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
|
boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
|
||||||
|
keypoints: []
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -171,7 +171,7 @@ export class FaceStylizer extends VisionTaskRunner {
|
||||||
/**
|
/**
|
||||||
* Performs face stylization on the provided single image and returns the
|
* Performs face stylization on the provided single image and returns the
|
||||||
* result. This method creates a copy of the resulting image and should not be
|
* result. This method creates a copy of the resulting image and should not be
|
||||||
* used in high-throughput applictions. Only use this method when the
|
* used in high-throughput applications. Only use this method when the
|
||||||
* FaceStylizer is created with the image running mode.
|
* FaceStylizer is created with the image running mode.
|
||||||
*
|
*
|
||||||
* @param image An image to process.
|
* @param image An image to process.
|
||||||
|
@ -182,7 +182,7 @@ export class FaceStylizer extends VisionTaskRunner {
|
||||||
/**
|
/**
|
||||||
* Performs face stylization on the provided single image and returns the
|
* Performs face stylization on the provided single image and returns the
|
||||||
* result. This method creates a copy of the resulting image and should not be
|
* result. This method creates a copy of the resulting image and should not be
|
||||||
* used in high-throughput applictions. Only use this method when the
|
* used in high-throughput applications. Only use this method when the
|
||||||
* FaceStylizer is created with the image running mode.
|
* FaceStylizer is created with the image running mode.
|
||||||
*
|
*
|
||||||
* The 'imageProcessingOptions' parameter can be used to specify one or all
|
* The 'imageProcessingOptions' parameter can be used to specify one or all
|
||||||
|
@ -275,7 +275,7 @@ export class FaceStylizer extends VisionTaskRunner {
|
||||||
/**
|
/**
|
||||||
* Performs face stylization on the provided video frame. This method creates
|
* Performs face stylization on the provided video frame. This method creates
|
||||||
* a copy of the resulting image and should not be used in high-throughput
|
* a copy of the resulting image and should not be used in high-throughput
|
||||||
* applictions. Only use this method when the FaceStylizer is created with the
|
* applications. Only use this method when the FaceStylizer is created with the
|
||||||
* video running mode.
|
* video running mode.
|
||||||
*
|
*
|
||||||
* The input frame can be of any size. It's required to provide the video
|
* The input frame can be of any size. It's required to provide the video
|
||||||
|
|
|
@ -39,6 +39,7 @@ const IMAGE_STREAM = 'image_in';
|
||||||
const NORM_RECT_STREAM = 'norm_rect';
|
const NORM_RECT_STREAM = 'norm_rect';
|
||||||
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||||
const CATEGORY_MASK_STREAM = 'category_mask';
|
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||||
|
const QUALITY_SCORES_STREAM = 'quality_scores';
|
||||||
const IMAGE_SEGMENTER_GRAPH =
|
const IMAGE_SEGMENTER_GRAPH =
|
||||||
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
||||||
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||||
|
@ -61,6 +62,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
|
||||||
export class ImageSegmenter extends VisionTaskRunner {
|
export class ImageSegmenter extends VisionTaskRunner {
|
||||||
private categoryMask?: MPMask;
|
private categoryMask?: MPMask;
|
||||||
private confidenceMasks?: MPMask[];
|
private confidenceMasks?: MPMask[];
|
||||||
|
private qualityScores?: number[];
|
||||||
private labels: string[] = [];
|
private labels: string[] = [];
|
||||||
private userCallback?: ImageSegmenterCallback;
|
private userCallback?: ImageSegmenterCallback;
|
||||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||||
|
@ -229,7 +231,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
/**
|
/**
|
||||||
* Performs image segmentation on the provided single image and returns the
|
* Performs image segmentation on the provided single image and returns the
|
||||||
* segmentation result. This method creates a copy of the resulting masks and
|
* segmentation result. This method creates a copy of the resulting masks and
|
||||||
* should not be used in high-throughput applictions. Only use this method
|
* should not be used in high-throughput applications. Only use this method
|
||||||
* when the ImageSegmenter is created with running mode `image`.
|
* when the ImageSegmenter is created with running mode `image`.
|
||||||
*
|
*
|
||||||
* @param image An image to process.
|
* @param image An image to process.
|
||||||
|
@ -240,7 +242,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
/**
|
/**
|
||||||
* Performs image segmentation on the provided single image and returns the
|
* Performs image segmentation on the provided single image and returns the
|
||||||
* segmentation result. This method creates a copy of the resulting masks and
|
* segmentation result. This method creates a copy of the resulting masks and
|
||||||
* should not be used in high-v applictions. Only use this method when
|
* should not be used in high-v applications. Only use this method when
|
||||||
* the ImageSegmenter is created with running mode `image`.
|
* the ImageSegmenter is created with running mode `image`.
|
||||||
*
|
*
|
||||||
* @param image An image to process.
|
* @param image An image to process.
|
||||||
|
@ -318,7 +320,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
/**
|
/**
|
||||||
* Performs image segmentation on the provided video frame and returns the
|
* Performs image segmentation on the provided video frame and returns the
|
||||||
* segmentation result. This method creates a copy of the resulting masks and
|
* segmentation result. This method creates a copy of the resulting masks and
|
||||||
* should not be used in high-v applictions. Only use this method when
|
* should not be used in high-v applications. Only use this method when
|
||||||
* the ImageSegmenter is created with running mode `video`.
|
* the ImageSegmenter is created with running mode `video`.
|
||||||
*
|
*
|
||||||
* @param videoFrame A video frame to process.
|
* @param videoFrame A video frame to process.
|
||||||
|
@ -367,12 +369,13 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
private reset(): void {
|
private reset(): void {
|
||||||
this.categoryMask = undefined;
|
this.categoryMask = undefined;
|
||||||
this.confidenceMasks = undefined;
|
this.confidenceMasks = undefined;
|
||||||
|
this.qualityScores = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
private processResults(): ImageSegmenterResult|void {
|
private processResults(): ImageSegmenterResult|void {
|
||||||
try {
|
try {
|
||||||
const result =
|
const result = new ImageSegmenterResult(
|
||||||
new ImageSegmenterResult(this.confidenceMasks, this.categoryMask);
|
this.confidenceMasks, this.categoryMask, this.qualityScores);
|
||||||
if (this.userCallback) {
|
if (this.userCallback) {
|
||||||
this.userCallback(result);
|
this.userCallback(result);
|
||||||
} else {
|
} else {
|
||||||
|
@ -442,6 +445,20 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
|
||||||
|
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
|
||||||
|
|
||||||
|
this.graphRunner.attachFloatVectorListener(
|
||||||
|
QUALITY_SCORES_STREAM, (scores, timestamp) => {
|
||||||
|
this.qualityScores = scores;
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
|
QUALITY_SCORES_STREAM, timestamp => {
|
||||||
|
this.categoryMask = undefined;
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
|
||||||
const binaryGraph = graphConfig.serializeBinary();
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,13 @@ export class ImageSegmenterResult {
|
||||||
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
||||||
* which the pixel in the original image was predicted to belong to.
|
* which the pixel in the original image was predicted to belong to.
|
||||||
*/
|
*/
|
||||||
readonly categoryMask?: MPMask) {}
|
readonly categoryMask?: MPMask,
|
||||||
|
/**
|
||||||
|
* The quality scores of the result masks, in the range of [0, 1].
|
||||||
|
* Defaults to `1` if the model doesn't output quality scores. Each
|
||||||
|
* element corresponds to the score of the category in the model outputs.
|
||||||
|
*/
|
||||||
|
readonly qualityScores?: number[]) {}
|
||||||
|
|
||||||
/** Frees the resources held by the category and confidence masks. */
|
/** Frees the resources held by the category and confidence masks. */
|
||||||
close(): void {
|
close(): void {
|
||||||
|
|
|
@ -35,6 +35,8 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
||||||
((images: WasmImage, timestamp: number) => void)|undefined;
|
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||||
confidenceMasksListener:
|
confidenceMasksListener:
|
||||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||||
|
qualityScoresListener:
|
||||||
|
((data: number[], timestamp: number) => void)|undefined;
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
super(createSpyWasmModule(), /* glCanvas= */ null);
|
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||||
|
@ -52,6 +54,12 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
||||||
expect(stream).toEqual('confidence_masks');
|
expect(stream).toEqual('confidence_masks');
|
||||||
this.confidenceMasksListener = listener;
|
this.confidenceMasksListener = listener;
|
||||||
});
|
});
|
||||||
|
this.attachListenerSpies[2] =
|
||||||
|
spyOn(this.graphRunner, 'attachFloatVectorListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('quality_scores');
|
||||||
|
this.qualityScoresListener = listener;
|
||||||
|
});
|
||||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
});
|
});
|
||||||
|
@ -266,6 +274,7 @@ describe('ImageSegmenter', () => {
|
||||||
it('invokes listener after masks are available', async () => {
|
it('invokes listener after masks are available', async () => {
|
||||||
const categoryMask = new Uint8Array([1]);
|
const categoryMask = new Uint8Array([1]);
|
||||||
const confidenceMask = new Float32Array([0.0]);
|
const confidenceMask = new Float32Array([0.0]);
|
||||||
|
const qualityScores = [1.0];
|
||||||
let listenerCalled = false;
|
let listenerCalled = false;
|
||||||
|
|
||||||
await imageSegmenter.setOptions(
|
await imageSegmenter.setOptions(
|
||||||
|
@ -283,11 +292,16 @@ describe('ImageSegmenter', () => {
|
||||||
],
|
],
|
||||||
1337);
|
1337);
|
||||||
expect(listenerCalled).toBeFalse();
|
expect(listenerCalled).toBeFalse();
|
||||||
|
imageSegmenter.qualityScoresListener!(qualityScores, 1337);
|
||||||
|
expect(listenerCalled).toBeFalse();
|
||||||
});
|
});
|
||||||
|
|
||||||
return new Promise<void>(resolve => {
|
return new Promise<void>(resolve => {
|
||||||
imageSegmenter.segment({} as HTMLImageElement, () => {
|
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||||
listenerCalled = true;
|
listenerCalled = true;
|
||||||
|
expect(result.categoryMask).toBeInstanceOf(MPMask);
|
||||||
|
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
|
||||||
|
expect(result.qualityScores).toEqual(qualityScores);
|
||||||
resolve();
|
resolve();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -42,6 +42,7 @@ const NORM_RECT_IN_STREAM = 'norm_rect_in';
|
||||||
const ROI_IN_STREAM = 'roi_in';
|
const ROI_IN_STREAM = 'roi_in';
|
||||||
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||||
const CATEGORY_MASK_STREAM = 'category_mask';
|
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||||
|
const QUALITY_SCORES_STREAM = 'quality_scores';
|
||||||
const IMAGEA_SEGMENTER_GRAPH =
|
const IMAGEA_SEGMENTER_GRAPH =
|
||||||
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
|
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
|
||||||
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
|
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
|
||||||
|
@ -86,6 +87,7 @@ export type InteractiveSegmenterCallback =
|
||||||
export class InteractiveSegmenter extends VisionTaskRunner {
|
export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
private categoryMask?: MPMask;
|
private categoryMask?: MPMask;
|
||||||
private confidenceMasks?: MPMask[];
|
private confidenceMasks?: MPMask[];
|
||||||
|
private qualityScores?: number[];
|
||||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||||
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||||
private userCallback?: InteractiveSegmenterCallback;
|
private userCallback?: InteractiveSegmenterCallback;
|
||||||
|
@ -284,12 +286,13 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
private reset(): void {
|
private reset(): void {
|
||||||
this.confidenceMasks = undefined;
|
this.confidenceMasks = undefined;
|
||||||
this.categoryMask = undefined;
|
this.categoryMask = undefined;
|
||||||
|
this.qualityScores = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
private processResults(): InteractiveSegmenterResult|void {
|
private processResults(): InteractiveSegmenterResult|void {
|
||||||
try {
|
try {
|
||||||
const result = new InteractiveSegmenterResult(
|
const result = new InteractiveSegmenterResult(
|
||||||
this.confidenceMasks, this.categoryMask);
|
this.confidenceMasks, this.categoryMask, this.qualityScores);
|
||||||
if (this.userCallback) {
|
if (this.userCallback) {
|
||||||
this.userCallback(result);
|
this.userCallback(result);
|
||||||
} else {
|
} else {
|
||||||
|
@ -361,6 +364,20 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
|
||||||
|
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
|
||||||
|
|
||||||
|
this.graphRunner.attachFloatVectorListener(
|
||||||
|
QUALITY_SCORES_STREAM, (scores, timestamp) => {
|
||||||
|
this.qualityScores = scores;
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
|
QUALITY_SCORES_STREAM, timestamp => {
|
||||||
|
this.categoryMask = undefined;
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
|
||||||
const binaryGraph = graphConfig.serializeBinary();
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,13 @@ export class InteractiveSegmenterResult {
|
||||||
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
||||||
* which the pixel in the original image was predicted to belong to.
|
* which the pixel in the original image was predicted to belong to.
|
||||||
*/
|
*/
|
||||||
readonly categoryMask?: MPMask) {}
|
readonly categoryMask?: MPMask,
|
||||||
|
/**
|
||||||
|
* The quality scores of the result masks, in the range of [0, 1].
|
||||||
|
* Defaults to `1` if the model doesn't output quality scores. Each
|
||||||
|
* element corresponds to the score of the category in the model outputs.
|
||||||
|
*/
|
||||||
|
readonly qualityScores?: number[]) {}
|
||||||
|
|
||||||
/** Frees the resources held by the category and confidence masks. */
|
/** Frees the resources held by the category and confidence masks. */
|
||||||
close(): void {
|
close(): void {
|
||||||
|
|
|
@ -46,6 +46,8 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
||||||
((images: WasmImage, timestamp: number) => void)|undefined;
|
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||||
confidenceMasksListener:
|
confidenceMasksListener:
|
||||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||||
|
qualityScoresListener:
|
||||||
|
((data: number[], timestamp: number) => void)|undefined;
|
||||||
lastRoi?: RenderDataProto;
|
lastRoi?: RenderDataProto;
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
|
@ -64,6 +66,12 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
||||||
expect(stream).toEqual('confidence_masks');
|
expect(stream).toEqual('confidence_masks');
|
||||||
this.confidenceMasksListener = listener;
|
this.confidenceMasksListener = listener;
|
||||||
});
|
});
|
||||||
|
this.attachListenerSpies[2] =
|
||||||
|
spyOn(this.graphRunner, 'attachFloatVectorListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('quality_scores');
|
||||||
|
this.qualityScoresListener = listener;
|
||||||
|
});
|
||||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
});
|
});
|
||||||
|
@ -277,9 +285,10 @@ describe('InteractiveSegmenter', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('invokes listener after masks are avaiblae', async () => {
|
it('invokes listener after masks are available', async () => {
|
||||||
const categoryMask = new Uint8Array([1]);
|
const categoryMask = new Uint8Array([1]);
|
||||||
const confidenceMask = new Float32Array([0.0]);
|
const confidenceMask = new Float32Array([0.0]);
|
||||||
|
const qualityScores = [1.0];
|
||||||
let listenerCalled = false;
|
let listenerCalled = false;
|
||||||
|
|
||||||
await interactiveSegmenter.setOptions(
|
await interactiveSegmenter.setOptions(
|
||||||
|
@ -297,11 +306,16 @@ describe('InteractiveSegmenter', () => {
|
||||||
],
|
],
|
||||||
1337);
|
1337);
|
||||||
expect(listenerCalled).toBeFalse();
|
expect(listenerCalled).toBeFalse();
|
||||||
|
interactiveSegmenter.qualityScoresListener!(qualityScores, 1337);
|
||||||
|
expect(listenerCalled).toBeFalse();
|
||||||
});
|
});
|
||||||
|
|
||||||
return new Promise<void>(resolve => {
|
return new Promise<void>(resolve => {
|
||||||
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => {
|
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
|
||||||
listenerCalled = true;
|
listenerCalled = true;
|
||||||
|
expect(result.categoryMask).toBeInstanceOf(MPMask);
|
||||||
|
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
|
||||||
|
expect(result.qualityScores).toEqual(qualityScores);
|
||||||
resolve();
|
resolve();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -210,7 +210,8 @@ describe('ObjectDetector', () => {
|
||||||
categoryName: '',
|
categoryName: '',
|
||||||
displayName: '',
|
displayName: '',
|
||||||
}],
|
}],
|
||||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
|
boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
|
||||||
|
keypoints: []
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -21,3 +21,16 @@ export function isWebKit(browser = navigator) {
|
||||||
// it uses "CriOS".
|
// it uses "CriOS".
|
||||||
return userAgent.includes('Safari') && !userAgent.includes('Chrome');
|
return userAgent.includes('Safari') && !userAgent.includes('Chrome');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Detect if code is running on iOS. */
|
||||||
|
export function isIOS() {
|
||||||
|
// Source:
|
||||||
|
// https://stackoverflow.com/questions/9038625/detect-if-device-is-ios
|
||||||
|
return [
|
||||||
|
'iPad Simulator', 'iPhone Simulator', 'iPod Simulator', 'iPad', 'iPhone',
|
||||||
|
'iPod'
|
||||||
|
// tslint:disable-next-line:deprecation
|
||||||
|
].includes(navigator.platform)
|
||||||
|
// iPad on iOS 13 detection
|
||||||
|
|| (navigator.userAgent.includes('Mac') && 'ontouchend' in document);
|
||||||
|
}
|
||||||
|
|
5
setup.py
5
setup.py
|
@ -357,7 +357,10 @@ class BuildExtension(build_ext.build_ext):
|
||||||
for ext in self.extensions:
|
for ext in self.extensions:
|
||||||
target_name = self.get_ext_fullpath(ext.name)
|
target_name = self.get_ext_fullpath(ext.name)
|
||||||
# Build x86
|
# Build x86
|
||||||
self._build_binary(ext)
|
self._build_binary(
|
||||||
|
ext,
|
||||||
|
['--cpu=darwin', '--ios_multi_cpus=i386,x86_64,armv7,arm64'],
|
||||||
|
)
|
||||||
x86_name = self.get_ext_fullpath(ext.name)
|
x86_name = self.get_ext_fullpath(ext.name)
|
||||||
# Build Arm64
|
# Build Arm64
|
||||||
ext.name = ext.name + '.arm64'
|
ext.name = ext.name + '.arm64'
|
||||||
|
|
3
third_party/flatbuffers/BUILD.bazel
vendored
3
third_party/flatbuffers/BUILD.bazel
vendored
|
@ -42,16 +42,15 @@ filegroup(
|
||||||
"include/flatbuffers/allocator.h",
|
"include/flatbuffers/allocator.h",
|
||||||
"include/flatbuffers/array.h",
|
"include/flatbuffers/array.h",
|
||||||
"include/flatbuffers/base.h",
|
"include/flatbuffers/base.h",
|
||||||
"include/flatbuffers/bfbs_generator.h",
|
|
||||||
"include/flatbuffers/buffer.h",
|
"include/flatbuffers/buffer.h",
|
||||||
"include/flatbuffers/buffer_ref.h",
|
"include/flatbuffers/buffer_ref.h",
|
||||||
"include/flatbuffers/code_generator.h",
|
"include/flatbuffers/code_generator.h",
|
||||||
"include/flatbuffers/code_generators.h",
|
"include/flatbuffers/code_generators.h",
|
||||||
"include/flatbuffers/default_allocator.h",
|
"include/flatbuffers/default_allocator.h",
|
||||||
"include/flatbuffers/detached_buffer.h",
|
"include/flatbuffers/detached_buffer.h",
|
||||||
|
"include/flatbuffers/file_manager.h",
|
||||||
"include/flatbuffers/flatbuffer_builder.h",
|
"include/flatbuffers/flatbuffer_builder.h",
|
||||||
"include/flatbuffers/flatbuffers.h",
|
"include/flatbuffers/flatbuffers.h",
|
||||||
"include/flatbuffers/flatc.h",
|
|
||||||
"include/flatbuffers/flex_flat_util.h",
|
"include/flatbuffers/flex_flat_util.h",
|
||||||
"include/flatbuffers/flexbuffers.h",
|
"include/flatbuffers/flexbuffers.h",
|
||||||
"include/flatbuffers/grpc.h",
|
"include/flatbuffers/grpc.h",
|
||||||
|
|
8
third_party/flatbuffers/workspace.bzl
vendored
8
third_party/flatbuffers/workspace.bzl
vendored
|
@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
|
||||||
def repo():
|
def repo():
|
||||||
third_party_http_archive(
|
third_party_http_archive(
|
||||||
name = "flatbuffers",
|
name = "flatbuffers",
|
||||||
strip_prefix = "flatbuffers-23.1.21",
|
strip_prefix = "flatbuffers-23.5.8",
|
||||||
sha256 = "d84cb25686514348e615163b458ae0767001b24b42325f426fd56406fd384238",
|
sha256 = "55b75dfa5b6f6173e4abf9c35284a10482ba65db886b39db511eba6c244f1e88",
|
||||||
urls = [
|
urls = [
|
||||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
|
"https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz",
|
||||||
"https://github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
|
"https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz",
|
||||||
],
|
],
|
||||||
build_file = "//third_party/flatbuffers:BUILD.bazel",
|
build_file = "//third_party/flatbuffers:BUILD.bazel",
|
||||||
delete = ["build_defs.bzl", "BUILD.bazel"],
|
delete = ["build_defs.bzl", "BUILD.bazel"],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user