Pulled changes from master
This commit is contained in:
		
						commit
						164eae8c16
					
				| 
						 | 
				
			
			@ -55,6 +55,10 @@ MEDIAPIPE_REGISTER_NODE(ConcatenateUInt64VectorCalculator);
 | 
			
		|||
typedef ConcatenateVectorCalculator<bool> ConcatenateBoolVectorCalculator;
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(ConcatenateBoolVectorCalculator);
 | 
			
		||||
 | 
			
		||||
typedef ConcatenateVectorCalculator<std::string>
 | 
			
		||||
    ConcatenateStringVectorCalculator;
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(ConcatenateStringVectorCalculator);
 | 
			
		||||
 | 
			
		||||
// Example config:
 | 
			
		||||
// node {
 | 
			
		||||
//   calculator: "ConcatenateTfLiteTensorVectorCalculator"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,13 +30,15 @@ namespace mediapipe {
 | 
			
		|||
typedef ConcatenateVectorCalculator<int> TestConcatenateIntVectorCalculator;
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(TestConcatenateIntVectorCalculator);
 | 
			
		||||
 | 
			
		||||
void AddInputVector(int index, const std::vector<int>& input, int64_t timestamp,
 | 
			
		||||
template <typename T>
 | 
			
		||||
void AddInputVector(int index, const std::vector<T>& input, int64_t timestamp,
 | 
			
		||||
                    CalculatorRunner* runner) {
 | 
			
		||||
  runner->MutableInputs()->Index(index).packets.push_back(
 | 
			
		||||
      MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
 | 
			
		||||
      MakePacket<std::vector<T>>(input).At(Timestamp(timestamp)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void AddInputVectors(const std::vector<std::vector<int>>& inputs,
 | 
			
		||||
template <typename T>
 | 
			
		||||
void AddInputVectors(const std::vector<std::vector<T>>& inputs,
 | 
			
		||||
                     int64_t timestamp, CalculatorRunner* runner) {
 | 
			
		||||
  for (int i = 0; i < inputs.size(); ++i) {
 | 
			
		||||
    AddInputVector(i, inputs[i], timestamp, runner);
 | 
			
		||||
| 
						 | 
				
			
			@ -382,6 +384,23 @@ TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) {
 | 
			
		|||
  EXPECT_EQ(0, outputs.size());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConcatenateStringVectorCalculatorTest, OneTimestamp) {
 | 
			
		||||
  CalculatorRunner runner("ConcatenateStringVectorCalculator",
 | 
			
		||||
                          /*options_string=*/"", /*num_inputs=*/3,
 | 
			
		||||
                          /*num_outputs=*/1, /*num_side_packets=*/0);
 | 
			
		||||
 | 
			
		||||
  std::vector<std::vector<std::string>> inputs = {
 | 
			
		||||
      {"a", "b"}, {"c"}, {"d", "e", "f"}};
 | 
			
		||||
  AddInputVectors(inputs, /*timestamp=*/1, &runner);
 | 
			
		||||
  MP_ASSERT_OK(runner.Run());
 | 
			
		||||
 | 
			
		||||
  const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
 | 
			
		||||
  EXPECT_EQ(1, outputs.size());
 | 
			
		||||
  EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
 | 
			
		||||
  std::vector<std::string> expected_vector = {"a", "b", "c", "d", "e", "f"};
 | 
			
		||||
  EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<std::string>>());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
typedef ConcatenateVectorCalculator<std::unique_ptr<int>>
 | 
			
		||||
    TestConcatenateUniqueIntPtrCalculator;
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(TestConcatenateUniqueIntPtrCalculator);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1099,6 +1099,7 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework/port:ret_check",
 | 
			
		||||
        "//mediapipe/framework/port:status",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = True,  # Defines TestServiceCalculator
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -68,11 +68,11 @@ StatusBuilder&& StatusBuilder::SetNoLogging() && {
 | 
			
		|||
  return std::move(SetNoLogging());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusBuilder::operator Status() const& {
 | 
			
		||||
StatusBuilder::operator absl::Status() const& {
 | 
			
		||||
  return StatusBuilder(*this).JoinMessageToStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusBuilder::operator Status() && { return JoinMessageToStatus(); }
 | 
			
		||||
StatusBuilder::operator absl::Status() && { return JoinMessageToStatus(); }
 | 
			
		||||
 | 
			
		||||
absl::Status StatusBuilder::JoinMessageToStatus() {
 | 
			
		||||
  if (!impl_) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -83,8 +83,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
 | 
			
		|||
    return std::move(*this << msg);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  operator Status() const&;
 | 
			
		||||
  operator Status() &&;
 | 
			
		||||
  operator absl::Status() const&;
 | 
			
		||||
  operator absl::Status() &&;
 | 
			
		||||
 | 
			
		||||
  absl::Status JoinMessageToStatus();
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -403,11 +403,11 @@ std::ostream &operator<<(std::ostream &os,
 | 
			
		|||
    lhs op## = rhs;                                                       \
 | 
			
		||||
    return lhs;                                                           \
 | 
			
		||||
  }
 | 
			
		||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(+);
 | 
			
		||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(-);
 | 
			
		||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(&);
 | 
			
		||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(|);
 | 
			
		||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(^);
 | 
			
		||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(+)
 | 
			
		||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(-)
 | 
			
		||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(&)
 | 
			
		||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(|)
 | 
			
		||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(^)
 | 
			
		||||
#undef STRONG_INT_VS_STRONG_INT_BINARY_OP
 | 
			
		||||
 | 
			
		||||
// Define operators that take one StrongInt and one native integer argument.
 | 
			
		||||
| 
						 | 
				
			
			@ -431,12 +431,12 @@ STRONG_INT_VS_STRONG_INT_BINARY_OP(^);
 | 
			
		|||
    rhs op## = lhs;                                                       \
 | 
			
		||||
    return rhs;                                                           \
 | 
			
		||||
  }
 | 
			
		||||
STRONG_INT_VS_NUMERIC_BINARY_OP(*);
 | 
			
		||||
NUMERIC_VS_STRONG_INT_BINARY_OP(*);
 | 
			
		||||
STRONG_INT_VS_NUMERIC_BINARY_OP(/);
 | 
			
		||||
STRONG_INT_VS_NUMERIC_BINARY_OP(%);
 | 
			
		||||
STRONG_INT_VS_NUMERIC_BINARY_OP(<<);  // NOLINT(whitespace/operators)
 | 
			
		||||
STRONG_INT_VS_NUMERIC_BINARY_OP(>>);  // NOLINT(whitespace/operators)
 | 
			
		||||
STRONG_INT_VS_NUMERIC_BINARY_OP(*)
 | 
			
		||||
NUMERIC_VS_STRONG_INT_BINARY_OP(*)
 | 
			
		||||
STRONG_INT_VS_NUMERIC_BINARY_OP(/)
 | 
			
		||||
STRONG_INT_VS_NUMERIC_BINARY_OP(%)
 | 
			
		||||
STRONG_INT_VS_NUMERIC_BINARY_OP(<<)  // NOLINT(whitespace/operators)
 | 
			
		||||
STRONG_INT_VS_NUMERIC_BINARY_OP(>>)  // NOLINT(whitespace/operators)
 | 
			
		||||
#undef STRONG_INT_VS_NUMERIC_BINARY_OP
 | 
			
		||||
#undef NUMERIC_VS_STRONG_INT_BINARY_OP
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -447,12 +447,12 @@ STRONG_INT_VS_NUMERIC_BINARY_OP(>>);  // NOLINT(whitespace/operators)
 | 
			
		|||
                          StrongInt<TagType, ValueType, ValidatorType> rhs) { \
 | 
			
		||||
    return lhs.value() op rhs.value();                                        \
 | 
			
		||||
  }
 | 
			
		||||
STRONG_INT_COMPARISON_OP(==);  // NOLINT(whitespace/operators)
 | 
			
		||||
STRONG_INT_COMPARISON_OP(!=);  // NOLINT(whitespace/operators)
 | 
			
		||||
STRONG_INT_COMPARISON_OP(<);   // NOLINT(whitespace/operators)
 | 
			
		||||
STRONG_INT_COMPARISON_OP(<=);  // NOLINT(whitespace/operators)
 | 
			
		||||
STRONG_INT_COMPARISON_OP(>);   // NOLINT(whitespace/operators)
 | 
			
		||||
STRONG_INT_COMPARISON_OP(>=);  // NOLINT(whitespace/operators)
 | 
			
		||||
STRONG_INT_COMPARISON_OP(==)  // NOLINT(whitespace/operators)
 | 
			
		||||
STRONG_INT_COMPARISON_OP(!=)  // NOLINT(whitespace/operators)
 | 
			
		||||
STRONG_INT_COMPARISON_OP(<)   // NOLINT(whitespace/operators)
 | 
			
		||||
STRONG_INT_COMPARISON_OP(<=)  // NOLINT(whitespace/operators)
 | 
			
		||||
STRONG_INT_COMPARISON_OP(>)   // NOLINT(whitespace/operators)
 | 
			
		||||
STRONG_INT_COMPARISON_OP(>=)  // NOLINT(whitespace/operators)
 | 
			
		||||
#undef STRONG_INT_COMPARISON_OP
 | 
			
		||||
 | 
			
		||||
}  // namespace intops
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -44,7 +44,6 @@ class GraphServiceBase {
 | 
			
		|||
 | 
			
		||||
  constexpr GraphServiceBase(const char* key) : key(key) {}
 | 
			
		||||
 | 
			
		||||
  virtual ~GraphServiceBase() = default;
 | 
			
		||||
  inline virtual absl::StatusOr<Packet> CreateDefaultObject() const {
 | 
			
		||||
    return DefaultInitializationUnsupported();
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -52,14 +51,32 @@ class GraphServiceBase {
 | 
			
		|||
  const char* key;
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  // `GraphService<T>` objects, deriving `GraphServiceBase` are designed to be
 | 
			
		||||
  // global constants and not ever deleted through `GraphServiceBase`. Hence,
 | 
			
		||||
  // protected and non-virtual destructor which helps to make `GraphService<T>`
 | 
			
		||||
  // trivially destructible and properly defined as global constants.
 | 
			
		||||
  //
 | 
			
		||||
  // A class with any virtual functions should have a destructor that is either
 | 
			
		||||
  // public and virtual or else protected and non-virtual.
 | 
			
		||||
  // https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rc-dtor-virtual
 | 
			
		||||
  ~GraphServiceBase() = default;
 | 
			
		||||
 | 
			
		||||
  absl::Status DefaultInitializationUnsupported() const {
 | 
			
		||||
    return absl::UnimplementedError(absl::StrCat(
 | 
			
		||||
        "Graph service '", key, "' does not support default initialization"));
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// A global constant to refer a service:
 | 
			
		||||
// - Requesting `CalculatorContract::UseService` from calculator
 | 
			
		||||
// - Accessing `Calculator/SubgraphContext::Service`from calculator/subgraph
 | 
			
		||||
// - Setting before graph initialization `CalculatorGraph::SetServiceObject`
 | 
			
		||||
//
 | 
			
		||||
// NOTE: In headers, define your graph service reference safely as following:
 | 
			
		||||
// `inline constexpr GraphService<YourService> kYourService("YourService");`
 | 
			
		||||
//
 | 
			
		||||
template <typename T>
 | 
			
		||||
class GraphService : public GraphServiceBase {
 | 
			
		||||
class GraphService final : public GraphServiceBase {
 | 
			
		||||
 public:
 | 
			
		||||
  using type = T;
 | 
			
		||||
  using packet_type = std::shared_ptr<T>;
 | 
			
		||||
| 
						 | 
				
			
			@ -68,7 +85,7 @@ class GraphService : public GraphServiceBase {
 | 
			
		|||
                                                 kDisallowDefaultInitialization)
 | 
			
		||||
      : GraphServiceBase(my_key), default_init_(default_init) {}
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<Packet> CreateDefaultObject() const override {
 | 
			
		||||
  absl::StatusOr<Packet> CreateDefaultObject() const final {
 | 
			
		||||
    if (default_init_ != kAllowDefaultInitialization) {
 | 
			
		||||
      return DefaultInitializationUnsupported();
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,7 +7,7 @@
 | 
			
		|||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace {
 | 
			
		||||
const GraphService<int> kIntService("mediapipe::IntService");
 | 
			
		||||
constexpr GraphService<int> kIntService("mediapipe::IntService");
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
TEST(GraphServiceManager, SetGetServiceObject) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,6 +14,8 @@
 | 
			
		|||
 | 
			
		||||
#include "mediapipe/framework/graph_service.h"
 | 
			
		||||
 | 
			
		||||
#include <type_traits>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/calculator_contract.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/port/canonical_errors.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -159,7 +161,7 @@ TEST_F(GraphServiceTest, CreateDefault) {
 | 
			
		|||
 | 
			
		||||
struct TestServiceData {};
 | 
			
		||||
 | 
			
		||||
const GraphService<TestServiceData> kTestServiceAllowDefaultInitialization(
 | 
			
		||||
constexpr GraphService<TestServiceData> kTestServiceAllowDefaultInitialization(
 | 
			
		||||
    "kTestServiceAllowDefaultInitialization",
 | 
			
		||||
    GraphServiceBase::kAllowDefaultInitialization);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -272,9 +274,13 @@ TEST(AllowDefaultInitializationGraphServiceTest,
 | 
			
		|||
                       HasSubstr("Service is unavailable.")));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const GraphService<TestServiceData> kTestServiceDisallowDefaultInitialization(
 | 
			
		||||
    "kTestServiceDisallowDefaultInitialization",
 | 
			
		||||
    GraphServiceBase::kDisallowDefaultInitialization);
 | 
			
		||||
constexpr GraphService<TestServiceData>
 | 
			
		||||
    kTestServiceDisallowDefaultInitialization(
 | 
			
		||||
        "kTestServiceDisallowDefaultInitialization",
 | 
			
		||||
        GraphServiceBase::kDisallowDefaultInitialization);
 | 
			
		||||
 | 
			
		||||
static_assert(std::is_trivially_destructible_v<GraphService<TestServiceData>>,
 | 
			
		||||
              "GraphService is not trivially destructible");
 | 
			
		||||
 | 
			
		||||
class FailOnUnavailableOptionalDisallowDefaultInitServiceCalculator
 | 
			
		||||
    : public CalculatorBase {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,15 +16,6 @@
 | 
			
		|||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
 | 
			
		||||
const GraphService<TestServiceObject> kTestService(
 | 
			
		||||
    "test_service", GraphServiceBase::kDisallowDefaultInitialization);
 | 
			
		||||
const GraphService<int> kAnotherService(
 | 
			
		||||
    "another_service", GraphServiceBase::kAllowDefaultInitialization);
 | 
			
		||||
const GraphService<NoDefaultConstructor> kNoDefaultService(
 | 
			
		||||
    "no_default_service", GraphServiceBase::kAllowDefaultInitialization);
 | 
			
		||||
const GraphService<NeedsCreateMethod> kNeedsCreateService(
 | 
			
		||||
    "needs_create_service", GraphServiceBase::kAllowDefaultInitialization);
 | 
			
		||||
 | 
			
		||||
absl::Status TestServiceCalculator::GetContract(CalculatorContract* cc) {
 | 
			
		||||
  cc->Inputs().Index(0).Set<int>();
 | 
			
		||||
  cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -22,14 +22,17 @@ namespace mediapipe {
 | 
			
		|||
 | 
			
		||||
using TestServiceObject = std::map<std::string, int>;
 | 
			
		||||
 | 
			
		||||
extern const GraphService<TestServiceObject> kTestService;
 | 
			
		||||
extern const GraphService<int> kAnotherService;
 | 
			
		||||
inline constexpr GraphService<TestServiceObject> kTestService(
 | 
			
		||||
    "test_service", GraphServiceBase::kDisallowDefaultInitialization);
 | 
			
		||||
inline constexpr GraphService<int> kAnotherService(
 | 
			
		||||
    "another_service", GraphServiceBase::kAllowDefaultInitialization);
 | 
			
		||||
 | 
			
		||||
class NoDefaultConstructor {
 | 
			
		||||
 public:
 | 
			
		||||
  NoDefaultConstructor() = delete;
 | 
			
		||||
};
 | 
			
		||||
extern const GraphService<NoDefaultConstructor> kNoDefaultService;
 | 
			
		||||
inline constexpr GraphService<NoDefaultConstructor> kNoDefaultService(
 | 
			
		||||
    "no_default_service", GraphServiceBase::kAllowDefaultInitialization);
 | 
			
		||||
 | 
			
		||||
class NeedsCreateMethod {
 | 
			
		||||
 public:
 | 
			
		||||
| 
						 | 
				
			
			@ -40,7 +43,8 @@ class NeedsCreateMethod {
 | 
			
		|||
 private:
 | 
			
		||||
  NeedsCreateMethod() = default;
 | 
			
		||||
};
 | 
			
		||||
extern const GraphService<NeedsCreateMethod> kNeedsCreateService;
 | 
			
		||||
inline constexpr GraphService<NeedsCreateMethod> kNeedsCreateService(
 | 
			
		||||
    "needs_create_service", GraphServiceBase::kAllowDefaultInitialization);
 | 
			
		||||
 | 
			
		||||
// Use a service.
 | 
			
		||||
class TestServiceCalculator : public CalculatorBase {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -57,7 +57,7 @@ namespace mediapipe {
 | 
			
		|||
// have underflow/overflow etc.  This type is used internally by Timestamp
 | 
			
		||||
// and TimestampDiff.
 | 
			
		||||
MEDIAPIPE_DEFINE_SAFE_INT_TYPE(TimestampBaseType, int64,
 | 
			
		||||
                               mediapipe::intops::LogFatalOnError);
 | 
			
		||||
                               mediapipe::intops::LogFatalOnError)
 | 
			
		||||
 | 
			
		||||
class TimestampDiff;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -272,17 +272,20 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string);
 | 
			
		|||
#define MEDIAPIPE_REGISTER_TYPE(type, type_name, serialize_fn, deserialize_fn) \
 | 
			
		||||
  SET_MEDIAPIPE_TYPE_MAP_VALUE(                                                \
 | 
			
		||||
      mediapipe::PacketTypeIdToMediaPipeTypeData,                              \
 | 
			
		||||
      mediapipe::tool::GetTypeHash<                                            \
 | 
			
		||||
          mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(),     \
 | 
			
		||||
      mediapipe::TypeId::Of<                                                   \
 | 
			
		||||
          mediapipe::type_map_internal::ReflectType<void(type*)>::Type>()      \
 | 
			
		||||
          .hash_code(),                                                        \
 | 
			
		||||
      (mediapipe::MediaPipeTypeData{                                           \
 | 
			
		||||
          mediapipe::tool::GetTypeHash<                                        \
 | 
			
		||||
              mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
 | 
			
		||||
          mediapipe::TypeId::Of<                                               \
 | 
			
		||||
              mediapipe::type_map_internal::ReflectType<void(type*)>::Type>()  \
 | 
			
		||||
              .hash_code(),                                                    \
 | 
			
		||||
          type_name, serialize_fn, deserialize_fn}));                          \
 | 
			
		||||
  SET_MEDIAPIPE_TYPE_MAP_VALUE(                                                \
 | 
			
		||||
      mediapipe::PacketTypeStringToMediaPipeTypeData, type_name,               \
 | 
			
		||||
      (mediapipe::MediaPipeTypeData{                                           \
 | 
			
		||||
          mediapipe::tool::GetTypeHash<                                        \
 | 
			
		||||
              mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
 | 
			
		||||
          mediapipe::TypeId::Of<                                               \
 | 
			
		||||
              mediapipe::type_map_internal::ReflectType<void(type*)>::Type>()  \
 | 
			
		||||
              .hash_code(),                                                    \
 | 
			
		||||
          type_name, serialize_fn, deserialize_fn}));
 | 
			
		||||
// End define MEDIAPIPE_REGISTER_TYPE.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -38,7 +38,10 @@ cc_library(
 | 
			
		|||
    srcs = ["gpu_service.cc"],
 | 
			
		||||
    hdrs = ["gpu_service.h"],
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = ["//mediapipe/framework:graph_service"] + select({
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/framework:graph_service",
 | 
			
		||||
        "@com_google_absl//absl/base:core_headers",
 | 
			
		||||
    ] + select({
 | 
			
		||||
        "//conditions:default": [
 | 
			
		||||
            ":gpu_shared_data_internal",
 | 
			
		||||
        ],
 | 
			
		||||
| 
						 | 
				
			
			@ -292,6 +295,7 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework/formats:image_frame",
 | 
			
		||||
        "//mediapipe/framework/port:logging",
 | 
			
		||||
        "@com_google_absl//absl/functional:bind_front",
 | 
			
		||||
        "@com_google_absl//absl/log:check",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
        "@com_google_absl//absl/synchronization",
 | 
			
		||||
    ] + select({
 | 
			
		||||
| 
						 | 
				
			
			@ -630,6 +634,7 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework:executor",
 | 
			
		||||
        "//mediapipe/framework/deps:no_destructor",
 | 
			
		||||
        "//mediapipe/framework/port:ret_check",
 | 
			
		||||
        "@com_google_absl//absl/base:core_headers",
 | 
			
		||||
    ] + select({
 | 
			
		||||
        "//conditions:default": [],
 | 
			
		||||
        "//mediapipe:apple": [
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -47,6 +47,7 @@ std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Create(int width, int height,
 | 
			
		|||
  auto buf = absl::make_unique<GlTextureBuffer>(GL_TEXTURE_2D, 0, width, height,
 | 
			
		||||
                                                format, nullptr);
 | 
			
		||||
  if (!buf->CreateInternal(data, alignment)) {
 | 
			
		||||
    LOG(WARNING) << "Failed to create a GL texture";
 | 
			
		||||
    return nullptr;
 | 
			
		||||
  }
 | 
			
		||||
  return buf;
 | 
			
		||||
| 
						 | 
				
			
			@ -106,7 +107,10 @@ GlTextureBuffer::GlTextureBuffer(GLenum target, GLuint name, int width,
 | 
			
		|||
 | 
			
		||||
bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
 | 
			
		||||
  auto context = GlContext::GetCurrent();
 | 
			
		||||
  if (!context) return false;
 | 
			
		||||
  if (!context) {
 | 
			
		||||
    LOG(WARNING) << "Cannot create a GL texture without a valid context";
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  producer_context_ = context;  // Save creation GL context.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,6 +20,7 @@
 | 
			
		|||
#include <memory>
 | 
			
		||||
#include <utility>
 | 
			
		||||
 | 
			
		||||
#include "absl/log/check.h"
 | 
			
		||||
#include "absl/synchronization/mutex.h"
 | 
			
		||||
#include "mediapipe/framework/formats/image_frame.h"
 | 
			
		||||
#include "mediapipe/gpu/gpu_buffer_format.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -72,8 +73,10 @@ class GpuBuffer {
 | 
			
		|||
  // are not portable. Applications and calculators should normally obtain
 | 
			
		||||
  // GpuBuffers in a portable way from the framework, e.g. using
 | 
			
		||||
  // GpuBufferMultiPool.
 | 
			
		||||
  explicit GpuBuffer(std::shared_ptr<internal::GpuBufferStorage> storage)
 | 
			
		||||
      : holder_(std::make_shared<StorageHolder>(std::move(storage))) {}
 | 
			
		||||
  explicit GpuBuffer(std::shared_ptr<internal::GpuBufferStorage> storage) {
 | 
			
		||||
    CHECK(storage) << "Cannot construct GpuBuffer with null storage";
 | 
			
		||||
    holder_ = std::make_shared<StorageHolder>(std::move(storage));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
 | 
			
		||||
  // This is used to support backward-compatible construction of GpuBuffer from
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,6 +28,12 @@ namespace mediapipe {
 | 
			
		|||
#define GL_HALF_FLOAT 0x140B
 | 
			
		||||
#endif  // GL_HALF_FLOAT
 | 
			
		||||
 | 
			
		||||
#ifdef __EMSCRIPTEN__
 | 
			
		||||
#ifndef GL_HALF_FLOAT_OES
 | 
			
		||||
#define GL_HALF_FLOAT_OES 0x8D61
 | 
			
		||||
#endif  // GL_HALF_FLOAT_OES
 | 
			
		||||
#endif  // __EMSCRIPTEN__
 | 
			
		||||
 | 
			
		||||
#if !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
#ifdef GL_ES_VERSION_2_0
 | 
			
		||||
static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) {
 | 
			
		||||
| 
						 | 
				
			
			@ -48,6 +54,12 @@ static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) {
 | 
			
		|||
    case GL_RG8:
 | 
			
		||||
      info->gl_internal_format = info->gl_format = GL_RG_EXT;
 | 
			
		||||
      return;
 | 
			
		||||
#ifdef __EMSCRIPTEN__
 | 
			
		||||
    case GL_RGBA16F:
 | 
			
		||||
      info->gl_internal_format = GL_RGBA;
 | 
			
		||||
      info->gl_type = GL_HALF_FLOAT_OES;
 | 
			
		||||
      return;
 | 
			
		||||
#endif  // __EMSCRIPTEN__
 | 
			
		||||
    default:
 | 
			
		||||
      return;
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,6 +15,7 @@
 | 
			
		|||
#ifndef MEDIAPIPE_GPU_GPU_SERVICE_H_
 | 
			
		||||
#define MEDIAPIPE_GPU_GPU_SERVICE_H_
 | 
			
		||||
 | 
			
		||||
#include "absl/base/attributes.h"
 | 
			
		||||
#include "mediapipe/framework/graph_service.h"
 | 
			
		||||
 | 
			
		||||
#if !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
| 
						 | 
				
			
			@ -29,7 +30,7 @@ class GpuResources {
 | 
			
		|||
};
 | 
			
		||||
#endif  // MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
 | 
			
		||||
extern const GraphService<GpuResources> kGpuService;
 | 
			
		||||
ABSL_CONST_INIT extern const GraphService<GpuResources> kGpuService;
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,6 +14,7 @@
 | 
			
		|||
 | 
			
		||||
#include "mediapipe/gpu/gpu_shared_data_internal.h"
 | 
			
		||||
 | 
			
		||||
#include "absl/base/attributes.h"
 | 
			
		||||
#include "mediapipe/framework/deps/no_destructor.h"
 | 
			
		||||
#include "mediapipe/framework/port/ret_check.h"
 | 
			
		||||
#include "mediapipe/gpu/gl_context.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -116,7 +117,7 @@ GpuResources::~GpuResources() {
 | 
			
		|||
#endif  // __APPLE__
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
extern const GraphService<GpuResources> kGpuService;
 | 
			
		||||
ABSL_CONST_INIT extern const GraphService<GpuResources> kGpuService;
 | 
			
		||||
 | 
			
		||||
absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) {
 | 
			
		||||
  CHECK(node->Contract().ServiceRequests().contains(kGpuService.key));
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -187,7 +187,7 @@ class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
 | 
			
		|||
    """Instantiates perceptual loss.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      feature_weight: The weight coeffcients of multiple model extracted
 | 
			
		||||
      feature_weight: The weight coefficients of multiple model extracted
 | 
			
		||||
        features used for calculating the perceptual loss.
 | 
			
		||||
      loss_weight: The weight coefficients between `style_loss` and
 | 
			
		||||
        `content_loss`.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -105,7 +105,7 @@ class FaceStylizer(object):
 | 
			
		|||
    self._train_model(train_data=train_data, preprocessor=self._preprocessor)
 | 
			
		||||
 | 
			
		||||
  def _create_model(self):
 | 
			
		||||
    """Creates the componenets of face stylizer."""
 | 
			
		||||
    """Creates the components of face stylizer."""
 | 
			
		||||
    self._encoder = model_util.load_keras_model(
 | 
			
		||||
        constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path()
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -138,7 +138,7 @@ class FaceStylizer(object):
 | 
			
		|||
    """
 | 
			
		||||
    train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor)
 | 
			
		||||
 | 
			
		||||
    # TODO: Support processing mulitple input style images. The
 | 
			
		||||
    # TODO: Support processing multiple input style images. The
 | 
			
		||||
    # input style images are expected to have similar style.
 | 
			
		||||
    # style_sample represents a tuple of (style_image, style_label).
 | 
			
		||||
    style_sample = next(iter(train_dataset))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -103,8 +103,8 @@ class ModelResourcesCache {
 | 
			
		|||
};
 | 
			
		||||
 | 
			
		||||
// Global service for mediapipe task model resources cache.
 | 
			
		||||
const mediapipe::GraphService<ModelResourcesCache> kModelResourcesCacheService(
 | 
			
		||||
    "mediapipe::tasks::ModelResourcesCacheService");
 | 
			
		||||
inline constexpr mediapipe::GraphService<ModelResourcesCache>
 | 
			
		||||
    kModelResourcesCacheService("mediapipe::tasks::ModelResourcesCacheService");
 | 
			
		||||
 | 
			
		||||
}  // namespace core
 | 
			
		||||
}  // namespace tasks
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -291,8 +291,11 @@ class TensorsToSegmentationCalculator : public Node {
 | 
			
		|||
  static constexpr Output<Image>::Multiple kConfidenceMaskOut{
 | 
			
		||||
      "CONFIDENCE_MASK"};
 | 
			
		||||
  static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"};
 | 
			
		||||
  static constexpr Output<std::vector<float>>::Optional kQualityScoresOut{
 | 
			
		||||
      "QUALITY_SCORES"};
 | 
			
		||||
  MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut,
 | 
			
		||||
                          kConfidenceMaskOut, kCategoryMaskOut);
 | 
			
		||||
                          kConfidenceMaskOut, kCategoryMaskOut,
 | 
			
		||||
                          kQualityScoresOut);
 | 
			
		||||
 | 
			
		||||
  static absl::Status UpdateContract(CalculatorContract* cc);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -345,12 +348,33 @@ absl::Status TensorsToSegmentationCalculator::Open(
 | 
			
		|||
 | 
			
		||||
absl::Status TensorsToSegmentationCalculator::Process(
 | 
			
		||||
    mediapipe::CalculatorContext* cc) {
 | 
			
		||||
  RET_CHECK_EQ(kTensorsIn(cc).Get().size(), 1)
 | 
			
		||||
      << "Expect a vector of single Tensor.";
 | 
			
		||||
  const auto& input_tensor = kTensorsIn(cc).Get()[0];
 | 
			
		||||
  const auto& input_tensors = kTensorsIn(cc).Get();
 | 
			
		||||
  if (input_tensors.size() != 1 && input_tensors.size() != 2) {
 | 
			
		||||
    return absl::InvalidArgumentError(
 | 
			
		||||
        "Expect input tensor vector of size 1 or 2.");
 | 
			
		||||
  }
 | 
			
		||||
  const auto& input_tensor = *input_tensors.rbegin();
 | 
			
		||||
  ASSIGN_OR_RETURN(const Shape input_shape,
 | 
			
		||||
                   GetImageLikeTensorShape(input_tensor));
 | 
			
		||||
 | 
			
		||||
  // TODO: should use tensor signature to get the correct output
 | 
			
		||||
  // tensor.
 | 
			
		||||
  if (input_tensors.size() == 2) {
 | 
			
		||||
    const auto& quality_tensor = input_tensors[0];
 | 
			
		||||
    const float* quality_score_buffer =
 | 
			
		||||
        quality_tensor.GetCpuReadView().buffer<float>();
 | 
			
		||||
    const std::vector<float> quality_scores(
 | 
			
		||||
        quality_score_buffer,
 | 
			
		||||
        quality_score_buffer +
 | 
			
		||||
            (quality_tensor.bytes() / quality_tensor.element_size()));
 | 
			
		||||
    kQualityScoresOut(cc).Send(quality_scores);
 | 
			
		||||
  } else {
 | 
			
		||||
    // If the input_tensors don't contain quality scores, send the default
 | 
			
		||||
    // quality scores as 1.
 | 
			
		||||
    const std::vector<float> quality_scores(input_shape.channels, 1.0f);
 | 
			
		||||
    kQualityScoresOut(cc).Send(quality_scores);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Category mask does not require activation function.
 | 
			
		||||
  if (options_.segmenter_options().output_type() ==
 | 
			
		||||
          SegmenterOptions::CONFIDENCE_MASK &&
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,6 +46,8 @@ constexpr char kImageOutStreamName[] = "image_out";
 | 
			
		|||
constexpr char kImageTag[] = "IMAGE";
 | 
			
		||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
 | 
			
		||||
constexpr char kNormRectTag[] = "NORM_RECT";
 | 
			
		||||
constexpr char kQualityScoresStreamName[] = "quality_scores";
 | 
			
		||||
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
 | 
			
		||||
constexpr char kSubgraphTypeName[] =
 | 
			
		||||
    "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
 | 
			
		||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
 | 
			
		||||
| 
						 | 
				
			
			@ -77,6 +79,8 @@ CalculatorGraphConfig CreateGraphConfig(
 | 
			
		|||
    task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
 | 
			
		||||
        graph.Out(kCategoryMaskTag);
 | 
			
		||||
  }
 | 
			
		||||
  task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
 | 
			
		||||
      graph.Out(kQualityScoresTag);
 | 
			
		||||
  task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
 | 
			
		||||
      graph.Out(kImageTag);
 | 
			
		||||
  if (enable_flow_limiting) {
 | 
			
		||||
| 
						 | 
				
			
			@ -172,9 +176,13 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
 | 
			
		|||
            category_mask =
 | 
			
		||||
                status_or_packets.value()[kCategoryMaskStreamName].Get<Image>();
 | 
			
		||||
          }
 | 
			
		||||
          const std::vector<float>& quality_scores =
 | 
			
		||||
              status_or_packets.value()[kQualityScoresStreamName]
 | 
			
		||||
                  .Get<std::vector<float>>();
 | 
			
		||||
          Packet image_packet = status_or_packets.value()[kImageOutStreamName];
 | 
			
		||||
          result_callback(
 | 
			
		||||
              {{confidence_masks, category_mask}}, image_packet.Get<Image>(),
 | 
			
		||||
              {{confidence_masks, category_mask, quality_scores}},
 | 
			
		||||
              image_packet.Get<Image>(),
 | 
			
		||||
              image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
 | 
			
		||||
        };
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -227,7 +235,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
 | 
			
		|||
  if (output_category_mask_) {
 | 
			
		||||
    category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
 | 
			
		||||
  }
 | 
			
		||||
  return {{confidence_masks, category_mask}};
 | 
			
		||||
  const std::vector<float>& quality_scores =
 | 
			
		||||
      output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
 | 
			
		||||
  return {{confidence_masks, category_mask, quality_scores}};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
 | 
			
		||||
| 
						 | 
				
			
			@ -260,7 +270,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
 | 
			
		|||
  if (output_category_mask_) {
 | 
			
		||||
    category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
 | 
			
		||||
  }
 | 
			
		||||
  return {{confidence_masks, category_mask}};
 | 
			
		||||
  const std::vector<float>& quality_scores =
 | 
			
		||||
      output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
 | 
			
		||||
  return {{confidence_masks, category_mask, quality_scores}};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status ImageSegmenter::SegmentAsync(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 | 
			
		|||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <type_traits>
 | 
			
		||||
| 
						 | 
				
			
			@ -81,6 +82,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU";
 | 
			
		|||
constexpr char kNormRectTag[] = "NORM_RECT";
 | 
			
		||||
constexpr char kTensorsTag[] = "TENSORS";
 | 
			
		||||
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
 | 
			
		||||
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
 | 
			
		||||
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
 | 
			
		||||
 | 
			
		||||
// Struct holding the different output streams produced by the image segmenter
 | 
			
		||||
| 
						 | 
				
			
			@ -90,6 +92,7 @@ struct ImageSegmenterOutputs {
 | 
			
		|||
  std::optional<std::vector<Source<Image>>> confidence_masks;
 | 
			
		||||
  std::optional<Source<Image>> category_mask;
 | 
			
		||||
  // The same as the input image, mainly used for live stream mode.
 | 
			
		||||
  std::optional<Source<std::vector<float>>> quality_scores;
 | 
			
		||||
  Source<Image> image;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -191,19 +194,12 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
 | 
			
		|||
        "Segmentation tflite models are assumed to have a single subgraph.",
 | 
			
		||||
        MediaPipeTasksStatus::kInvalidArgumentError);
 | 
			
		||||
  }
 | 
			
		||||
  const auto* primary_subgraph = (*model.subgraphs())[0];
 | 
			
		||||
  if (primary_subgraph->outputs()->size() != 1) {
 | 
			
		||||
    return CreateStatusWithPayload(
 | 
			
		||||
        absl::StatusCode::kInvalidArgument,
 | 
			
		||||
        "Segmentation tflite models are assumed to have a single output.",
 | 
			
		||||
        MediaPipeTasksStatus::kInvalidArgumentError);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      *options->mutable_label_items(),
 | 
			
		||||
      GetLabelItemsIfAny(*metadata_extractor,
 | 
			
		||||
                         *metadata_extractor->GetOutputTensorMetadata()->Get(0),
 | 
			
		||||
                         segmenter_option.display_names_locale()));
 | 
			
		||||
      GetLabelItemsIfAny(
 | 
			
		||||
          *metadata_extractor,
 | 
			
		||||
          **metadata_extractor->GetOutputTensorMetadata()->crbegin(),
 | 
			
		||||
          segmenter_option.display_names_locale()));
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -213,10 +209,16 @@ absl::StatusOr<const tflite::Tensor*> GetOutputTensor(
 | 
			
		|||
  const tflite::Model& model = *model_resources.GetTfLiteModel();
 | 
			
		||||
  const auto* primary_subgraph = (*model.subgraphs())[0];
 | 
			
		||||
  const auto* output_tensor =
 | 
			
		||||
      (*primary_subgraph->tensors())[(*primary_subgraph->outputs())[0]];
 | 
			
		||||
      (*primary_subgraph->tensors())[*(*primary_subgraph->outputs()).rbegin()];
 | 
			
		||||
  return output_tensor;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
uint32_t GetOutputTensorsSize(const core::ModelResources& model_resources) {
 | 
			
		||||
  const tflite::Model& model = *model_resources.GetTfLiteModel();
 | 
			
		||||
  const auto* primary_subgraph = (*model.subgraphs())[0];
 | 
			
		||||
  return primary_subgraph->outputs()->size();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get the input tensor from the tflite model of given model resources.
 | 
			
		||||
absl::StatusOr<const tflite::Tensor*> GetInputTensor(
 | 
			
		||||
    const core::ModelResources& model_resources) {
 | 
			
		||||
| 
						 | 
				
			
			@ -433,6 +435,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
 | 
			
		|||
        *output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    if (output_streams.quality_scores) {
 | 
			
		||||
      *output_streams.quality_scores >>
 | 
			
		||||
          graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
 | 
			
		||||
    }
 | 
			
		||||
    output_streams.image >> graph[Output<Image>(kImageTag)];
 | 
			
		||||
    return graph.GetConfig();
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -530,9 +536,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
 | 
			
		|||
              tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      auto quality_scores =
 | 
			
		||||
          tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
 | 
			
		||||
      return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
 | 
			
		||||
                                   /*confidence_masks=*/std::nullopt,
 | 
			
		||||
                                   /*category_mask=*/std::nullopt,
 | 
			
		||||
                                   /*quality_scores=*/quality_scores,
 | 
			
		||||
                                   /*image=*/image_and_tensors.image};
 | 
			
		||||
    } else {
 | 
			
		||||
      std::optional<std::vector<Source<Image>>> confidence_masks;
 | 
			
		||||
| 
						 | 
				
			
			@ -552,9 +561,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
 | 
			
		|||
      if (output_category_mask_) {
 | 
			
		||||
        category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)];
 | 
			
		||||
      }
 | 
			
		||||
      auto quality_scores =
 | 
			
		||||
          tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
 | 
			
		||||
      return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt,
 | 
			
		||||
                                   /*confidence_masks=*/confidence_masks,
 | 
			
		||||
                                   /*category_mask=*/category_mask,
 | 
			
		||||
                                   /*quality_scores=*/quality_scores,
 | 
			
		||||
                                   /*image=*/image_and_tensors.image};
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -33,6 +33,10 @@ struct ImageSegmenterResult {
 | 
			
		|||
  // A category mask of uint8 image in GRAY8 format where each pixel represents
 | 
			
		||||
  // the class which the pixel in the original image was predicted to belong to.
 | 
			
		||||
  std::optional<Image> category_mask;
 | 
			
		||||
  // The quality scores of the result masks, in the range of [0, 1]. Defaults to
 | 
			
		||||
  // `1` if the model doesn't output quality scores. Each element corresponds to
 | 
			
		||||
  // the score of the category in the model outputs.
 | 
			
		||||
  std::vector<float> quality_scores;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace image_segmenter
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -51,12 +51,14 @@ constexpr char kImageInStreamName[] = "image_in";
 | 
			
		|||
constexpr char kImageOutStreamName[] = "image_out";
 | 
			
		||||
constexpr char kRoiStreamName[] = "roi_in";
 | 
			
		||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
 | 
			
		||||
constexpr char kQualityScoresStreamName[] = "quality_scores";
 | 
			
		||||
 | 
			
		||||
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
 | 
			
		||||
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
 | 
			
		||||
constexpr absl::string_view kImageTag{"IMAGE"};
 | 
			
		||||
constexpr absl::string_view kRoiTag{"ROI"};
 | 
			
		||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
 | 
			
		||||
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
 | 
			
		||||
 | 
			
		||||
constexpr absl::string_view kSubgraphTypeName{
 | 
			
		||||
    "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
 | 
			
		||||
| 
						 | 
				
			
			@ -91,6 +93,8 @@ CalculatorGraphConfig CreateGraphConfig(
 | 
			
		|||
    task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
 | 
			
		||||
        graph.Out(kCategoryMaskTag);
 | 
			
		||||
  }
 | 
			
		||||
  task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
 | 
			
		||||
      graph.Out(kQualityScoresTag);
 | 
			
		||||
  task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
 | 
			
		||||
      graph.Out(kImageTag);
 | 
			
		||||
  graph.In(kImageTag) >> task_subgraph.In(kImageTag);
 | 
			
		||||
| 
						 | 
				
			
			@ -201,7 +205,9 @@ absl::StatusOr<ImageSegmenterResult> InteractiveSegmenter::Segment(
 | 
			
		|||
  if (output_category_mask_) {
 | 
			
		||||
    category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
 | 
			
		||||
  }
 | 
			
		||||
  return {{confidence_masks, category_mask}};
 | 
			
		||||
  const std::vector<float>& quality_scores =
 | 
			
		||||
      output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
 | 
			
		||||
  return {{confidence_masks, category_mask, quality_scores}};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace interactive_segmenter
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -58,6 +58,7 @@ constexpr absl::string_view kAlphaTag{"ALPHA"};
 | 
			
		|||
constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
 | 
			
		||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
 | 
			
		||||
constexpr absl::string_view kRoiTag{"ROI"};
 | 
			
		||||
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
 | 
			
		||||
 | 
			
		||||
// Updates the graph to return `roi` stream which has same dimension as
 | 
			
		||||
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
 | 
			
		||||
| 
						 | 
				
			
			@ -200,6 +201,8 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
 | 
			
		|||
            graph[Output<Image>(kCategoryMaskTag)];
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    image_segmenter.Out(kQualityScoresTag) >>
 | 
			
		||||
        graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
 | 
			
		||||
    image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
 | 
			
		||||
 | 
			
		||||
    return graph.GetConfig();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -81,7 +81,7 @@ strip_api_include_path_prefix(
 | 
			
		|||
        "//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierResult.h",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetector.h",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorOptions.h",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectionResult.h",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorResult.h",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -162,7 +162,7 @@ apple_static_xcframework(
 | 
			
		|||
        ":MPPImageClassifierResult.h",
 | 
			
		||||
        ":MPPObjectDetector.h",
 | 
			
		||||
        ":MPPObjectDetectorOptions.h",
 | 
			
		||||
        ":MPPObjectDetectionResult.h",
 | 
			
		||||
        ":MPPObjectDetectorResult.h",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,17 +16,6 @@
 | 
			
		|||
 | 
			
		||||
NS_ASSUME_NONNULL_BEGIN
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * MediaPipe Tasks delegate.
 | 
			
		||||
 */
 | 
			
		||||
typedef NS_ENUM(NSUInteger, MPPDelegate) {
 | 
			
		||||
  /** CPU. */
 | 
			
		||||
  MPPDelegateCPU,
 | 
			
		||||
 | 
			
		||||
  /** GPU. */
 | 
			
		||||
  MPPDelegateGPU
 | 
			
		||||
} NS_SWIFT_NAME(Delegate);
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Holds the base options that is used for creation of any type of task. It has fields with
 | 
			
		||||
 * important information acceleration configuration, TFLite model source etc.
 | 
			
		||||
| 
						 | 
				
			
			@ -37,12 +26,6 @@ NS_SWIFT_NAME(BaseOptions)
 | 
			
		|||
/** The path to the model asset to open and mmap in memory. */
 | 
			
		||||
@property(nonatomic, copy) NSString *modelAssetPath;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Device delegate to run the MediaPipe pipeline. If the delegate is not set, the default
 | 
			
		||||
 * delegate CPU is used.
 | 
			
		||||
 */
 | 
			
		||||
@property(nonatomic) MPPDelegate delegate;
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
 | 
			
		||||
NS_ASSUME_NONNULL_END
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,7 +28,6 @@
 | 
			
		|||
  MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init];
 | 
			
		||||
 | 
			
		||||
  baseOptions.modelAssetPath = self.modelAssetPath;
 | 
			
		||||
  baseOptions.delegate = self.delegate;
 | 
			
		||||
 | 
			
		||||
  return baseOptions;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -33,20 +33,6 @@ using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions;
 | 
			
		|||
  if (self.modelAssetPath) {
 | 
			
		||||
    baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  switch (self.delegate) {
 | 
			
		||||
    case MPPDelegateCPU: {
 | 
			
		||||
      baseOptionsProto->mutable_acceleration()->mutable_tflite();
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    case MPPDelegateGPU: {
 | 
			
		||||
      // TODO: Provide an implementation for GPU Delegate.
 | 
			
		||||
      [NSException raise:@"Invalid value for delegate" format:@"GPU Delegate is not implemented."];
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    default:
 | 
			
		||||
      break;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,9 +28,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
 | 
			
		|||
  XCTAssertNotNil(error);                                                                     \
 | 
			
		||||
  XCTAssertEqualObjects(error.domain, expectedError.domain);                                  \
 | 
			
		||||
  XCTAssertEqual(error.code, expectedError.code);                                             \
 | 
			
		||||
  XCTAssertNotEqual(                                                                          \
 | 
			
		||||
      [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
 | 
			
		||||
      NSNotFound)
 | 
			
		||||
  XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
 | 
			
		||||
 | 
			
		||||
#define AssertEqualCategoryArrays(categories, expectedCategories)                         \
 | 
			
		||||
  XCTAssertEqual(categories.count, expectedCategories.count);                             \
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -29,9 +29,7 @@ static const float kSimilarityDiffTolerance = 1e-4;
 | 
			
		|||
  XCTAssertNotNil(error);                                                                     \
 | 
			
		||||
  XCTAssertEqualObjects(error.domain, expectedError.domain);                                  \
 | 
			
		||||
  XCTAssertEqual(error.code, expectedError.code);                                             \
 | 
			
		||||
  XCTAssertNotEqual(                                                                          \
 | 
			
		||||
      [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
 | 
			
		||||
      NSNotFound)
 | 
			
		||||
  XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
 | 
			
		||||
 | 
			
		||||
#define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \
 | 
			
		||||
  XCTAssertNotNil(textEmbedderResult);                              \
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -34,9 +34,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
  XCTAssertNotNil(error);                                                                     \
 | 
			
		||||
  XCTAssertEqualObjects(error.domain, expectedError.domain);                                  \
 | 
			
		||||
  XCTAssertEqual(error.code, expectedError.code);                                             \
 | 
			
		||||
  XCTAssertNotEqual(                                                                          \
 | 
			
		||||
      [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
 | 
			
		||||
      NSNotFound)
 | 
			
		||||
  XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
 | 
			
		||||
 | 
			
		||||
#define AssertEqualCategoryArrays(categories, expectedCategories)                         \
 | 
			
		||||
  XCTAssertEqual(categories.count, expectedCategories.count);                             \
 | 
			
		||||
| 
						 | 
				
			
			@ -670,10 +668,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
 | 
			
		||||
  // Because of flow limiting, we cannot ensure that the callback will be
 | 
			
		||||
  // invoked `iterationCount` times.
 | 
			
		||||
  // An normal expectation will fail if expectation.fullfill() is not called
 | 
			
		||||
  // An normal expectation will fail if expectation.fulfill() is not called
 | 
			
		||||
  // `expectation.expectedFulfillmentCount` times.
 | 
			
		||||
  // If `expectation.isInverted = true`, the test will only succeed if
 | 
			
		||||
  // expectation is not fullfilled for the specified `expectedFulfillmentCount`.
 | 
			
		||||
  // expectation is not fulfilled for the specified `expectedFulfillmentCount`.
 | 
			
		||||
  // Since in our case we cannot predict how many times the expectation is
 | 
			
		||||
  // supposed to be fullfilled setting,
 | 
			
		||||
  // `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -32,9 +32,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
  XCTAssertNotNil(error);                                                                     \
 | 
			
		||||
  XCTAssertEqualObjects(error.domain, expectedError.domain);                                  \
 | 
			
		||||
  XCTAssertEqual(error.code, expectedError.code);                                             \
 | 
			
		||||
  XCTAssertNotEqual(                                                                          \
 | 
			
		||||
      [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
 | 
			
		||||
      NSNotFound)
 | 
			
		||||
  XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
 | 
			
		||||
 | 
			
		||||
#define AssertEqualCategories(category, expectedCategory, detectionIndex, categoryIndex)           \
 | 
			
		||||
  XCTAssertEqual(category.index, expectedCategory.index,                                           \
 | 
			
		||||
| 
						 | 
				
			
			@ -70,7 +68,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
 | 
			
		||||
#pragma mark Results
 | 
			
		||||
 | 
			
		||||
+ (MPPObjectDetectionResult *)expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
 | 
			
		||||
+ (MPPObjectDetectorResult *)expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
 | 
			
		||||
    (NSInteger)timestampInMilliseconds {
 | 
			
		||||
  NSArray<MPPDetection *> *detections = @[
 | 
			
		||||
    [[MPPDetection alloc] initWithCategories:@[
 | 
			
		||||
| 
						 | 
				
			
			@ -95,8 +93,8 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
                                   keypoints:nil],
 | 
			
		||||
  ];
 | 
			
		||||
 | 
			
		||||
  return [[MPPObjectDetectionResult alloc] initWithDetections:detections
 | 
			
		||||
                                      timestampInMilliseconds:timestampInMilliseconds];
 | 
			
		||||
  return [[MPPObjectDetectorResult alloc] initWithDetections:detections
 | 
			
		||||
                                     timestampInMilliseconds:timestampInMilliseconds];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertDetections:(NSArray<MPPDetection *> *)detections
 | 
			
		||||
| 
						 | 
				
			
			@ -112,25 +110,25 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertObjectDetectionResult:(MPPObjectDetectionResult *)objectDetectionResult
 | 
			
		||||
            isEqualToExpectedResult:(MPPObjectDetectionResult *)expectedObjectDetectionResult
 | 
			
		||||
            expectedDetectionsCount:(NSInteger)expectedDetectionsCount {
 | 
			
		||||
  XCTAssertNotNil(objectDetectionResult);
 | 
			
		||||
- (void)assertObjectDetectorResult:(MPPObjectDetectorResult *)objectDetectorResult
 | 
			
		||||
           isEqualToExpectedResult:(MPPObjectDetectorResult *)expectedObjectDetectorResult
 | 
			
		||||
           expectedDetectionsCount:(NSInteger)expectedDetectionsCount {
 | 
			
		||||
  XCTAssertNotNil(objectDetectorResult);
 | 
			
		||||
 | 
			
		||||
  NSArray<MPPDetection *> *detectionsSubsetToCompare;
 | 
			
		||||
  XCTAssertEqual(objectDetectionResult.detections.count, expectedDetectionsCount);
 | 
			
		||||
  if (objectDetectionResult.detections.count > expectedObjectDetectionResult.detections.count) {
 | 
			
		||||
    detectionsSubsetToCompare = [objectDetectionResult.detections
 | 
			
		||||
        subarrayWithRange:NSMakeRange(0, expectedObjectDetectionResult.detections.count)];
 | 
			
		||||
  XCTAssertEqual(objectDetectorResult.detections.count, expectedDetectionsCount);
 | 
			
		||||
  if (objectDetectorResult.detections.count > expectedObjectDetectorResult.detections.count) {
 | 
			
		||||
    detectionsSubsetToCompare = [objectDetectorResult.detections
 | 
			
		||||
        subarrayWithRange:NSMakeRange(0, expectedObjectDetectorResult.detections.count)];
 | 
			
		||||
  } else {
 | 
			
		||||
    detectionsSubsetToCompare = objectDetectionResult.detections;
 | 
			
		||||
    detectionsSubsetToCompare = objectDetectorResult.detections;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  [self assertDetections:detectionsSubsetToCompare
 | 
			
		||||
      isEqualToExpectedDetections:expectedObjectDetectionResult.detections];
 | 
			
		||||
      isEqualToExpectedDetections:expectedObjectDetectorResult.detections];
 | 
			
		||||
 | 
			
		||||
  XCTAssertEqual(objectDetectionResult.timestampInMilliseconds,
 | 
			
		||||
                 expectedObjectDetectionResult.timestampInMilliseconds);
 | 
			
		||||
  XCTAssertEqual(objectDetectorResult.timestampInMilliseconds,
 | 
			
		||||
                 expectedObjectDetectorResult.timestampInMilliseconds);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#pragma mark File
 | 
			
		||||
| 
						 | 
				
			
			@ -195,28 +193,27 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
- (void)assertResultsOfDetectInImage:(MPPImage *)mppImage
 | 
			
		||||
                 usingObjectDetector:(MPPObjectDetector *)objectDetector
 | 
			
		||||
                          maxResults:(NSInteger)maxResults
 | 
			
		||||
         equalsObjectDetectionResult:(MPPObjectDetectionResult *)expectedObjectDetectionResult {
 | 
			
		||||
  MPPObjectDetectionResult *objectDetectionResult = [objectDetector detectInImage:mppImage
 | 
			
		||||
                                                                            error:nil];
 | 
			
		||||
          equalsObjectDetectorResult:(MPPObjectDetectorResult *)expectedObjectDetectorResult {
 | 
			
		||||
  MPPObjectDetectorResult *ObjectDetectorResult = [objectDetector detectInImage:mppImage error:nil];
 | 
			
		||||
 | 
			
		||||
  [self assertObjectDetectionResult:objectDetectionResult
 | 
			
		||||
            isEqualToExpectedResult:expectedObjectDetectionResult
 | 
			
		||||
            expectedDetectionsCount:maxResults > 0 ? maxResults
 | 
			
		||||
                                                   : objectDetectionResult.detections.count];
 | 
			
		||||
  [self assertObjectDetectorResult:ObjectDetectorResult
 | 
			
		||||
           isEqualToExpectedResult:expectedObjectDetectorResult
 | 
			
		||||
           expectedDetectionsCount:maxResults > 0 ? maxResults
 | 
			
		||||
                                                  : ObjectDetectorResult.detections.count];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertResultsOfDetectInImageWithFileInfo:(NSDictionary *)fileInfo
 | 
			
		||||
                             usingObjectDetector:(MPPObjectDetector *)objectDetector
 | 
			
		||||
                                      maxResults:(NSInteger)maxResults
 | 
			
		||||
 | 
			
		||||
                     equalsObjectDetectionResult:
 | 
			
		||||
                         (MPPObjectDetectionResult *)expectedObjectDetectionResult {
 | 
			
		||||
                      equalsObjectDetectorResult:
 | 
			
		||||
                          (MPPObjectDetectorResult *)expectedObjectDetectorResult {
 | 
			
		||||
  MPPImage *mppImage = [self imageWithFileInfo:fileInfo];
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfDetectInImage:mppImage
 | 
			
		||||
                 usingObjectDetector:objectDetector
 | 
			
		||||
                          maxResults:maxResults
 | 
			
		||||
         equalsObjectDetectionResult:expectedObjectDetectionResult];
 | 
			
		||||
          equalsObjectDetectorResult:expectedObjectDetectorResult];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#pragma mark General Tests
 | 
			
		||||
| 
						 | 
				
			
			@ -266,10 +263,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
  [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
 | 
			
		||||
                             usingObjectDetector:objectDetector
 | 
			
		||||
                                      maxResults:-1
 | 
			
		||||
                     equalsObjectDetectionResult:
 | 
			
		||||
                         [MPPObjectDetectorTests
 | 
			
		||||
                             expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
 | 
			
		||||
                                 0]];
 | 
			
		||||
                      equalsObjectDetectorResult:
 | 
			
		||||
                          [MPPObjectDetectorTests
 | 
			
		||||
                              expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
 | 
			
		||||
                                  0]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testDetectWithOptionsSucceeds {
 | 
			
		||||
| 
						 | 
				
			
			@ -280,10 +277,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
  [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
 | 
			
		||||
                             usingObjectDetector:objectDetector
 | 
			
		||||
                                      maxResults:-1
 | 
			
		||||
                     equalsObjectDetectionResult:
 | 
			
		||||
                         [MPPObjectDetectorTests
 | 
			
		||||
                             expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
 | 
			
		||||
                                 0]];
 | 
			
		||||
                      equalsObjectDetectorResult:
 | 
			
		||||
                          [MPPObjectDetectorTests
 | 
			
		||||
                              expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
 | 
			
		||||
                                  0]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testDetectWithMaxResultsSucceeds {
 | 
			
		||||
| 
						 | 
				
			
			@ -297,10 +294,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
  [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
 | 
			
		||||
                             usingObjectDetector:objectDetector
 | 
			
		||||
                                      maxResults:maxResults
 | 
			
		||||
                     equalsObjectDetectionResult:
 | 
			
		||||
                         [MPPObjectDetectorTests
 | 
			
		||||
                             expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
 | 
			
		||||
                                 0]];
 | 
			
		||||
                      equalsObjectDetectorResult:
 | 
			
		||||
                          [MPPObjectDetectorTests
 | 
			
		||||
                              expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
 | 
			
		||||
                                  0]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testDetectWithScoreThresholdSucceeds {
 | 
			
		||||
| 
						 | 
				
			
			@ -316,13 +313,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
                                 boundingBox:CGRectMake(608, 161, 381, 439)
 | 
			
		||||
                                   keypoints:nil],
 | 
			
		||||
  ];
 | 
			
		||||
  MPPObjectDetectionResult *expectedObjectDetectionResult =
 | 
			
		||||
      [[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
 | 
			
		||||
  MPPObjectDetectorResult *expectedObjectDetectorResult =
 | 
			
		||||
      [[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
 | 
			
		||||
                             usingObjectDetector:objectDetector
 | 
			
		||||
                                      maxResults:-1
 | 
			
		||||
                     equalsObjectDetectionResult:expectedObjectDetectionResult];
 | 
			
		||||
                      equalsObjectDetectorResult:expectedObjectDetectorResult];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testDetectWithCategoryAllowlistSucceeds {
 | 
			
		||||
| 
						 | 
				
			
			@ -359,13 +356,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
                                   keypoints:nil],
 | 
			
		||||
  ];
 | 
			
		||||
 | 
			
		||||
  MPPObjectDetectionResult *expectedDetectionResult =
 | 
			
		||||
      [[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
 | 
			
		||||
  MPPObjectDetectorResult *expectedDetectionResult =
 | 
			
		||||
      [[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
 | 
			
		||||
                             usingObjectDetector:objectDetector
 | 
			
		||||
                                      maxResults:-1
 | 
			
		||||
                     equalsObjectDetectionResult:expectedDetectionResult];
 | 
			
		||||
                      equalsObjectDetectorResult:expectedDetectionResult];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testDetectWithCategoryDenylistSucceeds {
 | 
			
		||||
| 
						 | 
				
			
			@ -414,13 +411,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
                                   keypoints:nil],
 | 
			
		||||
  ];
 | 
			
		||||
 | 
			
		||||
  MPPObjectDetectionResult *expectedDetectionResult =
 | 
			
		||||
      [[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
 | 
			
		||||
  MPPObjectDetectorResult *expectedDetectionResult =
 | 
			
		||||
      [[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
 | 
			
		||||
                             usingObjectDetector:objectDetector
 | 
			
		||||
                                      maxResults:-1
 | 
			
		||||
                     equalsObjectDetectionResult:expectedDetectionResult];
 | 
			
		||||
                      equalsObjectDetectorResult:expectedDetectionResult];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testDetectWithOrientationSucceeds {
 | 
			
		||||
| 
						 | 
				
			
			@ -437,8 +434,8 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
                                   keypoints:nil],
 | 
			
		||||
  ];
 | 
			
		||||
 | 
			
		||||
  MPPObjectDetectionResult *expectedDetectionResult =
 | 
			
		||||
      [[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
 | 
			
		||||
  MPPObjectDetectorResult *expectedDetectionResult =
 | 
			
		||||
      [[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
 | 
			
		||||
 | 
			
		||||
  MPPImage *image = [self imageWithFileInfo:kCatsAndDogsRotatedImage
 | 
			
		||||
                                orientation:UIImageOrientationRight];
 | 
			
		||||
| 
						 | 
				
			
			@ -446,7 +443,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
  [self assertResultsOfDetectInImage:image
 | 
			
		||||
                 usingObjectDetector:objectDetector
 | 
			
		||||
                          maxResults:1
 | 
			
		||||
         equalsObjectDetectionResult:expectedDetectionResult];
 | 
			
		||||
          equalsObjectDetectorResult:expectedDetectionResult];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#pragma mark Running Mode Tests
 | 
			
		||||
| 
						 | 
				
			
			@ -613,15 +610,15 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
  MPPImage *image = [self imageWithFileInfo:kCatsAndDogsImage];
 | 
			
		||||
 | 
			
		||||
  for (int i = 0; i < 3; i++) {
 | 
			
		||||
    MPPObjectDetectionResult *objectDetectionResult = [objectDetector detectInVideoFrame:image
 | 
			
		||||
                                                                 timestampInMilliseconds:i
 | 
			
		||||
                                                                                   error:nil];
 | 
			
		||||
    MPPObjectDetectorResult *ObjectDetectorResult = [objectDetector detectInVideoFrame:image
 | 
			
		||||
                                                               timestampInMilliseconds:i
 | 
			
		||||
                                                                                 error:nil];
 | 
			
		||||
 | 
			
		||||
    [self assertObjectDetectionResult:objectDetectionResult
 | 
			
		||||
              isEqualToExpectedResult:
 | 
			
		||||
                  [MPPObjectDetectorTests
 | 
			
		||||
                      expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:i]
 | 
			
		||||
              expectedDetectionsCount:maxResults];
 | 
			
		||||
    [self assertObjectDetectorResult:ObjectDetectorResult
 | 
			
		||||
             isEqualToExpectedResult:
 | 
			
		||||
                 [MPPObjectDetectorTests
 | 
			
		||||
                     expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:i]
 | 
			
		||||
             expectedDetectionsCount:maxResults];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -676,15 +673,15 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
 | 
			
		||||
  // Because of flow limiting, we cannot ensure that the callback will be
 | 
			
		||||
  // invoked `iterationCount` times.
 | 
			
		||||
  // An normal expectation will fail if expectation.fullfill() is not called
 | 
			
		||||
  // An normal expectation will fail if expectation.fulfill() is not called
 | 
			
		||||
  // `expectation.expectedFulfillmentCount` times.
 | 
			
		||||
  // If `expectation.isInverted = true`, the test will only succeed if
 | 
			
		||||
  // expectation is not fullfilled for the specified `expectedFulfillmentCount`.
 | 
			
		||||
  // expectation is not fulfilled for the specified `expectedFulfillmentCount`.
 | 
			
		||||
  // Since in our case we cannot predict how many times the expectation is
 | 
			
		||||
  // supposed to be fullfilled setting,
 | 
			
		||||
  // supposed to be fulfilled setting,
 | 
			
		||||
  // `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
			
		||||
  // `expectation.isInverted = true` ensures that test succeeds if
 | 
			
		||||
  // expectation is fullfilled <= `iterationCount` times.
 | 
			
		||||
  // expectation is fulfilled <= `iterationCount` times.
 | 
			
		||||
  XCTestExpectation *expectation = [[XCTestExpectation alloc]
 | 
			
		||||
      initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
 | 
			
		||||
  expectation.expectedFulfillmentCount = iterationCount + 1;
 | 
			
		||||
| 
						 | 
				
			
			@ -714,16 +711,16 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
 | 
			
		||||
#pragma mark MPPObjectDetectorLiveStreamDelegate Methods
 | 
			
		||||
- (void)objectDetector:(MPPObjectDetector *)objectDetector
 | 
			
		||||
    didFinishDetectionWithResult:(MPPObjectDetectionResult *)objectDetectionResult
 | 
			
		||||
    didFinishDetectionWithResult:(MPPObjectDetectorResult *)ObjectDetectorResult
 | 
			
		||||
         timestampInMilliseconds:(NSInteger)timestampInMilliseconds
 | 
			
		||||
                           error:(NSError *)error {
 | 
			
		||||
  NSInteger maxResults = 4;
 | 
			
		||||
  [self assertObjectDetectionResult:objectDetectionResult
 | 
			
		||||
            isEqualToExpectedResult:
 | 
			
		||||
                [MPPObjectDetectorTests
 | 
			
		||||
                    expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
 | 
			
		||||
                        timestampInMilliseconds]
 | 
			
		||||
            expectedDetectionsCount:maxResults];
 | 
			
		||||
  [self assertObjectDetectorResult:ObjectDetectorResult
 | 
			
		||||
           isEqualToExpectedResult:
 | 
			
		||||
               [MPPObjectDetectorTests
 | 
			
		||||
                   expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
 | 
			
		||||
                       timestampInMilliseconds]
 | 
			
		||||
           expectedDetectionsCount:maxResults];
 | 
			
		||||
 | 
			
		||||
  if (objectDetector == outOfOrderTimestampTestDict[kLiveStreamTestsDictObjectDetectorKey]) {
 | 
			
		||||
    [outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -64,4 +64,3 @@ objc_library(
 | 
			
		|||
        "//mediapipe/tasks/ios/vision/gesture_recognizer/utils:MPPGestureRecognizerResultHelpers",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -87,7 +87,7 @@ NS_SWIFT_NAME(GestureRecognizerOptions)
 | 
			
		|||
    gestureRecognizerLiveStreamDelegate;
 | 
			
		||||
 | 
			
		||||
/** Sets the maximum number of hands can be detected by the GestureRecognizer. */
 | 
			
		||||
@property(nonatomic) NSInteger numberOfHands;
 | 
			
		||||
@property(nonatomic) NSInteger numberOfHands NS_SWIFT_NAME(numHands);
 | 
			
		||||
 | 
			
		||||
/** Sets minimum confidence score for the hand detection to be considered successful */
 | 
			
		||||
@property(nonatomic) float minHandDetectionConfidence;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -31,7 +31,8 @@
 | 
			
		|||
  MPPGestureRecognizerOptions *gestureRecognizerOptions = [super copyWithZone:zone];
 | 
			
		||||
 | 
			
		||||
  gestureRecognizerOptions.runningMode = self.runningMode;
 | 
			
		||||
  gestureRecognizerOptions.gestureRecognizerLiveStreamDelegate = self.gestureRecognizerLiveStreamDelegate;
 | 
			
		||||
  gestureRecognizerOptions.gestureRecognizerLiveStreamDelegate =
 | 
			
		||||
      self.gestureRecognizerLiveStreamDelegate;
 | 
			
		||||
  gestureRecognizerOptions.numberOfHands = self.numberOfHands;
 | 
			
		||||
  gestureRecognizerOptions.minHandDetectionConfidence = self.minHandDetectionConfidence;
 | 
			
		||||
  gestureRecognizerOptions.minHandPresenceConfidence = self.minHandPresenceConfidence;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,9 +18,9 @@
 | 
			
		|||
 | 
			
		||||
- (instancetype)initWithGestures:(NSArray<NSArray<MPPCategory *> *> *)gestures
 | 
			
		||||
                      handedness:(NSArray<NSArray<MPPCategory *> *> *)handedness
 | 
			
		||||
                      landmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)landmarks
 | 
			
		||||
                   worldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)worldLandmarks
 | 
			
		||||
          timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
 | 
			
		||||
                       landmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)landmarks
 | 
			
		||||
                  worldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)worldLandmarks
 | 
			
		||||
         timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
 | 
			
		||||
  self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
 | 
			
		||||
  if (self) {
 | 
			
		||||
    _landmarks = landmarks;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -22,7 +22,12 @@ objc_library(
 | 
			
		|||
    hdrs = ["sources/MPPGestureRecognizerOptions+Helpers.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/framework:calculator_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
 | 
			
		||||
        "//mediapipe/tasks/ios/components/processors/utils:MPPClassifierOptionsHelpers",
 | 
			
		||||
        "//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,7 +18,12 @@
 | 
			
		|||
#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
using CalculatorOptionsProto = mediapipe::CalculatorOptions;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,7 +17,7 @@
 | 
			
		|||
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
 | 
			
		||||
 | 
			
		||||
static const int kMicroSecondsPerMilliSecond = 1000;
 | 
			
		||||
static const int kMicrosecondsPerMillisecond = 1000;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
using ClassificationResultProto =
 | 
			
		||||
| 
						 | 
				
			
			@ -29,19 +29,26 @@ using ::mediapipe::Packet;
 | 
			
		|||
 | 
			
		||||
+ (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket:
 | 
			
		||||
    (const Packet &)packet {
 | 
			
		||||
  MPPClassificationResult *classificationResult;
 | 
			
		||||
  // Even if packet does not validate as the expected type, you can safely access the timestamp.
 | 
			
		||||
  NSInteger timestampInMilliSeconds =
 | 
			
		||||
      (NSInteger)(packet.Timestamp().Value() / kMicrosecondsPerMillisecond);
 | 
			
		||||
 | 
			
		||||
  if (!packet.ValidateAsType<ClassificationResultProto>().ok()) {
 | 
			
		||||
    return nil;
 | 
			
		||||
    // MPPClassificationResult's timestamp is populated from timestamp `ClassificationResultProto`'s
 | 
			
		||||
    // timestamp_ms(). It is 0 since the packet can't be validated as a `ClassificationResultProto`.
 | 
			
		||||
    return [[MPPImageClassifierResult alloc]
 | 
			
		||||
        initWithClassificationResult:[[MPPClassificationResult alloc] initWithClassifications:@[]
 | 
			
		||||
                                                                      timestampInMilliseconds:0]
 | 
			
		||||
             timestampInMilliseconds:timestampInMilliSeconds];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  classificationResult = [MPPClassificationResult
 | 
			
		||||
  MPPClassificationResult *classificationResult = [MPPClassificationResult
 | 
			
		||||
      classificationResultWithProto:packet.Get<ClassificationResultProto>()];
 | 
			
		||||
 | 
			
		||||
  return [[MPPImageClassifierResult alloc]
 | 
			
		||||
      initWithClassificationResult:classificationResult
 | 
			
		||||
           timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
 | 
			
		||||
                                               kMicroSecondsPerMilliSecond)];
 | 
			
		||||
                                               kMicrosecondsPerMillisecond)];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,9 +17,9 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
			
		|||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
objc_library(
 | 
			
		||||
    name = "MPPObjectDetectionResult",
 | 
			
		||||
    srcs = ["sources/MPPObjectDetectionResult.m"],
 | 
			
		||||
    hdrs = ["sources/MPPObjectDetectionResult.h"],
 | 
			
		||||
    name = "MPPObjectDetectorResult",
 | 
			
		||||
    srcs = ["sources/MPPObjectDetectorResult.m"],
 | 
			
		||||
    hdrs = ["sources/MPPObjectDetectorResult.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/tasks/ios/components/containers:MPPDetection",
 | 
			
		||||
        "//mediapipe/tasks/ios/core:MPPTaskResult",
 | 
			
		||||
| 
						 | 
				
			
			@ -31,7 +31,7 @@ objc_library(
 | 
			
		|||
    srcs = ["sources/MPPObjectDetectorOptions.m"],
 | 
			
		||||
    hdrs = ["sources/MPPObjectDetectorOptions.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":MPPObjectDetectionResult",
 | 
			
		||||
        ":MPPObjectDetectorResult",
 | 
			
		||||
        "//mediapipe/tasks/ios/core:MPPTaskOptions",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/core:MPPRunningMode",
 | 
			
		||||
    ],
 | 
			
		||||
| 
						 | 
				
			
			@ -47,8 +47,8 @@ objc_library(
 | 
			
		|||
        "-x objective-c++",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":MPPObjectDetectionResult",
 | 
			
		||||
        ":MPPObjectDetectorOptions",
 | 
			
		||||
        ":MPPObjectDetectorResult",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
 | 
			
		||||
        "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
 | 
			
		||||
        "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
 | 
			
		||||
| 
						 | 
				
			
			@ -56,7 +56,7 @@ objc_library(
 | 
			
		|||
        "//mediapipe/tasks/ios/vision/core:MPPImage",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectionResultHelpers",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectorOptionsHelpers",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectorResultHelpers",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,8 +15,8 @@
 | 
			
		|||
#import <Foundation/Foundation.h>
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
 | 
			
		||||
 | 
			
		||||
NS_ASSUME_NONNULL_BEGIN
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -109,14 +109,13 @@ NS_SWIFT_NAME(ObjectDetector)
 | 
			
		|||
 * @param error An optional error parameter populated when there is an error in performing object
 | 
			
		||||
 * detection on the input image.
 | 
			
		||||
 *
 | 
			
		||||
 * @return An `MPPObjectDetectionResult` object that contains a list of detections, each detection
 | 
			
		||||
 * @return An `MPPObjectDetectorResult` object that contains a list of detections, each detection
 | 
			
		||||
 * has a bounding box that is expressed in the unrotated input frame of reference coordinates
 | 
			
		||||
 * system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying
 | 
			
		||||
 * image data.
 | 
			
		||||
 */
 | 
			
		||||
- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image
 | 
			
		||||
                                               error:(NSError **)error
 | 
			
		||||
    NS_SWIFT_NAME(detect(image:));
 | 
			
		||||
- (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image
 | 
			
		||||
                                              error:(NSError **)error NS_SWIFT_NAME(detect(image:));
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Performs object detection on the provided video frame of type `MPPImage` using the whole
 | 
			
		||||
| 
						 | 
				
			
			@ -139,14 +138,14 @@ NS_SWIFT_NAME(ObjectDetector)
 | 
			
		|||
 * @param error An optional error parameter populated when there is an error in performing object
 | 
			
		||||
 * detection on the input image.
 | 
			
		||||
 *
 | 
			
		||||
 * @return An `MPPObjectDetectionResult` object that contains a list of detections, each detection
 | 
			
		||||
 * @return An `MPPObjectDetectorResult` object that contains a list of detections, each detection
 | 
			
		||||
 * has a bounding box that is expressed in the unrotated input frame of reference coordinates
 | 
			
		||||
 * system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying
 | 
			
		||||
 * image data.
 | 
			
		||||
 */
 | 
			
		||||
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
 | 
			
		||||
                                  timestampInMilliseconds:(NSInteger)timestampInMilliseconds
 | 
			
		||||
                                                    error:(NSError **)error
 | 
			
		||||
- (nullable MPPObjectDetectorResult *)detectInVideoFrame:(MPPImage *)image
 | 
			
		||||
                                 timestampInMilliseconds:(NSInteger)timestampInMilliseconds
 | 
			
		||||
                                                   error:(NSError **)error
 | 
			
		||||
    NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:));
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,8 +19,8 @@
 | 
			
		|||
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorOptions+Helpers.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.h"
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
using ::mediapipe::NormalizedRect;
 | 
			
		||||
| 
						 | 
				
			
			@ -118,9 +118,9 @@ static NSString *const kTaskName = @"objectDetector";
 | 
			
		|||
          return;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        MPPObjectDetectionResult *result = [MPPObjectDetectionResult
 | 
			
		||||
            objectDetectionResultWithDetectionsPacket:statusOrPackets.value()[kDetectionsStreamName
 | 
			
		||||
                                                                                  .cppString]];
 | 
			
		||||
        MPPObjectDetectorResult *result = [MPPObjectDetectorResult
 | 
			
		||||
            objectDetectorResultWithDetectionsPacket:statusOrPackets
 | 
			
		||||
                                                         .value()[kDetectionsStreamName.cppString]];
 | 
			
		||||
 | 
			
		||||
        NSInteger timeStampInMilliseconds =
 | 
			
		||||
            outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() /
 | 
			
		||||
| 
						 | 
				
			
			@ -184,9 +184,9 @@ static NSString *const kTaskName = @"objectDetector";
 | 
			
		|||
  return inputPacketMap;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image
 | 
			
		||||
                                    regionOfInterest:(CGRect)roi
 | 
			
		||||
                                               error:(NSError **)error {
 | 
			
		||||
- (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image
 | 
			
		||||
                                   regionOfInterest:(CGRect)roi
 | 
			
		||||
                                              error:(NSError **)error {
 | 
			
		||||
  std::optional<NormalizedRect> rect =
 | 
			
		||||
      [_visionTaskRunner normalizedRectFromRegionOfInterest:roi
 | 
			
		||||
                                                  imageSize:CGSizeMake(image.width, image.height)
 | 
			
		||||
| 
						 | 
				
			
			@ -213,18 +213,18 @@ static NSString *const kTaskName = @"objectDetector";
 | 
			
		|||
    return nil;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return [MPPObjectDetectionResult
 | 
			
		||||
      objectDetectionResultWithDetectionsPacket:outputPacketMap
 | 
			
		||||
                                                    .value()[kDetectionsStreamName.cppString]];
 | 
			
		||||
  return [MPPObjectDetectorResult
 | 
			
		||||
      objectDetectorResultWithDetectionsPacket:outputPacketMap
 | 
			
		||||
                                                   .value()[kDetectionsStreamName.cppString]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image error:(NSError **)error {
 | 
			
		||||
- (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image error:(NSError **)error {
 | 
			
		||||
  return [self detectInImage:image regionOfInterest:CGRectZero error:error];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
 | 
			
		||||
                                  timestampInMilliseconds:(NSInteger)timestampInMilliseconds
 | 
			
		||||
                                                    error:(NSError **)error {
 | 
			
		||||
- (nullable MPPObjectDetectorResult *)detectInVideoFrame:(MPPImage *)image
 | 
			
		||||
                                 timestampInMilliseconds:(NSInteger)timestampInMilliseconds
 | 
			
		||||
                                                   error:(NSError **)error {
 | 
			
		||||
  std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
 | 
			
		||||
                                                     timestampInMilliseconds:timestampInMilliseconds
 | 
			
		||||
                                                                       error:error];
 | 
			
		||||
| 
						 | 
				
			
			@ -239,9 +239,9 @@ static NSString *const kTaskName = @"objectDetector";
 | 
			
		|||
    return nil;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return [MPPObjectDetectionResult
 | 
			
		||||
      objectDetectionResultWithDetectionsPacket:outputPacketMap
 | 
			
		||||
                                                    .value()[kDetectionsStreamName.cppString]];
 | 
			
		||||
  return [MPPObjectDetectorResult
 | 
			
		||||
      objectDetectorResultWithDetectionsPacket:outputPacketMap
 | 
			
		||||
                                                   .value()[kDetectionsStreamName.cppString]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (BOOL)detectAsyncInImage:(MPPImage *)image
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,7 +16,7 @@
 | 
			
		|||
 | 
			
		||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
 | 
			
		||||
 | 
			
		||||
NS_ASSUME_NONNULL_BEGIN
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -44,7 +44,7 @@ NS_SWIFT_NAME(ObjectDetectorLiveStreamDelegate)
 | 
			
		|||
 *
 | 
			
		||||
 * @param objectDetector The object detector which performed the object detection.
 | 
			
		||||
 * This is useful to test equality when there are multiple instances of `MPPObjectDetector`.
 | 
			
		||||
 * @param result The `MPPObjectDetectionResult` object that contains a list of detections, each
 | 
			
		||||
 * @param result The `MPPObjectDetectorResult` object that contains a list of detections, each
 | 
			
		||||
 * detection has a bounding box that is expressed in the unrotated input frame of reference
 | 
			
		||||
 * coordinates system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the
 | 
			
		||||
 * underlying image data.
 | 
			
		||||
| 
						 | 
				
			
			@ -54,7 +54,7 @@ NS_SWIFT_NAME(ObjectDetectorLiveStreamDelegate)
 | 
			
		|||
 * detection on the input live stream image data.
 | 
			
		||||
 */
 | 
			
		||||
- (void)objectDetector:(MPPObjectDetector *)objectDetector
 | 
			
		||||
    didFinishDetectionWithResult:(nullable MPPObjectDetectionResult *)result
 | 
			
		||||
    didFinishDetectionWithResult:(nullable MPPObjectDetectorResult *)result
 | 
			
		||||
         timestampInMilliseconds:(NSInteger)timestampInMilliseconds
 | 
			
		||||
                           error:(nullable NSError *)error
 | 
			
		||||
    NS_SWIFT_NAME(objectDetector(_:didFinishDetection:timestampInMilliseconds:error:));
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,8 +19,8 @@
 | 
			
		|||
NS_ASSUME_NONNULL_BEGIN
 | 
			
		||||
 | 
			
		||||
/** Represents the detection results generated by `MPPObjectDetector`. */
 | 
			
		||||
NS_SWIFT_NAME(ObjectDetectionResult)
 | 
			
		||||
@interface MPPObjectDetectionResult : MPPTaskResult
 | 
			
		||||
NS_SWIFT_NAME(ObjectDetectorResult)
 | 
			
		||||
@interface MPPObjectDetectorResult : MPPTaskResult
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The array of `MPPDetection` objects each of which has a bounding box that is expressed in the
 | 
			
		||||
| 
						 | 
				
			
			@ -30,7 +30,7 @@ NS_SWIFT_NAME(ObjectDetectionResult)
 | 
			
		|||
@property(nonatomic, readonly) NSArray<MPPDetection *> *detections;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Initializes a new `MPPObjectDetectionResult` with the given array of detections and timestamp (in
 | 
			
		||||
 * Initializes a new `MPPObjectDetectorResult` with the given array of detections and timestamp (in
 | 
			
		||||
 * milliseconds).
 | 
			
		||||
 *
 | 
			
		||||
 * @param detections An array of `MPPDetection` objects each of which has a bounding box that is
 | 
			
		||||
| 
						 | 
				
			
			@ -38,7 +38,7 @@ NS_SWIFT_NAME(ObjectDetectionResult)
 | 
			
		|||
 * x [0,image_height)`, which are the dimensions of the underlying image data.
 | 
			
		||||
 * @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
 | 
			
		||||
 *
 | 
			
		||||
 * @return An instance of `MPPObjectDetectionResult` initialized with the given array of detections
 | 
			
		||||
 * @return An instance of `MPPObjectDetectorResult` initialized with the given array of detections
 | 
			
		||||
 * and timestamp (in milliseconds).
 | 
			
		||||
 */
 | 
			
		||||
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
 | 
			
		||||
| 
						 | 
				
			
			@ -12,9 +12,9 @@
 | 
			
		|||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
 | 
			
		||||
 | 
			
		||||
@implementation MPPObjectDetectionResult
 | 
			
		||||
@implementation MPPObjectDetectorResult
 | 
			
		||||
 | 
			
		||||
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
 | 
			
		||||
           timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
 | 
			
		||||
| 
						 | 
				
			
			@ -31,12 +31,12 @@ objc_library(
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
objc_library(
 | 
			
		||||
    name = "MPPObjectDetectionResultHelpers",
 | 
			
		||||
    srcs = ["sources/MPPObjectDetectionResult+Helpers.mm"],
 | 
			
		||||
    hdrs = ["sources/MPPObjectDetectionResult+Helpers.h"],
 | 
			
		||||
    name = "MPPObjectDetectorResultHelpers",
 | 
			
		||||
    srcs = ["sources/MPPObjectDetectorResult+Helpers.mm"],
 | 
			
		||||
    hdrs = ["sources/MPPObjectDetectorResult+Helpers.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/framework:packet",
 | 
			
		||||
        "//mediapipe/tasks/ios/components/containers/utils:MPPDetectionHelpers",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetectionResult",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetectorResult",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,7 +12,7 @@
 | 
			
		|||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/packet.h"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -20,17 +20,17 @@ NS_ASSUME_NONNULL_BEGIN
 | 
			
		|||
 | 
			
		||||
static const int kMicroSecondsPerMilliSecond = 1000;
 | 
			
		||||
 | 
			
		||||
@interface MPPObjectDetectionResult (Helpers)
 | 
			
		||||
@interface MPPObjectDetectorResult (Helpers)
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Creates an `MPPObjectDetectionResult` from a MediaPipe packet containing a
 | 
			
		||||
 * Creates an `MPPObjectDetectorResult` from a MediaPipe packet containing a
 | 
			
		||||
 * `std::vector<DetectionProto>`.
 | 
			
		||||
 *
 | 
			
		||||
 * @param packet a MediaPipe packet wrapping a `std::vector<DetectionProto>`.
 | 
			
		||||
 *
 | 
			
		||||
 * @return  An `MPPObjectDetectionResult` object that contains a list of detections.
 | 
			
		||||
 * @return  An `MPPObjectDetectorResult` object that contains a list of detections.
 | 
			
		||||
 */
 | 
			
		||||
+ (nullable MPPObjectDetectionResult *)objectDetectionResultWithDetectionsPacket:
 | 
			
		||||
+ (nullable MPPObjectDetectorResult *)objectDetectorResultWithDetectionsPacket:
 | 
			
		||||
    (const mediapipe::Packet &)packet;
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
| 
						 | 
				
			
			@ -12,7 +12,7 @@
 | 
			
		|||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.h"
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -21,9 +21,9 @@ using DetectionProto = ::mediapipe::Detection;
 | 
			
		|||
using ::mediapipe::Packet;
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@implementation MPPObjectDetectionResult (Helpers)
 | 
			
		||||
@implementation MPPObjectDetectorResult (Helpers)
 | 
			
		||||
 | 
			
		||||
+ (nullable MPPObjectDetectionResult *)objectDetectionResultWithDetectionsPacket:
 | 
			
		||||
+ (nullable MPPObjectDetectorResult *)objectDetectorResultWithDetectionsPacket:
 | 
			
		||||
    (const Packet &)packet {
 | 
			
		||||
  if (!packet.ValidateAsType<std::vector<DetectionProto>>().ok()) {
 | 
			
		||||
    return nil;
 | 
			
		||||
| 
						 | 
				
			
			@ -37,10 +37,10 @@ using ::mediapipe::Packet;
 | 
			
		|||
    [detections addObject:[MPPDetection detectionWithProto:detectionProto]];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return [[MPPObjectDetectionResult alloc]
 | 
			
		||||
           initWithDetections:detections
 | 
			
		||||
      timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
 | 
			
		||||
                                          kMicroSecondsPerMilliSecond)];
 | 
			
		||||
  return
 | 
			
		||||
      [[MPPObjectDetectorResult alloc] initWithDetections:detections
 | 
			
		||||
                                  timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
 | 
			
		||||
                                                                      kMicroSecondsPerMilliSecond)];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
| 
						 | 
				
			
			@ -166,7 +166,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
 | 
			
		|||
    // For 90° and 270° rotations, we need to swap width and height.
 | 
			
		||||
    // This is due to the internal behavior of ImageToTensorCalculator, which:
 | 
			
		||||
    // - first denormalizes the provided rect by multiplying the rect width or
 | 
			
		||||
    //   height by the image width or height, repectively.
 | 
			
		||||
    //   height by the image width or height, respectively.
 | 
			
		||||
    // - then rotates this by denormalized rect by the provided rotation, and
 | 
			
		||||
    //   uses this for cropping,
 | 
			
		||||
    // - then finally rotates this back.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -115,6 +115,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
        segmenterOptions.outputCategoryMask()
 | 
			
		||||
            ? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
 | 
			
		||||
            : -1;
 | 
			
		||||
    final int qualityScoresOutStreamIndex =
 | 
			
		||||
        getStreamIndex.apply(outputStreams, "QUALITY_SCORES:quality_scores");
 | 
			
		||||
    final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
 | 
			
		||||
 | 
			
		||||
    // TODO: Consolidate OutputHandler and TaskRunner.
 | 
			
		||||
| 
						 | 
				
			
			@ -128,6 +130,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
              return ImageSegmenterResult.create(
 | 
			
		||||
                  Optional.empty(),
 | 
			
		||||
                  Optional.empty(),
 | 
			
		||||
                  new ArrayList<>(),
 | 
			
		||||
                  packets.get(imageOutStreamIndex).getTimestamp());
 | 
			
		||||
            }
 | 
			
		||||
            boolean copyImage = !segmenterOptions.resultListener().isPresent();
 | 
			
		||||
| 
						 | 
				
			
			@ -182,9 +185,16 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
                  new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
 | 
			
		||||
              categoryMask = Optional.of(builder.build());
 | 
			
		||||
            }
 | 
			
		||||
            float[] qualityScores =
 | 
			
		||||
                PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
 | 
			
		||||
            List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
 | 
			
		||||
            for (float score : qualityScores) {
 | 
			
		||||
              qualityScoresList.add(score);
 | 
			
		||||
            }
 | 
			
		||||
            return ImageSegmenterResult.create(
 | 
			
		||||
                confidenceMasks,
 | 
			
		||||
                categoryMask,
 | 
			
		||||
                qualityScoresList,
 | 
			
		||||
                BaseVisionTaskApi.generateResultTimestampMs(
 | 
			
		||||
                    segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
 | 
			
		||||
          }
 | 
			
		||||
| 
						 | 
				
			
			@ -592,8 +602,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
      public abstract Builder setOutputCategoryMask(boolean value);
 | 
			
		||||
 | 
			
		||||
      /**
 | 
			
		||||
       * Sets an optional {@link ResultListener} to receive the segmentation results when the graph
 | 
			
		||||
       * pipeline is done processing an image.
 | 
			
		||||
       * /** Sets an optional {@link ResultListener} to receive the segmentation results when the
 | 
			
		||||
       * graph pipeline is done processing an image.
 | 
			
		||||
       */
 | 
			
		||||
      public abstract Builder setResultListener(
 | 
			
		||||
          ResultListener<ImageSegmenterResult, MPImage> value);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -34,19 +34,30 @@ public abstract class ImageSegmenterResult implements TaskResult {
 | 
			
		|||
   * @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
 | 
			
		||||
   *     category mask, where each pixel represents the class which the pixel in the original image
 | 
			
		||||
   *     was predicted to belong to.
 | 
			
		||||
   * @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Defaults
 | 
			
		||||
   *     to `1` if the model doesn't output quality scores. Each element corresponds to the score of
 | 
			
		||||
   *     the category in the model outputs.
 | 
			
		||||
   * @param timestampMs a timestamp for this result.
 | 
			
		||||
   */
 | 
			
		||||
  // TODO: consolidate output formats across platforms.
 | 
			
		||||
  public static ImageSegmenterResult create(
 | 
			
		||||
      Optional<List<MPImage>> confidenceMasks, Optional<MPImage> categoryMask, long timestampMs) {
 | 
			
		||||
      Optional<List<MPImage>> confidenceMasks,
 | 
			
		||||
      Optional<MPImage> categoryMask,
 | 
			
		||||
      List<Float> qualityScores,
 | 
			
		||||
      long timestampMs) {
 | 
			
		||||
    return new AutoValue_ImageSegmenterResult(
 | 
			
		||||
        confidenceMasks.map(Collections::unmodifiableList), categoryMask, timestampMs);
 | 
			
		||||
        confidenceMasks.map(Collections::unmodifiableList),
 | 
			
		||||
        categoryMask,
 | 
			
		||||
        Collections.unmodifiableList(qualityScores),
 | 
			
		||||
        timestampMs);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public abstract Optional<List<MPImage>> confidenceMasks();
 | 
			
		||||
 | 
			
		||||
  public abstract Optional<MPImage> categoryMask();
 | 
			
		||||
 | 
			
		||||
  public abstract List<Float> qualityScores();
 | 
			
		||||
 | 
			
		||||
  @Override
 | 
			
		||||
  public abstract long timestampMs();
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -127,6 +127,10 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
      outputStreams.add("CATEGORY_MASK:category_mask");
 | 
			
		||||
    }
 | 
			
		||||
    final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
 | 
			
		||||
 | 
			
		||||
    outputStreams.add("QUALITY_SCORES:quality_scores");
 | 
			
		||||
    final int qualityScoresOutStreamIndex = outputStreams.size() - 1;
 | 
			
		||||
 | 
			
		||||
    outputStreams.add("IMAGE:image_out");
 | 
			
		||||
    // TODO: add test for stream indices.
 | 
			
		||||
    final int imageOutStreamIndex = outputStreams.size() - 1;
 | 
			
		||||
| 
						 | 
				
			
			@ -142,6 +146,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
              return ImageSegmenterResult.create(
 | 
			
		||||
                  Optional.empty(),
 | 
			
		||||
                  Optional.empty(),
 | 
			
		||||
                  new ArrayList<>(),
 | 
			
		||||
                  packets.get(imageOutStreamIndex).getTimestamp());
 | 
			
		||||
            }
 | 
			
		||||
            // If resultListener is not provided, the resulted MPImage is deep copied from
 | 
			
		||||
| 
						 | 
				
			
			@ -199,9 +204,17 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
              categoryMask = Optional.of(builder.build());
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            float[] qualityScores =
 | 
			
		||||
                PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
 | 
			
		||||
            List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
 | 
			
		||||
            for (float score : qualityScores) {
 | 
			
		||||
              qualityScoresList.add(score);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            return ImageSegmenterResult.create(
 | 
			
		||||
                confidenceMasks,
 | 
			
		||||
                categoryMask,
 | 
			
		||||
                qualityScoresList,
 | 
			
		||||
                BaseVisionTaskApi.generateResultTimestampMs(
 | 
			
		||||
                    RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
 | 
			
		||||
          }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -201,6 +201,7 @@ py_test(
 | 
			
		|||
        "//mediapipe/tasks/testdata/vision:test_images",
 | 
			
		||||
        "//mediapipe/tasks/testdata/vision:test_models",
 | 
			
		||||
    ],
 | 
			
		||||
    tags = ["not_run:arm"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/python:_framework_bindings",
 | 
			
		||||
        "//mediapipe/tasks/python/components/containers:rect",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,13 +27,14 @@ export declare interface Detection {
 | 
			
		|||
  boundingBox?: BoundingBox;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Optional list of keypoints associated with the detection. Keypoints
 | 
			
		||||
   * represent interesting points related to the detection. For example, the
 | 
			
		||||
   * keypoints represent the eye, ear and mouth from face detection model. Or
 | 
			
		||||
   * in the template matching detection, e.g. KNIFT, they can represent the
 | 
			
		||||
   * feature points for template matching.
 | 
			
		||||
   * List of keypoints associated with the detection. Keypoints represent
 | 
			
		||||
   * interesting points related to the detection. For example, the keypoints
 | 
			
		||||
   * represent the eye, ear and mouth from face detection model. Or in the
 | 
			
		||||
   * template matching detection, e.g. KNIFT, they can represent the feature
 | 
			
		||||
   * points for template matching. Contains an empty list if no keypoints are
 | 
			
		||||
   * detected.
 | 
			
		||||
   */
 | 
			
		||||
  keypoints?: NormalizedKeypoint[];
 | 
			
		||||
  keypoints: NormalizedKeypoint[];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/** Detection results of a model. */
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -85,7 +85,8 @@ describe('convertFromDetectionProto()', () => {
 | 
			
		|||
        categoryName: '',
 | 
			
		||||
        displayName: '',
 | 
			
		||||
      }],
 | 
			
		||||
      boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
 | 
			
		||||
      boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
 | 
			
		||||
      keypoints: []
 | 
			
		||||
    });
 | 
			
		||||
  });
 | 
			
		||||
});
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,7 +26,7 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
 | 
			
		|||
  const labels = source.getLabelList();
 | 
			
		||||
  const displayNames = source.getDisplayNameList();
 | 
			
		||||
 | 
			
		||||
  const detection: Detection = {categories: []};
 | 
			
		||||
  const detection: Detection = {categories: [], keypoints: []};
 | 
			
		||||
  for (let i = 0; i < scores.length; i++) {
 | 
			
		||||
    detection.categories.push({
 | 
			
		||||
      score: scores[i],
 | 
			
		||||
| 
						 | 
				
			
			@ -47,7 +47,6 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
  if (source.getLocationData()?.getRelativeKeypointsList().length) {
 | 
			
		||||
    detection.keypoints = [];
 | 
			
		||||
    for (const keypoint of
 | 
			
		||||
             source.getLocationData()!.getRelativeKeypointsList()) {
 | 
			
		||||
      detection.keypoints.push({
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -62,7 +62,10 @@ jasmine_node_test(
 | 
			
		|||
mediapipe_ts_library(
 | 
			
		||||
    name = "mask",
 | 
			
		||||
    srcs = ["mask.ts"],
 | 
			
		||||
    deps = [":image"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":image",
 | 
			
		||||
        "//mediapipe/web/graph_runner:platform_utils",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mediapipe_ts_library(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -60,6 +60,10 @@ class MPImageTestContext {
 | 
			
		|||
 | 
			
		||||
    this.webGLTexture = gl.createTexture()!;
 | 
			
		||||
    gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
 | 
			
		||||
    gl.texImage2D(
 | 
			
		||||
        gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, this.imageBitmap);
 | 
			
		||||
    gl.bindTexture(gl.TEXTURE_2D, null);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -187,10 +187,11 @@ export class MPImage {
 | 
			
		|||
        destinationContainer =
 | 
			
		||||
            assertNotNull(gl.createTexture(), 'Failed to create texture');
 | 
			
		||||
        gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
 | 
			
		||||
 | 
			
		||||
        this.configureTextureParams();
 | 
			
		||||
        gl.texImage2D(
 | 
			
		||||
            gl.TEXTURE_2D, 0, gl.RGBA, this.width, this.height, 0, gl.RGBA,
 | 
			
		||||
            gl.UNSIGNED_BYTE, null);
 | 
			
		||||
        gl.bindTexture(gl.TEXTURE_2D, null);
 | 
			
		||||
 | 
			
		||||
        shaderContext.bindFramebuffer(gl, destinationContainer);
 | 
			
		||||
        shaderContext.run(gl, /* flipVertically= */ false, () => {
 | 
			
		||||
| 
						 | 
				
			
			@ -302,6 +303,20 @@ export class MPImage {
 | 
			
		|||
    return webGLTexture;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Sets texture params for the currently bound texture. */
 | 
			
		||||
  private configureTextureParams() {
 | 
			
		||||
    const gl = this.getGL();
 | 
			
		||||
    // `gl.LINEAR` might break rendering for some textures, but it allows us to
 | 
			
		||||
    // do smooth resizing. Ideally, this would be user-configurable, but for now
 | 
			
		||||
    // we hard-code the value here to `gl.LINEAR` (versus `gl.NEAREST` for
 | 
			
		||||
    // `MPMask` where we do not want to interpolate mask values, especially for
 | 
			
		||||
    // category masks).
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Binds the backing texture to the canvas. If the texture does not yet
 | 
			
		||||
   * exist, creates it first.
 | 
			
		||||
| 
						 | 
				
			
			@ -318,16 +333,12 @@ export class MPImage {
 | 
			
		|||
          assertNotNull(gl.createTexture(), 'Failed to create texture');
 | 
			
		||||
      this.containers.push(webGLTexture);
 | 
			
		||||
      this.ownsWebGLTexture = true;
 | 
			
		||||
 | 
			
		||||
      gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
 | 
			
		||||
      this.configureTextureParams();
 | 
			
		||||
    } else {
 | 
			
		||||
      gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
 | 
			
		||||
    // TODO: Ideally, we would only set these once per texture and
 | 
			
		||||
    // not once every frame.
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
 | 
			
		||||
 | 
			
		||||
    return webGLTexture;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -60,8 +60,11 @@ class MPMaskTestContext {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    this.webGLTexture = gl.createTexture()!;
 | 
			
		||||
 | 
			
		||||
    gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
 | 
			
		||||
    gl.texImage2D(
 | 
			
		||||
        gl.TEXTURE_2D, 0, gl.R32F, width, height, 0, gl.RED, gl.FLOAT,
 | 
			
		||||
        new Float32Array(pixels).map(v => v / 255));
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,6 +15,7 @@
 | 
			
		|||
 */
 | 
			
		||||
 | 
			
		||||
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
 | 
			
		||||
import {isIOS} from '../../../../web/graph_runner/platform_utils';
 | 
			
		||||
 | 
			
		||||
/** Number of instances a user can keep alive before we raise a warning. */
 | 
			
		||||
const INSTANCE_COUNT_WARNING_THRESHOLD = 250;
 | 
			
		||||
| 
						 | 
				
			
			@ -32,6 +33,8 @@ enum MPMaskType {
 | 
			
		|||
/** The supported mask formats. For internal usage. */
 | 
			
		||||
export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The wrapper class for MediaPipe segmentation masks.
 | 
			
		||||
 *
 | 
			
		||||
| 
						 | 
				
			
			@ -56,6 +59,9 @@ export class MPMask {
 | 
			
		|||
   */
 | 
			
		||||
  private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD;
 | 
			
		||||
 | 
			
		||||
  /** The format used to write pixel values from textures. */
 | 
			
		||||
  private static texImage2DFormat?: GLenum;
 | 
			
		||||
 | 
			
		||||
  /** @hideconstructor */
 | 
			
		||||
  constructor(
 | 
			
		||||
      private readonly containers: MPMaskContainer[],
 | 
			
		||||
| 
						 | 
				
			
			@ -127,6 +133,29 @@ export class MPMask {
 | 
			
		|||
    return this.convertToWebGLTexture();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Returns the texture format used for writing float textures on this
 | 
			
		||||
   * platform.
 | 
			
		||||
   */
 | 
			
		||||
  getTexImage2DFormat(): GLenum {
 | 
			
		||||
    const gl = this.getGL();
 | 
			
		||||
    if (!MPMask.texImage2DFormat) {
 | 
			
		||||
      // Note: This is the same check we use in
 | 
			
		||||
      // `SegmentationPostprocessorGl::GetSegmentationResultGpu()`.
 | 
			
		||||
      if (gl.getExtension('EXT_color_buffer_float') &&
 | 
			
		||||
          gl.getExtension('OES_texture_float_linear') &&
 | 
			
		||||
          gl.getExtension('EXT_float_blend')) {
 | 
			
		||||
        MPMask.texImage2DFormat = gl.R32F;
 | 
			
		||||
      } else if (gl.getExtension('EXT_color_buffer_half_float')) {
 | 
			
		||||
        MPMask.texImage2DFormat = gl.R16F;
 | 
			
		||||
      } else {
 | 
			
		||||
        throw new Error(
 | 
			
		||||
            'GPU does not fully support 4-channel float32 or float16 formats');
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return MPMask.texImage2DFormat;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
 | 
			
		||||
  private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
 | 
			
		||||
  private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
 | 
			
		||||
| 
						 | 
				
			
			@ -175,8 +204,10 @@ export class MPMask {
 | 
			
		|||
        destinationContainer =
 | 
			
		||||
            assertNotNull(gl.createTexture(), 'Failed to create texture');
 | 
			
		||||
        gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
 | 
			
		||||
        this.configureTextureParams();
 | 
			
		||||
        const format = this.getTexImage2DFormat();
 | 
			
		||||
        gl.texImage2D(
 | 
			
		||||
            gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
 | 
			
		||||
            gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
 | 
			
		||||
            gl.FLOAT, null);
 | 
			
		||||
        gl.bindTexture(gl.TEXTURE_2D, null);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -207,7 +238,7 @@ export class MPMask {
 | 
			
		|||
    if (!this.canvas) {
 | 
			
		||||
      throw new Error(
 | 
			
		||||
          'Conversion to different image formats require that a canvas ' +
 | 
			
		||||
          'is passed when iniitializing the image.');
 | 
			
		||||
          'is passed when initializing the image.');
 | 
			
		||||
    }
 | 
			
		||||
    if (!this.gl) {
 | 
			
		||||
      this.gl = assertNotNull(
 | 
			
		||||
| 
						 | 
				
			
			@ -215,11 +246,6 @@ export class MPMask {
 | 
			
		|||
          'You cannot use a canvas that is already bound to a different ' +
 | 
			
		||||
              'type of rendering context.');
 | 
			
		||||
    }
 | 
			
		||||
    const ext = this.gl.getExtension('EXT_color_buffer_float');
 | 
			
		||||
    if (!ext) {
 | 
			
		||||
      // TODO: Ensure this works on iOS
 | 
			
		||||
      throw new Error('Missing required EXT_color_buffer_float extension');
 | 
			
		||||
    }
 | 
			
		||||
    return this.gl;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -237,18 +263,34 @@ export class MPMask {
 | 
			
		|||
      if (uint8Array) {
 | 
			
		||||
        float32Array = new Float32Array(uint8Array).map(v => v / 255);
 | 
			
		||||
      } else {
 | 
			
		||||
        float32Array = new Float32Array(this.width * this.height);
 | 
			
		||||
 | 
			
		||||
        const gl = this.getGL();
 | 
			
		||||
        const shaderContext = this.getShaderContext();
 | 
			
		||||
        float32Array = new Float32Array(this.width * this.height);
 | 
			
		||||
 | 
			
		||||
        // Create texture if needed
 | 
			
		||||
        const webGlTexture = this.convertToWebGLTexture();
 | 
			
		||||
 | 
			
		||||
        // Create a framebuffer from the texture and read back pixels
 | 
			
		||||
        shaderContext.bindFramebuffer(gl, webGlTexture);
 | 
			
		||||
        gl.readPixels(
 | 
			
		||||
            0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
 | 
			
		||||
        shaderContext.unbindFramebuffer();
 | 
			
		||||
 | 
			
		||||
        if (isIOS()) {
 | 
			
		||||
          // WebKit on iOS only supports gl.HALF_FLOAT for single channel reads
 | 
			
		||||
          // (as tested on iOS 16.4). HALF_FLOAT requires reading data into a
 | 
			
		||||
          // Uint16Array, however, and requires a manual bitwise conversion from
 | 
			
		||||
          // Uint16 to floating point numbers. This conversion is more expensive
 | 
			
		||||
          // that reading back a Float32Array from the RGBA image and dropping
 | 
			
		||||
          // the superfluous data, so we do this instead.
 | 
			
		||||
          const outputArray = new Float32Array(this.width * this.height * 4);
 | 
			
		||||
          gl.readPixels(
 | 
			
		||||
              0, 0, this.width, this.height, gl.RGBA, gl.FLOAT, outputArray);
 | 
			
		||||
          for (let i = 0, j = 0; i < float32Array.length; ++i, j += 4) {
 | 
			
		||||
            float32Array[i] = outputArray[j];
 | 
			
		||||
          }
 | 
			
		||||
        } else {
 | 
			
		||||
          gl.readPixels(
 | 
			
		||||
              0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      this.containers.push(float32Array);
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			@ -273,9 +315,9 @@ export class MPMask {
 | 
			
		|||
      webGLTexture = this.bindTexture();
 | 
			
		||||
 | 
			
		||||
      const data = this.convertToFloat32Array();
 | 
			
		||||
      // TODO: Add support for R16F to support iOS
 | 
			
		||||
      const format = this.getTexImage2DFormat();
 | 
			
		||||
      gl.texImage2D(
 | 
			
		||||
          gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
 | 
			
		||||
          gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
 | 
			
		||||
          gl.FLOAT, data);
 | 
			
		||||
      this.unbindTexture();
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			@ -283,6 +325,19 @@ export class MPMask {
 | 
			
		|||
    return webGLTexture;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Sets texture params for the currently bound texture. */
 | 
			
		||||
  private configureTextureParams() {
 | 
			
		||||
    const gl = this.getGL();
 | 
			
		||||
    // `gl.NEAREST` ensures that we do not get interpolated values for
 | 
			
		||||
    // masks. In some cases, the user might want interpolation (e.g. for
 | 
			
		||||
    // confidence masks), so we might want to make this user-configurable.
 | 
			
		||||
    // Note that `MPImage` uses `gl.LINEAR`.
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Binds the backing texture to the canvas. If the texture does not yet
 | 
			
		||||
   * exist, creates it first.
 | 
			
		||||
| 
						 | 
				
			
			@ -299,15 +354,12 @@ export class MPMask {
 | 
			
		|||
          assertNotNull(gl.createTexture(), 'Failed to create texture');
 | 
			
		||||
      this.containers.push(webGLTexture);
 | 
			
		||||
      this.ownsWebGLTexture = true;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
 | 
			
		||||
    // TODO: Ideally, we would only set these once per texture and
 | 
			
		||||
    // not once every frame.
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
 | 
			
		||||
      gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
 | 
			
		||||
      this.configureTextureParams();
 | 
			
		||||
    } else {
 | 
			
		||||
      gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return webGLTexture;
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -191,7 +191,8 @@ describe('FaceDetector', () => {
 | 
			
		|||
        categoryName: '',
 | 
			
		||||
        displayName: '',
 | 
			
		||||
      }],
 | 
			
		||||
      boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
 | 
			
		||||
      boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
 | 
			
		||||
      keypoints: []
 | 
			
		||||
    });
 | 
			
		||||
  });
 | 
			
		||||
});
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -171,7 +171,7 @@ export class FaceStylizer extends VisionTaskRunner {
 | 
			
		|||
  /**
 | 
			
		||||
   * Performs face stylization on the provided single image and returns the
 | 
			
		||||
   * result. This method creates a copy of the resulting image and should not be
 | 
			
		||||
   * used in high-throughput applictions. Only use this method when the
 | 
			
		||||
   * used in high-throughput applications. Only use this method when the
 | 
			
		||||
   * FaceStylizer is created with the image running mode.
 | 
			
		||||
   *
 | 
			
		||||
   * @param image An image to process.
 | 
			
		||||
| 
						 | 
				
			
			@ -182,7 +182,7 @@ export class FaceStylizer extends VisionTaskRunner {
 | 
			
		|||
  /**
 | 
			
		||||
   * Performs face stylization on the provided single image and returns the
 | 
			
		||||
   * result. This method creates a copy of the resulting image and should not be
 | 
			
		||||
   * used in high-throughput applictions. Only use this method when the
 | 
			
		||||
   * used in high-throughput applications. Only use this method when the
 | 
			
		||||
   * FaceStylizer is created with the image running mode.
 | 
			
		||||
   *
 | 
			
		||||
   * The 'imageProcessingOptions' parameter can be used to specify one or all
 | 
			
		||||
| 
						 | 
				
			
			@ -275,7 +275,7 @@ export class FaceStylizer extends VisionTaskRunner {
 | 
			
		|||
  /**
 | 
			
		||||
   * Performs face stylization on the provided video frame. This method creates
 | 
			
		||||
   * a copy of the resulting image and should not be used in high-throughput
 | 
			
		||||
   * applictions. Only use this method when the FaceStylizer is created with the
 | 
			
		||||
   * applications. Only use this method when the FaceStylizer is created with the
 | 
			
		||||
   * video running mode.
 | 
			
		||||
   *
 | 
			
		||||
   * The input frame can be of any size. It's required to provide the video
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -39,6 +39,7 @@ const IMAGE_STREAM = 'image_in';
 | 
			
		|||
const NORM_RECT_STREAM = 'norm_rect';
 | 
			
		||||
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
 | 
			
		||||
const CATEGORY_MASK_STREAM = 'category_mask';
 | 
			
		||||
const QUALITY_SCORES_STREAM = 'quality_scores';
 | 
			
		||||
const IMAGE_SEGMENTER_GRAPH =
 | 
			
		||||
    'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
 | 
			
		||||
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
 | 
			
		||||
| 
						 | 
				
			
			@ -61,6 +62,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
 | 
			
		|||
export class ImageSegmenter extends VisionTaskRunner {
 | 
			
		||||
  private categoryMask?: MPMask;
 | 
			
		||||
  private confidenceMasks?: MPMask[];
 | 
			
		||||
  private qualityScores?: number[];
 | 
			
		||||
  private labels: string[] = [];
 | 
			
		||||
  private userCallback?: ImageSegmenterCallback;
 | 
			
		||||
  private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
 | 
			
		||||
| 
						 | 
				
			
			@ -229,7 +231,7 @@ export class ImageSegmenter extends VisionTaskRunner {
 | 
			
		|||
  /**
 | 
			
		||||
   * Performs image segmentation on the provided single image and returns the
 | 
			
		||||
   * segmentation result. This method creates a copy of the resulting masks and
 | 
			
		||||
   * should not be used in high-throughput applictions. Only use this method
 | 
			
		||||
   * should not be used in high-throughput applications. Only use this method
 | 
			
		||||
   * when the ImageSegmenter is created with running mode `image`.
 | 
			
		||||
   *
 | 
			
		||||
   * @param image An image to process.
 | 
			
		||||
| 
						 | 
				
			
			@ -240,7 +242,7 @@ export class ImageSegmenter extends VisionTaskRunner {
 | 
			
		|||
  /**
 | 
			
		||||
   * Performs image segmentation on the provided single image and returns the
 | 
			
		||||
   * segmentation result. This method creates a copy of the resulting masks and
 | 
			
		||||
   * should not be used in high-v applictions. Only use this method when
 | 
			
		||||
   * should not be used in high-v applications. Only use this method when
 | 
			
		||||
   * the ImageSegmenter is created with running mode `image`.
 | 
			
		||||
   *
 | 
			
		||||
   * @param image An image to process.
 | 
			
		||||
| 
						 | 
				
			
			@ -318,7 +320,7 @@ export class ImageSegmenter extends VisionTaskRunner {
 | 
			
		|||
  /**
 | 
			
		||||
   * Performs image segmentation on the provided video frame and returns the
 | 
			
		||||
   * segmentation result. This method creates a copy of the resulting masks and
 | 
			
		||||
   * should not be used in high-v applictions. Only use this method when
 | 
			
		||||
   * should not be used in high-v applications. Only use this method when
 | 
			
		||||
   * the ImageSegmenter is created with running mode `video`.
 | 
			
		||||
   *
 | 
			
		||||
   * @param videoFrame A video frame to process.
 | 
			
		||||
| 
						 | 
				
			
			@ -367,12 +369,13 @@ export class ImageSegmenter extends VisionTaskRunner {
 | 
			
		|||
  private reset(): void {
 | 
			
		||||
    this.categoryMask = undefined;
 | 
			
		||||
    this.confidenceMasks = undefined;
 | 
			
		||||
    this.qualityScores = undefined;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private processResults(): ImageSegmenterResult|void {
 | 
			
		||||
    try {
 | 
			
		||||
      const result =
 | 
			
		||||
          new ImageSegmenterResult(this.confidenceMasks, this.categoryMask);
 | 
			
		||||
      const result = new ImageSegmenterResult(
 | 
			
		||||
          this.confidenceMasks, this.categoryMask, this.qualityScores);
 | 
			
		||||
      if (this.userCallback) {
 | 
			
		||||
        this.userCallback(result);
 | 
			
		||||
      } else {
 | 
			
		||||
| 
						 | 
				
			
			@ -442,6 +445,20 @@ export class ImageSegmenter extends VisionTaskRunner {
 | 
			
		|||
          });
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
 | 
			
		||||
    segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
 | 
			
		||||
 | 
			
		||||
    this.graphRunner.attachFloatVectorListener(
 | 
			
		||||
        QUALITY_SCORES_STREAM, (scores, timestamp) => {
 | 
			
		||||
          this.qualityScores = scores;
 | 
			
		||||
          this.setLatestOutputTimestamp(timestamp);
 | 
			
		||||
        });
 | 
			
		||||
    this.graphRunner.attachEmptyPacketListener(
 | 
			
		||||
        QUALITY_SCORES_STREAM, timestamp => {
 | 
			
		||||
          this.categoryMask = undefined;
 | 
			
		||||
          this.setLatestOutputTimestamp(timestamp);
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
    const binaryGraph = graphConfig.serializeBinary();
 | 
			
		||||
    this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,7 +30,13 @@ export class ImageSegmenterResult {
 | 
			
		|||
       * `WebGLTexture`-backed `MPImage` where each pixel represents the class
 | 
			
		||||
       * which the pixel in the original image was predicted to belong to.
 | 
			
		||||
       */
 | 
			
		||||
      readonly categoryMask?: MPMask) {}
 | 
			
		||||
      readonly categoryMask?: MPMask,
 | 
			
		||||
      /**
 | 
			
		||||
       * The quality scores of the result masks, in the range of [0, 1].
 | 
			
		||||
       * Defaults to `1` if the model doesn't output quality scores. Each
 | 
			
		||||
       * element corresponds to the score of the category in the model outputs.
 | 
			
		||||
       */
 | 
			
		||||
      readonly qualityScores?: number[]) {}
 | 
			
		||||
 | 
			
		||||
  /** Frees the resources held by the category and confidence masks. */
 | 
			
		||||
  close(): void {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -35,6 +35,8 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
 | 
			
		|||
      ((images: WasmImage, timestamp: number) => void)|undefined;
 | 
			
		||||
  confidenceMasksListener:
 | 
			
		||||
      ((images: WasmImage[], timestamp: number) => void)|undefined;
 | 
			
		||||
  qualityScoresListener:
 | 
			
		||||
      ((data: number[], timestamp: number) => void)|undefined;
 | 
			
		||||
 | 
			
		||||
  constructor() {
 | 
			
		||||
    super(createSpyWasmModule(), /* glCanvas= */ null);
 | 
			
		||||
| 
						 | 
				
			
			@ -52,6 +54,12 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
 | 
			
		|||
              expect(stream).toEqual('confidence_masks');
 | 
			
		||||
              this.confidenceMasksListener = listener;
 | 
			
		||||
            });
 | 
			
		||||
    this.attachListenerSpies[2] =
 | 
			
		||||
        spyOn(this.graphRunner, 'attachFloatVectorListener')
 | 
			
		||||
            .and.callFake((stream, listener) => {
 | 
			
		||||
              expect(stream).toEqual('quality_scores');
 | 
			
		||||
              this.qualityScoresListener = listener;
 | 
			
		||||
            });
 | 
			
		||||
    spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
 | 
			
		||||
      this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
 | 
			
		||||
    });
 | 
			
		||||
| 
						 | 
				
			
			@ -266,6 +274,7 @@ describe('ImageSegmenter', () => {
 | 
			
		|||
  it('invokes listener after masks are available', async () => {
 | 
			
		||||
    const categoryMask = new Uint8Array([1]);
 | 
			
		||||
    const confidenceMask = new Float32Array([0.0]);
 | 
			
		||||
    const qualityScores = [1.0];
 | 
			
		||||
    let listenerCalled = false;
 | 
			
		||||
 | 
			
		||||
    await imageSegmenter.setOptions(
 | 
			
		||||
| 
						 | 
				
			
			@ -283,11 +292,16 @@ describe('ImageSegmenter', () => {
 | 
			
		|||
          ],
 | 
			
		||||
          1337);
 | 
			
		||||
      expect(listenerCalled).toBeFalse();
 | 
			
		||||
      imageSegmenter.qualityScoresListener!(qualityScores, 1337);
 | 
			
		||||
      expect(listenerCalled).toBeFalse();
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    return new Promise<void>(resolve => {
 | 
			
		||||
      imageSegmenter.segment({} as HTMLImageElement, () => {
 | 
			
		||||
      imageSegmenter.segment({} as HTMLImageElement, result => {
 | 
			
		||||
        listenerCalled = true;
 | 
			
		||||
        expect(result.categoryMask).toBeInstanceOf(MPMask);
 | 
			
		||||
        expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
 | 
			
		||||
        expect(result.qualityScores).toEqual(qualityScores);
 | 
			
		||||
        resolve();
 | 
			
		||||
      });
 | 
			
		||||
    });
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -42,6 +42,7 @@ const NORM_RECT_IN_STREAM = 'norm_rect_in';
 | 
			
		|||
const ROI_IN_STREAM = 'roi_in';
 | 
			
		||||
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
 | 
			
		||||
const CATEGORY_MASK_STREAM = 'category_mask';
 | 
			
		||||
const QUALITY_SCORES_STREAM = 'quality_scores';
 | 
			
		||||
const IMAGEA_SEGMENTER_GRAPH =
 | 
			
		||||
    'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
 | 
			
		||||
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
 | 
			
		||||
| 
						 | 
				
			
			@ -86,6 +87,7 @@ export type InteractiveSegmenterCallback =
 | 
			
		|||
export class InteractiveSegmenter extends VisionTaskRunner {
 | 
			
		||||
  private categoryMask?: MPMask;
 | 
			
		||||
  private confidenceMasks?: MPMask[];
 | 
			
		||||
  private qualityScores?: number[];
 | 
			
		||||
  private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
 | 
			
		||||
  private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
 | 
			
		||||
  private userCallback?: InteractiveSegmenterCallback;
 | 
			
		||||
| 
						 | 
				
			
			@ -284,12 +286,13 @@ export class InteractiveSegmenter extends VisionTaskRunner {
 | 
			
		|||
  private reset(): void {
 | 
			
		||||
    this.confidenceMasks = undefined;
 | 
			
		||||
    this.categoryMask = undefined;
 | 
			
		||||
    this.qualityScores = undefined;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private processResults(): InteractiveSegmenterResult|void {
 | 
			
		||||
    try {
 | 
			
		||||
      const result = new InteractiveSegmenterResult(
 | 
			
		||||
          this.confidenceMasks, this.categoryMask);
 | 
			
		||||
          this.confidenceMasks, this.categoryMask, this.qualityScores);
 | 
			
		||||
      if (this.userCallback) {
 | 
			
		||||
        this.userCallback(result);
 | 
			
		||||
      } else {
 | 
			
		||||
| 
						 | 
				
			
			@ -361,6 +364,20 @@ export class InteractiveSegmenter extends VisionTaskRunner {
 | 
			
		|||
          });
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
 | 
			
		||||
    segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
 | 
			
		||||
 | 
			
		||||
    this.graphRunner.attachFloatVectorListener(
 | 
			
		||||
        QUALITY_SCORES_STREAM, (scores, timestamp) => {
 | 
			
		||||
          this.qualityScores = scores;
 | 
			
		||||
          this.setLatestOutputTimestamp(timestamp);
 | 
			
		||||
        });
 | 
			
		||||
    this.graphRunner.attachEmptyPacketListener(
 | 
			
		||||
        QUALITY_SCORES_STREAM, timestamp => {
 | 
			
		||||
          this.categoryMask = undefined;
 | 
			
		||||
          this.setLatestOutputTimestamp(timestamp);
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
    const binaryGraph = graphConfig.serializeBinary();
 | 
			
		||||
    this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,7 +30,13 @@ export class InteractiveSegmenterResult {
 | 
			
		|||
       * `WebGLTexture`-backed `MPImage` where each pixel represents the class
 | 
			
		||||
       * which the pixel in the original image was predicted to belong to.
 | 
			
		||||
       */
 | 
			
		||||
      readonly categoryMask?: MPMask) {}
 | 
			
		||||
      readonly categoryMask?: MPMask,
 | 
			
		||||
      /**
 | 
			
		||||
       * The quality scores of the result masks, in the range of [0, 1].
 | 
			
		||||
       * Defaults to `1` if the model doesn't output quality scores. Each
 | 
			
		||||
       * element corresponds to the score of the category in the model outputs.
 | 
			
		||||
       */
 | 
			
		||||
      readonly qualityScores?: number[]) {}
 | 
			
		||||
 | 
			
		||||
  /** Frees the resources held by the category and confidence masks. */
 | 
			
		||||
  close(): void {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,6 +46,8 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
 | 
			
		|||
      ((images: WasmImage, timestamp: number) => void)|undefined;
 | 
			
		||||
  confidenceMasksListener:
 | 
			
		||||
      ((images: WasmImage[], timestamp: number) => void)|undefined;
 | 
			
		||||
  qualityScoresListener:
 | 
			
		||||
      ((data: number[], timestamp: number) => void)|undefined;
 | 
			
		||||
  lastRoi?: RenderDataProto;
 | 
			
		||||
 | 
			
		||||
  constructor() {
 | 
			
		||||
| 
						 | 
				
			
			@ -64,6 +66,12 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
 | 
			
		|||
              expect(stream).toEqual('confidence_masks');
 | 
			
		||||
              this.confidenceMasksListener = listener;
 | 
			
		||||
            });
 | 
			
		||||
    this.attachListenerSpies[2] =
 | 
			
		||||
        spyOn(this.graphRunner, 'attachFloatVectorListener')
 | 
			
		||||
            .and.callFake((stream, listener) => {
 | 
			
		||||
              expect(stream).toEqual('quality_scores');
 | 
			
		||||
              this.qualityScoresListener = listener;
 | 
			
		||||
            });
 | 
			
		||||
    spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
 | 
			
		||||
      this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
 | 
			
		||||
    });
 | 
			
		||||
| 
						 | 
				
			
			@ -277,9 +285,10 @@ describe('InteractiveSegmenter', () => {
 | 
			
		|||
    });
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  it('invokes listener after masks are avaiblae', async () => {
 | 
			
		||||
  it('invokes listener after masks are available', async () => {
 | 
			
		||||
    const categoryMask = new Uint8Array([1]);
 | 
			
		||||
    const confidenceMask = new Float32Array([0.0]);
 | 
			
		||||
    const qualityScores = [1.0];
 | 
			
		||||
    let listenerCalled = false;
 | 
			
		||||
 | 
			
		||||
    await interactiveSegmenter.setOptions(
 | 
			
		||||
| 
						 | 
				
			
			@ -297,11 +306,16 @@ describe('InteractiveSegmenter', () => {
 | 
			
		|||
          ],
 | 
			
		||||
          1337);
 | 
			
		||||
      expect(listenerCalled).toBeFalse();
 | 
			
		||||
      interactiveSegmenter.qualityScoresListener!(qualityScores, 1337);
 | 
			
		||||
      expect(listenerCalled).toBeFalse();
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    return new Promise<void>(resolve => {
 | 
			
		||||
      interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => {
 | 
			
		||||
      interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
 | 
			
		||||
        listenerCalled = true;
 | 
			
		||||
        expect(result.categoryMask).toBeInstanceOf(MPMask);
 | 
			
		||||
        expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
 | 
			
		||||
        expect(result.qualityScores).toEqual(qualityScores);
 | 
			
		||||
        resolve();
 | 
			
		||||
      });
 | 
			
		||||
    });
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -210,7 +210,8 @@ describe('ObjectDetector', () => {
 | 
			
		|||
        categoryName: '',
 | 
			
		||||
        displayName: '',
 | 
			
		||||
      }],
 | 
			
		||||
      boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
 | 
			
		||||
      boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
 | 
			
		||||
      keypoints: []
 | 
			
		||||
    });
 | 
			
		||||
  });
 | 
			
		||||
});
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,3 +21,16 @@ export function isWebKit(browser = navigator) {
 | 
			
		|||
  // it uses "CriOS".
 | 
			
		||||
  return userAgent.includes('Safari') && !userAgent.includes('Chrome');
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/** Detect if code is running on iOS. */
 | 
			
		||||
export function isIOS() {
 | 
			
		||||
  // Source:
 | 
			
		||||
  // https://stackoverflow.com/questions/9038625/detect-if-device-is-ios
 | 
			
		||||
  return [
 | 
			
		||||
    'iPad Simulator', 'iPhone Simulator', 'iPod Simulator', 'iPad', 'iPhone',
 | 
			
		||||
    'iPod'
 | 
			
		||||
    // tslint:disable-next-line:deprecation
 | 
			
		||||
  ].includes(navigator.platform)
 | 
			
		||||
      // iPad on iOS 13 detection
 | 
			
		||||
      || (navigator.userAgent.includes('Mac') && 'ontouchend' in document);
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										5
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								setup.py
									
									
									
									
									
								
							| 
						 | 
				
			
			@ -357,7 +357,10 @@ class BuildExtension(build_ext.build_ext):
 | 
			
		|||
      for ext in self.extensions:
 | 
			
		||||
        target_name = self.get_ext_fullpath(ext.name)
 | 
			
		||||
        # Build x86
 | 
			
		||||
        self._build_binary(ext)
 | 
			
		||||
        self._build_binary(
 | 
			
		||||
            ext,
 | 
			
		||||
            ['--cpu=darwin', '--ios_multi_cpus=i386,x86_64,armv7,arm64'],
 | 
			
		||||
        )
 | 
			
		||||
        x86_name = self.get_ext_fullpath(ext.name)
 | 
			
		||||
        # Build Arm64
 | 
			
		||||
        ext.name = ext.name + '.arm64'
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										3
									
								
								third_party/flatbuffers/BUILD.bazel
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								third_party/flatbuffers/BUILD.bazel
									
									
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -42,16 +42,15 @@ filegroup(
 | 
			
		|||
        "include/flatbuffers/allocator.h",
 | 
			
		||||
        "include/flatbuffers/array.h",
 | 
			
		||||
        "include/flatbuffers/base.h",
 | 
			
		||||
        "include/flatbuffers/bfbs_generator.h",
 | 
			
		||||
        "include/flatbuffers/buffer.h",
 | 
			
		||||
        "include/flatbuffers/buffer_ref.h",
 | 
			
		||||
        "include/flatbuffers/code_generator.h",
 | 
			
		||||
        "include/flatbuffers/code_generators.h",
 | 
			
		||||
        "include/flatbuffers/default_allocator.h",
 | 
			
		||||
        "include/flatbuffers/detached_buffer.h",
 | 
			
		||||
        "include/flatbuffers/file_manager.h",
 | 
			
		||||
        "include/flatbuffers/flatbuffer_builder.h",
 | 
			
		||||
        "include/flatbuffers/flatbuffers.h",
 | 
			
		||||
        "include/flatbuffers/flatc.h",
 | 
			
		||||
        "include/flatbuffers/flex_flat_util.h",
 | 
			
		||||
        "include/flatbuffers/flexbuffers.h",
 | 
			
		||||
        "include/flatbuffers/grpc.h",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										8
									
								
								third_party/flatbuffers/workspace.bzl
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								third_party/flatbuffers/workspace.bzl
									
									
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
 | 
			
		|||
def repo():
 | 
			
		||||
    third_party_http_archive(
 | 
			
		||||
        name = "flatbuffers",
 | 
			
		||||
        strip_prefix = "flatbuffers-23.1.21",
 | 
			
		||||
        sha256 = "d84cb25686514348e615163b458ae0767001b24b42325f426fd56406fd384238",
 | 
			
		||||
        strip_prefix = "flatbuffers-23.5.8",
 | 
			
		||||
        sha256 = "55b75dfa5b6f6173e4abf9c35284a10482ba65db886b39db511eba6c244f1e88",
 | 
			
		||||
        urls = [
 | 
			
		||||
            "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
 | 
			
		||||
            "https://github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
 | 
			
		||||
            "https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz",
 | 
			
		||||
            "https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz",
 | 
			
		||||
        ],
 | 
			
		||||
        build_file = "//third_party/flatbuffers:BUILD.bazel",
 | 
			
		||||
        delete = ["build_defs.bzl", "BUILD.bazel"],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user