Pulled changes from master

This commit is contained in:
Prianka Liz Kariat 2023-05-24 19:15:48 +05:30
commit 164eae8c16
78 changed files with 644 additions and 337 deletions

View File

@ -55,6 +55,10 @@ MEDIAPIPE_REGISTER_NODE(ConcatenateUInt64VectorCalculator);
typedef ConcatenateVectorCalculator<bool> ConcatenateBoolVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateBoolVectorCalculator);
typedef ConcatenateVectorCalculator<std::string>
ConcatenateStringVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateStringVectorCalculator);
// Example config:
// node {
// calculator: "ConcatenateTfLiteTensorVectorCalculator"

View File

@ -30,13 +30,15 @@ namespace mediapipe {
typedef ConcatenateVectorCalculator<int> TestConcatenateIntVectorCalculator;
MEDIAPIPE_REGISTER_NODE(TestConcatenateIntVectorCalculator);
void AddInputVector(int index, const std::vector<int>& input, int64_t timestamp,
template <typename T>
void AddInputVector(int index, const std::vector<T>& input, int64_t timestamp,
CalculatorRunner* runner) {
runner->MutableInputs()->Index(index).packets.push_back(
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
MakePacket<std::vector<T>>(input).At(Timestamp(timestamp)));
}
void AddInputVectors(const std::vector<std::vector<int>>& inputs,
template <typename T>
void AddInputVectors(const std::vector<std::vector<T>>& inputs,
int64_t timestamp, CalculatorRunner* runner) {
for (int i = 0; i < inputs.size(); ++i) {
AddInputVector(i, inputs[i], timestamp, runner);
@ -382,6 +384,23 @@ TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) {
EXPECT_EQ(0, outputs.size());
}
TEST(ConcatenateStringVectorCalculatorTest, OneTimestamp) {
CalculatorRunner runner("ConcatenateStringVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<std::string>> inputs = {
{"a", "b"}, {"c"}, {"d", "e", "f"}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<std::string> expected_vector = {"a", "b", "c", "d", "e", "f"};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<std::string>>());
}
typedef ConcatenateVectorCalculator<std::unique_ptr<int>>
TestConcatenateUniqueIntPtrCalculator;
MEDIAPIPE_REGISTER_NODE(TestConcatenateUniqueIntPtrCalculator);

View File

@ -1099,6 +1099,7 @@ cc_library(
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],
alwayslink = True, # Defines TestServiceCalculator
)
cc_library(

View File

@ -68,11 +68,11 @@ StatusBuilder&& StatusBuilder::SetNoLogging() && {
return std::move(SetNoLogging());
}
StatusBuilder::operator Status() const& {
StatusBuilder::operator absl::Status() const& {
return StatusBuilder(*this).JoinMessageToStatus();
}
StatusBuilder::operator Status() && { return JoinMessageToStatus(); }
StatusBuilder::operator absl::Status() && { return JoinMessageToStatus(); }
absl::Status StatusBuilder::JoinMessageToStatus() {
if (!impl_) {

View File

@ -83,8 +83,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
return std::move(*this << msg);
}
operator Status() const&;
operator Status() &&;
operator absl::Status() const&;
operator absl::Status() &&;
absl::Status JoinMessageToStatus();

View File

@ -403,11 +403,11 @@ std::ostream &operator<<(std::ostream &os,
lhs op## = rhs; \
return lhs; \
}
STRONG_INT_VS_STRONG_INT_BINARY_OP(+);
STRONG_INT_VS_STRONG_INT_BINARY_OP(-);
STRONG_INT_VS_STRONG_INT_BINARY_OP(&);
STRONG_INT_VS_STRONG_INT_BINARY_OP(|);
STRONG_INT_VS_STRONG_INT_BINARY_OP(^);
STRONG_INT_VS_STRONG_INT_BINARY_OP(+)
STRONG_INT_VS_STRONG_INT_BINARY_OP(-)
STRONG_INT_VS_STRONG_INT_BINARY_OP(&)
STRONG_INT_VS_STRONG_INT_BINARY_OP(|)
STRONG_INT_VS_STRONG_INT_BINARY_OP(^)
#undef STRONG_INT_VS_STRONG_INT_BINARY_OP
// Define operators that take one StrongInt and one native integer argument.
@ -431,12 +431,12 @@ STRONG_INT_VS_STRONG_INT_BINARY_OP(^);
rhs op## = lhs; \
return rhs; \
}
STRONG_INT_VS_NUMERIC_BINARY_OP(*);
NUMERIC_VS_STRONG_INT_BINARY_OP(*);
STRONG_INT_VS_NUMERIC_BINARY_OP(/);
STRONG_INT_VS_NUMERIC_BINARY_OP(%);
STRONG_INT_VS_NUMERIC_BINARY_OP(<<); // NOLINT(whitespace/operators)
STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators)
STRONG_INT_VS_NUMERIC_BINARY_OP(*)
NUMERIC_VS_STRONG_INT_BINARY_OP(*)
STRONG_INT_VS_NUMERIC_BINARY_OP(/)
STRONG_INT_VS_NUMERIC_BINARY_OP(%)
STRONG_INT_VS_NUMERIC_BINARY_OP(<<) // NOLINT(whitespace/operators)
STRONG_INT_VS_NUMERIC_BINARY_OP(>>) // NOLINT(whitespace/operators)
#undef STRONG_INT_VS_NUMERIC_BINARY_OP
#undef NUMERIC_VS_STRONG_INT_BINARY_OP
@ -447,12 +447,12 @@ STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators)
StrongInt<TagType, ValueType, ValidatorType> rhs) { \
return lhs.value() op rhs.value(); \
}
STRONG_INT_COMPARISON_OP(==); // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(!=); // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(<); // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(<=); // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(>); // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(>=); // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(==) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(!=) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(<) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(<=) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(>) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(>=) // NOLINT(whitespace/operators)
#undef STRONG_INT_COMPARISON_OP
} // namespace intops

View File

@ -44,7 +44,6 @@ class GraphServiceBase {
constexpr GraphServiceBase(const char* key) : key(key) {}
virtual ~GraphServiceBase() = default;
inline virtual absl::StatusOr<Packet> CreateDefaultObject() const {
return DefaultInitializationUnsupported();
}
@ -52,14 +51,32 @@ class GraphServiceBase {
const char* key;
protected:
// `GraphService<T>` objects, deriving `GraphServiceBase` are designed to be
// global constants and not ever deleted through `GraphServiceBase`. Hence,
// protected and non-virtual destructor which helps to make `GraphService<T>`
// trivially destructible and properly defined as global constants.
//
// A class with any virtual functions should have a destructor that is either
// public and virtual or else protected and non-virtual.
// https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rc-dtor-virtual
~GraphServiceBase() = default;
absl::Status DefaultInitializationUnsupported() const {
return absl::UnimplementedError(absl::StrCat(
"Graph service '", key, "' does not support default initialization"));
}
};
// A global constant to refer a service:
// - Requesting `CalculatorContract::UseService` from calculator
// - Accessing `Calculator/SubgraphContext::Service`from calculator/subgraph
// - Setting before graph initialization `CalculatorGraph::SetServiceObject`
//
// NOTE: In headers, define your graph service reference safely as following:
// `inline constexpr GraphService<YourService> kYourService("YourService");`
//
template <typename T>
class GraphService : public GraphServiceBase {
class GraphService final : public GraphServiceBase {
public:
using type = T;
using packet_type = std::shared_ptr<T>;
@ -68,7 +85,7 @@ class GraphService : public GraphServiceBase {
kDisallowDefaultInitialization)
: GraphServiceBase(my_key), default_init_(default_init) {}
absl::StatusOr<Packet> CreateDefaultObject() const override {
absl::StatusOr<Packet> CreateDefaultObject() const final {
if (default_init_ != kAllowDefaultInitialization) {
return DefaultInitializationUnsupported();
}

View File

@ -7,7 +7,7 @@
namespace mediapipe {
namespace {
const GraphService<int> kIntService("mediapipe::IntService");
constexpr GraphService<int> kIntService("mediapipe::IntService");
} // namespace
TEST(GraphServiceManager, SetGetServiceObject) {

View File

@ -14,6 +14,8 @@
#include "mediapipe/framework/graph_service.h"
#include <type_traits>
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
@ -159,7 +161,7 @@ TEST_F(GraphServiceTest, CreateDefault) {
struct TestServiceData {};
const GraphService<TestServiceData> kTestServiceAllowDefaultInitialization(
constexpr GraphService<TestServiceData> kTestServiceAllowDefaultInitialization(
"kTestServiceAllowDefaultInitialization",
GraphServiceBase::kAllowDefaultInitialization);
@ -272,9 +274,13 @@ TEST(AllowDefaultInitializationGraphServiceTest,
HasSubstr("Service is unavailable.")));
}
const GraphService<TestServiceData> kTestServiceDisallowDefaultInitialization(
"kTestServiceDisallowDefaultInitialization",
GraphServiceBase::kDisallowDefaultInitialization);
constexpr GraphService<TestServiceData>
kTestServiceDisallowDefaultInitialization(
"kTestServiceDisallowDefaultInitialization",
GraphServiceBase::kDisallowDefaultInitialization);
static_assert(std::is_trivially_destructible_v<GraphService<TestServiceData>>,
"GraphService is not trivially destructible");
class FailOnUnavailableOptionalDisallowDefaultInitServiceCalculator
: public CalculatorBase {

View File

@ -16,15 +16,6 @@
namespace mediapipe {
const GraphService<TestServiceObject> kTestService(
"test_service", GraphServiceBase::kDisallowDefaultInitialization);
const GraphService<int> kAnotherService(
"another_service", GraphServiceBase::kAllowDefaultInitialization);
const GraphService<NoDefaultConstructor> kNoDefaultService(
"no_default_service", GraphServiceBase::kAllowDefaultInitialization);
const GraphService<NeedsCreateMethod> kNeedsCreateService(
"needs_create_service", GraphServiceBase::kAllowDefaultInitialization);
absl::Status TestServiceCalculator::GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));

View File

@ -22,14 +22,17 @@ namespace mediapipe {
using TestServiceObject = std::map<std::string, int>;
extern const GraphService<TestServiceObject> kTestService;
extern const GraphService<int> kAnotherService;
inline constexpr GraphService<TestServiceObject> kTestService(
"test_service", GraphServiceBase::kDisallowDefaultInitialization);
inline constexpr GraphService<int> kAnotherService(
"another_service", GraphServiceBase::kAllowDefaultInitialization);
class NoDefaultConstructor {
public:
NoDefaultConstructor() = delete;
};
extern const GraphService<NoDefaultConstructor> kNoDefaultService;
inline constexpr GraphService<NoDefaultConstructor> kNoDefaultService(
"no_default_service", GraphServiceBase::kAllowDefaultInitialization);
class NeedsCreateMethod {
public:
@ -40,7 +43,8 @@ class NeedsCreateMethod {
private:
NeedsCreateMethod() = default;
};
extern const GraphService<NeedsCreateMethod> kNeedsCreateService;
inline constexpr GraphService<NeedsCreateMethod> kNeedsCreateService(
"needs_create_service", GraphServiceBase::kAllowDefaultInitialization);
// Use a service.
class TestServiceCalculator : public CalculatorBase {

View File

@ -57,7 +57,7 @@ namespace mediapipe {
// have underflow/overflow etc. This type is used internally by Timestamp
// and TimestampDiff.
MEDIAPIPE_DEFINE_SAFE_INT_TYPE(TimestampBaseType, int64,
mediapipe::intops::LogFatalOnError);
mediapipe::intops::LogFatalOnError)
class TimestampDiff;

View File

@ -272,17 +272,20 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string);
#define MEDIAPIPE_REGISTER_TYPE(type, type_name, serialize_fn, deserialize_fn) \
SET_MEDIAPIPE_TYPE_MAP_VALUE( \
mediapipe::PacketTypeIdToMediaPipeTypeData, \
mediapipe::tool::GetTypeHash< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
mediapipe::TypeId::Of< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
.hash_code(), \
(mediapipe::MediaPipeTypeData{ \
mediapipe::tool::GetTypeHash< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
mediapipe::TypeId::Of< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
.hash_code(), \
type_name, serialize_fn, deserialize_fn})); \
SET_MEDIAPIPE_TYPE_MAP_VALUE( \
mediapipe::PacketTypeStringToMediaPipeTypeData, type_name, \
(mediapipe::MediaPipeTypeData{ \
mediapipe::tool::GetTypeHash< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
mediapipe::TypeId::Of< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
.hash_code(), \
type_name, serialize_fn, deserialize_fn}));
// End define MEDIAPIPE_REGISTER_TYPE.

View File

@ -38,7 +38,10 @@ cc_library(
srcs = ["gpu_service.cc"],
hdrs = ["gpu_service.h"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:graph_service"] + select({
deps = [
"//mediapipe/framework:graph_service",
"@com_google_absl//absl/base:core_headers",
] + select({
"//conditions:default": [
":gpu_shared_data_internal",
],
@ -292,6 +295,7 @@ cc_library(
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:logging",
"@com_google_absl//absl/functional:bind_front",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
] + select({
@ -630,6 +634,7 @@ cc_library(
"//mediapipe/framework:executor",
"//mediapipe/framework/deps:no_destructor",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/base:core_headers",
] + select({
"//conditions:default": [],
"//mediapipe:apple": [

View File

@ -47,6 +47,7 @@ std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Create(int width, int height,
auto buf = absl::make_unique<GlTextureBuffer>(GL_TEXTURE_2D, 0, width, height,
format, nullptr);
if (!buf->CreateInternal(data, alignment)) {
LOG(WARNING) << "Failed to create a GL texture";
return nullptr;
}
return buf;
@ -106,7 +107,10 @@ GlTextureBuffer::GlTextureBuffer(GLenum target, GLuint name, int width,
bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
auto context = GlContext::GetCurrent();
if (!context) return false;
if (!context) {
LOG(WARNING) << "Cannot create a GL texture without a valid context";
return false;
}
producer_context_ = context; // Save creation GL context.

View File

@ -20,6 +20,7 @@
#include <memory>
#include <utility>
#include "absl/log/check.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/gpu/gpu_buffer_format.h"
@ -72,8 +73,10 @@ class GpuBuffer {
// are not portable. Applications and calculators should normally obtain
// GpuBuffers in a portable way from the framework, e.g. using
// GpuBufferMultiPool.
explicit GpuBuffer(std::shared_ptr<internal::GpuBufferStorage> storage)
: holder_(std::make_shared<StorageHolder>(std::move(storage))) {}
explicit GpuBuffer(std::shared_ptr<internal::GpuBufferStorage> storage) {
CHECK(storage) << "Cannot construct GpuBuffer with null storage";
holder_ = std::make_shared<StorageHolder>(std::move(storage));
}
#if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
// This is used to support backward-compatible construction of GpuBuffer from

View File

@ -28,6 +28,12 @@ namespace mediapipe {
#define GL_HALF_FLOAT 0x140B
#endif // GL_HALF_FLOAT
#ifdef __EMSCRIPTEN__
#ifndef GL_HALF_FLOAT_OES
#define GL_HALF_FLOAT_OES 0x8D61
#endif // GL_HALF_FLOAT_OES
#endif // __EMSCRIPTEN__
#if !MEDIAPIPE_DISABLE_GPU
#ifdef GL_ES_VERSION_2_0
static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) {
@ -48,6 +54,12 @@ static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) {
case GL_RG8:
info->gl_internal_format = info->gl_format = GL_RG_EXT;
return;
#ifdef __EMSCRIPTEN__
case GL_RGBA16F:
info->gl_internal_format = GL_RGBA;
info->gl_type = GL_HALF_FLOAT_OES;
return;
#endif // __EMSCRIPTEN__
default:
return;
}

View File

@ -15,6 +15,7 @@
#ifndef MEDIAPIPE_GPU_GPU_SERVICE_H_
#define MEDIAPIPE_GPU_GPU_SERVICE_H_
#include "absl/base/attributes.h"
#include "mediapipe/framework/graph_service.h"
#if !MEDIAPIPE_DISABLE_GPU
@ -29,7 +30,7 @@ class GpuResources {
};
#endif // MEDIAPIPE_DISABLE_GPU
extern const GraphService<GpuResources> kGpuService;
ABSL_CONST_INIT extern const GraphService<GpuResources> kGpuService;
} // namespace mediapipe

View File

@ -14,6 +14,7 @@
#include "mediapipe/gpu/gpu_shared_data_internal.h"
#include "absl/base/attributes.h"
#include "mediapipe/framework/deps/no_destructor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/gpu/gl_context.h"
@ -116,7 +117,7 @@ GpuResources::~GpuResources() {
#endif // __APPLE__
}
extern const GraphService<GpuResources> kGpuService;
ABSL_CONST_INIT extern const GraphService<GpuResources> kGpuService;
absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) {
CHECK(node->Contract().ServiceRequests().contains(kGpuService.key));

View File

@ -187,7 +187,7 @@ class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
"""Instantiates perceptual loss.
Args:
feature_weight: The weight coeffcients of multiple model extracted
feature_weight: The weight coefficients of multiple model extracted
features used for calculating the perceptual loss.
loss_weight: The weight coefficients between `style_loss` and
`content_loss`.

View File

@ -105,7 +105,7 @@ class FaceStylizer(object):
self._train_model(train_data=train_data, preprocessor=self._preprocessor)
def _create_model(self):
"""Creates the componenets of face stylizer."""
"""Creates the components of face stylizer."""
self._encoder = model_util.load_keras_model(
constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path()
)
@ -138,7 +138,7 @@ class FaceStylizer(object):
"""
train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor)
# TODO: Support processing mulitple input style images. The
# TODO: Support processing multiple input style images. The
# input style images are expected to have similar style.
# style_sample represents a tuple of (style_image, style_label).
style_sample = next(iter(train_dataset))

View File

@ -103,8 +103,8 @@ class ModelResourcesCache {
};
// Global service for mediapipe task model resources cache.
const mediapipe::GraphService<ModelResourcesCache> kModelResourcesCacheService(
"mediapipe::tasks::ModelResourcesCacheService");
inline constexpr mediapipe::GraphService<ModelResourcesCache>
kModelResourcesCacheService("mediapipe::tasks::ModelResourcesCacheService");
} // namespace core
} // namespace tasks

View File

@ -291,8 +291,11 @@ class TensorsToSegmentationCalculator : public Node {
static constexpr Output<Image>::Multiple kConfidenceMaskOut{
"CONFIDENCE_MASK"};
static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"};
static constexpr Output<std::vector<float>>::Optional kQualityScoresOut{
"QUALITY_SCORES"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut,
kConfidenceMaskOut, kCategoryMaskOut);
kConfidenceMaskOut, kCategoryMaskOut,
kQualityScoresOut);
static absl::Status UpdateContract(CalculatorContract* cc);
@ -345,12 +348,33 @@ absl::Status TensorsToSegmentationCalculator::Open(
absl::Status TensorsToSegmentationCalculator::Process(
mediapipe::CalculatorContext* cc) {
RET_CHECK_EQ(kTensorsIn(cc).Get().size(), 1)
<< "Expect a vector of single Tensor.";
const auto& input_tensor = kTensorsIn(cc).Get()[0];
const auto& input_tensors = kTensorsIn(cc).Get();
if (input_tensors.size() != 1 && input_tensors.size() != 2) {
return absl::InvalidArgumentError(
"Expect input tensor vector of size 1 or 2.");
}
const auto& input_tensor = *input_tensors.rbegin();
ASSIGN_OR_RETURN(const Shape input_shape,
GetImageLikeTensorShape(input_tensor));
// TODO: should use tensor signature to get the correct output
// tensor.
if (input_tensors.size() == 2) {
const auto& quality_tensor = input_tensors[0];
const float* quality_score_buffer =
quality_tensor.GetCpuReadView().buffer<float>();
const std::vector<float> quality_scores(
quality_score_buffer,
quality_score_buffer +
(quality_tensor.bytes() / quality_tensor.element_size()));
kQualityScoresOut(cc).Send(quality_scores);
} else {
// If the input_tensors don't contain quality scores, send the default
// quality scores as 1.
const std::vector<float> quality_scores(input_shape.channels, 1.0f);
kQualityScoresOut(cc).Send(quality_scores);
}
// Category mask does not require activation function.
if (options_.segmenter_options().output_type() ==
SegmenterOptions::CONFIDENCE_MASK &&

View File

@ -46,6 +46,8 @@ constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kQualityScoresStreamName[] = "quality_scores";
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000;
@ -77,6 +79,8 @@ CalculatorGraphConfig CreateGraphConfig(
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
graph.Out(kCategoryMaskTag);
}
task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
graph.Out(kQualityScoresTag);
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag);
if (enable_flow_limiting) {
@ -172,9 +176,13 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
category_mask =
status_or_packets.value()[kCategoryMaskStreamName].Get<Image>();
}
const std::vector<float>& quality_scores =
status_or_packets.value()[kQualityScoresStreamName]
.Get<std::vector<float>>();
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback(
{{confidence_masks, category_mask}}, image_packet.Get<Image>(),
{{confidence_masks, category_mask, quality_scores}},
image_packet.Get<Image>(),
image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
};
}
@ -227,7 +235,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
}
return {{confidence_masks, category_mask}};
const std::vector<float>& quality_scores =
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
return {{confidence_masks, category_mask, quality_scores}};
}
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
@ -260,7 +270,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
}
return {{confidence_masks, category_mask}};
const std::vector<float>& quality_scores =
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
return {{confidence_masks, category_mask, quality_scores}};
}
absl::Status ImageSegmenter::SegmentAsync(

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <memory>
#include <optional>
#include <type_traits>
@ -81,6 +82,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
// Struct holding the different output streams produced by the image segmenter
@ -90,6 +92,7 @@ struct ImageSegmenterOutputs {
std::optional<std::vector<Source<Image>>> confidence_masks;
std::optional<Source<Image>> category_mask;
// The same as the input image, mainly used for live stream mode.
std::optional<Source<std::vector<float>>> quality_scores;
Source<Image> image;
};
@ -191,19 +194,12 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
"Segmentation tflite models are assumed to have a single subgraph.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
const auto* primary_subgraph = (*model.subgraphs())[0];
if (primary_subgraph->outputs()->size() != 1) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Segmentation tflite models are assumed to have a single output.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
ASSIGN_OR_RETURN(
*options->mutable_label_items(),
GetLabelItemsIfAny(*metadata_extractor,
*metadata_extractor->GetOutputTensorMetadata()->Get(0),
segmenter_option.display_names_locale()));
GetLabelItemsIfAny(
*metadata_extractor,
**metadata_extractor->GetOutputTensorMetadata()->crbegin(),
segmenter_option.display_names_locale()));
return absl::OkStatus();
}
@ -213,10 +209,16 @@ absl::StatusOr<const tflite::Tensor*> GetOutputTensor(
const tflite::Model& model = *model_resources.GetTfLiteModel();
const auto* primary_subgraph = (*model.subgraphs())[0];
const auto* output_tensor =
(*primary_subgraph->tensors())[(*primary_subgraph->outputs())[0]];
(*primary_subgraph->tensors())[*(*primary_subgraph->outputs()).rbegin()];
return output_tensor;
}
uint32_t GetOutputTensorsSize(const core::ModelResources& model_resources) {
const tflite::Model& model = *model_resources.GetTfLiteModel();
const auto* primary_subgraph = (*model.subgraphs())[0];
return primary_subgraph->outputs()->size();
}
// Get the input tensor from the tflite model of given model resources.
absl::StatusOr<const tflite::Tensor*> GetInputTensor(
const core::ModelResources& model_resources) {
@ -433,6 +435,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
*output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
}
}
if (output_streams.quality_scores) {
*output_streams.quality_scores >>
graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
}
output_streams.image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig();
}
@ -530,9 +536,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
}
}
auto quality_scores =
tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
/*confidence_masks=*/std::nullopt,
/*category_mask=*/std::nullopt,
/*quality_scores=*/quality_scores,
/*image=*/image_and_tensors.image};
} else {
std::optional<std::vector<Source<Image>>> confidence_masks;
@ -552,9 +561,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
if (output_category_mask_) {
category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)];
}
auto quality_scores =
tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt,
/*confidence_masks=*/confidence_masks,
/*category_mask=*/category_mask,
/*quality_scores=*/quality_scores,
/*image=*/image_and_tensors.image};
}
}

View File

@ -33,6 +33,10 @@ struct ImageSegmenterResult {
// A category mask of uint8 image in GRAY8 format where each pixel represents
// the class which the pixel in the original image was predicted to belong to.
std::optional<Image> category_mask;
// The quality scores of the result masks, in the range of [0, 1]. Defaults to
// `1` if the model doesn't output quality scores. Each element corresponds to
// the score of the category in the model outputs.
std::vector<float> quality_scores;
};
} // namespace image_segmenter

View File

@ -51,12 +51,14 @@ constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kRoiStreamName[] = "roi_in";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kQualityScoresStreamName[] = "quality_scores";
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
constexpr absl::string_view kImageTag{"IMAGE"};
constexpr absl::string_view kRoiTag{"ROI"};
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
constexpr absl::string_view kSubgraphTypeName{
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
@ -91,6 +93,8 @@ CalculatorGraphConfig CreateGraphConfig(
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
graph.Out(kCategoryMaskTag);
}
task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
graph.Out(kQualityScoresTag);
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag);
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
@ -201,7 +205,9 @@ absl::StatusOr<ImageSegmenterResult> InteractiveSegmenter::Segment(
if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
}
return {{confidence_masks, category_mask}};
const std::vector<float>& quality_scores =
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
return {{confidence_masks, category_mask, quality_scores}};
}
} // namespace interactive_segmenter

View File

@ -58,6 +58,7 @@ constexpr absl::string_view kAlphaTag{"ALPHA"};
constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
constexpr absl::string_view kRoiTag{"ROI"};
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
// Updates the graph to return `roi` stream which has same dimension as
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
@ -200,6 +201,8 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
graph[Output<Image>(kCategoryMaskTag)];
}
}
image_segmenter.Out(kQualityScoresTag) >>
graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
return graph.GetConfig();

View File

@ -81,7 +81,7 @@ strip_api_include_path_prefix(
"//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierResult.h",
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetector.h",
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorOptions.h",
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectionResult.h",
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorResult.h",
],
)
@ -162,7 +162,7 @@ apple_static_xcframework(
":MPPImageClassifierResult.h",
":MPPObjectDetector.h",
":MPPObjectDetectorOptions.h",
":MPPObjectDetectionResult.h",
":MPPObjectDetectorResult.h",
],
deps = [
"//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier",

View File

@ -16,17 +16,6 @@
NS_ASSUME_NONNULL_BEGIN
/**
* MediaPipe Tasks delegate.
*/
typedef NS_ENUM(NSUInteger, MPPDelegate) {
/** CPU. */
MPPDelegateCPU,
/** GPU. */
MPPDelegateGPU
} NS_SWIFT_NAME(Delegate);
/**
* Holds the base options that is used for creation of any type of task. It has fields with
* important information acceleration configuration, TFLite model source etc.
@ -37,12 +26,6 @@ NS_SWIFT_NAME(BaseOptions)
/** The path to the model asset to open and mmap in memory. */
@property(nonatomic, copy) NSString *modelAssetPath;
/**
* Device delegate to run the MediaPipe pipeline. If the delegate is not set, the default
* delegate CPU is used.
*/
@property(nonatomic) MPPDelegate delegate;
@end
NS_ASSUME_NONNULL_END

View File

@ -28,7 +28,6 @@
MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init];
baseOptions.modelAssetPath = self.modelAssetPath;
baseOptions.delegate = self.delegate;
return baseOptions;
}

View File

@ -33,20 +33,6 @@ using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions;
if (self.modelAssetPath) {
baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String);
}
switch (self.delegate) {
case MPPDelegateCPU: {
baseOptionsProto->mutable_acceleration()->mutable_tflite();
break;
}
case MPPDelegateGPU: {
// TODO: Provide an implementation for GPU Delegate.
[NSException raise:@"Invalid value for delegate" format:@"GPU Delegate is not implemented."];
break;
}
default:
break;
}
}
@end

View File

@ -28,9 +28,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \
XCTAssertNotEqual( \
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
NSNotFound)
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
#define AssertEqualCategoryArrays(categories, expectedCategories) \
XCTAssertEqual(categories.count, expectedCategories.count); \

View File

@ -29,9 +29,7 @@ static const float kSimilarityDiffTolerance = 1e-4;
XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \
XCTAssertNotEqual( \
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
NSNotFound)
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
#define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \
XCTAssertNotNil(textEmbedderResult); \

View File

@ -34,9 +34,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \
XCTAssertNotEqual( \
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
NSNotFound)
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
#define AssertEqualCategoryArrays(categories, expectedCategories) \
XCTAssertEqual(categories.count, expectedCategories.count); \
@ -670,10 +668,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
// Because of flow limiting, we cannot ensure that the callback will be
// invoked `iterationCount` times.
// An normal expectation will fail if expectation.fullfill() is not called
// An normal expectation will fail if expectation.fulfill() is not called
// `expectation.expectedFulfillmentCount` times.
// If `expectation.isInverted = true`, the test will only succeed if
// expectation is not fullfilled for the specified `expectedFulfillmentCount`.
// expectation is not fulfilled for the specified `expectedFulfillmentCount`.
// Since in our case we cannot predict how many times the expectation is
// supposed to be fullfilled setting,
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and

View File

@ -32,9 +32,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \
XCTAssertNotEqual( \
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
NSNotFound)
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
#define AssertEqualCategories(category, expectedCategory, detectionIndex, categoryIndex) \
XCTAssertEqual(category.index, expectedCategory.index, \
@ -70,7 +68,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
#pragma mark Results
+ (MPPObjectDetectionResult *)expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
+ (MPPObjectDetectorResult *)expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
(NSInteger)timestampInMilliseconds {
NSArray<MPPDetection *> *detections = @[
[[MPPDetection alloc] initWithCategories:@[
@ -95,8 +93,8 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
keypoints:nil],
];
return [[MPPObjectDetectionResult alloc] initWithDetections:detections
timestampInMilliseconds:timestampInMilliseconds];
return [[MPPObjectDetectorResult alloc] initWithDetections:detections
timestampInMilliseconds:timestampInMilliseconds];
}
- (void)assertDetections:(NSArray<MPPDetection *> *)detections
@ -112,25 +110,25 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
}
}
- (void)assertObjectDetectionResult:(MPPObjectDetectionResult *)objectDetectionResult
isEqualToExpectedResult:(MPPObjectDetectionResult *)expectedObjectDetectionResult
expectedDetectionsCount:(NSInteger)expectedDetectionsCount {
XCTAssertNotNil(objectDetectionResult);
- (void)assertObjectDetectorResult:(MPPObjectDetectorResult *)objectDetectorResult
isEqualToExpectedResult:(MPPObjectDetectorResult *)expectedObjectDetectorResult
expectedDetectionsCount:(NSInteger)expectedDetectionsCount {
XCTAssertNotNil(objectDetectorResult);
NSArray<MPPDetection *> *detectionsSubsetToCompare;
XCTAssertEqual(objectDetectionResult.detections.count, expectedDetectionsCount);
if (objectDetectionResult.detections.count > expectedObjectDetectionResult.detections.count) {
detectionsSubsetToCompare = [objectDetectionResult.detections
subarrayWithRange:NSMakeRange(0, expectedObjectDetectionResult.detections.count)];
XCTAssertEqual(objectDetectorResult.detections.count, expectedDetectionsCount);
if (objectDetectorResult.detections.count > expectedObjectDetectorResult.detections.count) {
detectionsSubsetToCompare = [objectDetectorResult.detections
subarrayWithRange:NSMakeRange(0, expectedObjectDetectorResult.detections.count)];
} else {
detectionsSubsetToCompare = objectDetectionResult.detections;
detectionsSubsetToCompare = objectDetectorResult.detections;
}
[self assertDetections:detectionsSubsetToCompare
isEqualToExpectedDetections:expectedObjectDetectionResult.detections];
isEqualToExpectedDetections:expectedObjectDetectorResult.detections];
XCTAssertEqual(objectDetectionResult.timestampInMilliseconds,
expectedObjectDetectionResult.timestampInMilliseconds);
XCTAssertEqual(objectDetectorResult.timestampInMilliseconds,
expectedObjectDetectorResult.timestampInMilliseconds);
}
#pragma mark File
@ -195,28 +193,27 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
- (void)assertResultsOfDetectInImage:(MPPImage *)mppImage
usingObjectDetector:(MPPObjectDetector *)objectDetector
maxResults:(NSInteger)maxResults
equalsObjectDetectionResult:(MPPObjectDetectionResult *)expectedObjectDetectionResult {
MPPObjectDetectionResult *objectDetectionResult = [objectDetector detectInImage:mppImage
error:nil];
equalsObjectDetectorResult:(MPPObjectDetectorResult *)expectedObjectDetectorResult {
MPPObjectDetectorResult *ObjectDetectorResult = [objectDetector detectInImage:mppImage error:nil];
[self assertObjectDetectionResult:objectDetectionResult
isEqualToExpectedResult:expectedObjectDetectionResult
expectedDetectionsCount:maxResults > 0 ? maxResults
: objectDetectionResult.detections.count];
[self assertObjectDetectorResult:ObjectDetectorResult
isEqualToExpectedResult:expectedObjectDetectorResult
expectedDetectionsCount:maxResults > 0 ? maxResults
: ObjectDetectorResult.detections.count];
}
- (void)assertResultsOfDetectInImageWithFileInfo:(NSDictionary *)fileInfo
usingObjectDetector:(MPPObjectDetector *)objectDetector
maxResults:(NSInteger)maxResults
equalsObjectDetectionResult:
(MPPObjectDetectionResult *)expectedObjectDetectionResult {
equalsObjectDetectorResult:
(MPPObjectDetectorResult *)expectedObjectDetectorResult {
MPPImage *mppImage = [self imageWithFileInfo:fileInfo];
[self assertResultsOfDetectInImage:mppImage
usingObjectDetector:objectDetector
maxResults:maxResults
equalsObjectDetectionResult:expectedObjectDetectionResult];
equalsObjectDetectorResult:expectedObjectDetectorResult];
}
#pragma mark General Tests
@ -266,10 +263,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
usingObjectDetector:objectDetector
maxResults:-1
equalsObjectDetectionResult:
[MPPObjectDetectorTests
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
0]];
equalsObjectDetectorResult:
[MPPObjectDetectorTests
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
0]];
}
- (void)testDetectWithOptionsSucceeds {
@ -280,10 +277,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
usingObjectDetector:objectDetector
maxResults:-1
equalsObjectDetectionResult:
[MPPObjectDetectorTests
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
0]];
equalsObjectDetectorResult:
[MPPObjectDetectorTests
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
0]];
}
- (void)testDetectWithMaxResultsSucceeds {
@ -297,10 +294,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
usingObjectDetector:objectDetector
maxResults:maxResults
equalsObjectDetectionResult:
[MPPObjectDetectorTests
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
0]];
equalsObjectDetectorResult:
[MPPObjectDetectorTests
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
0]];
}
- (void)testDetectWithScoreThresholdSucceeds {
@ -316,13 +313,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
boundingBox:CGRectMake(608, 161, 381, 439)
keypoints:nil],
];
MPPObjectDetectionResult *expectedObjectDetectionResult =
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
MPPObjectDetectorResult *expectedObjectDetectorResult =
[[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
usingObjectDetector:objectDetector
maxResults:-1
equalsObjectDetectionResult:expectedObjectDetectionResult];
equalsObjectDetectorResult:expectedObjectDetectorResult];
}
- (void)testDetectWithCategoryAllowlistSucceeds {
@ -359,13 +356,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
keypoints:nil],
];
MPPObjectDetectionResult *expectedDetectionResult =
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
MPPObjectDetectorResult *expectedDetectionResult =
[[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
usingObjectDetector:objectDetector
maxResults:-1
equalsObjectDetectionResult:expectedDetectionResult];
equalsObjectDetectorResult:expectedDetectionResult];
}
- (void)testDetectWithCategoryDenylistSucceeds {
@ -414,13 +411,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
keypoints:nil],
];
MPPObjectDetectionResult *expectedDetectionResult =
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
MPPObjectDetectorResult *expectedDetectionResult =
[[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
[self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage
usingObjectDetector:objectDetector
maxResults:-1
equalsObjectDetectionResult:expectedDetectionResult];
equalsObjectDetectorResult:expectedDetectionResult];
}
- (void)testDetectWithOrientationSucceeds {
@ -437,8 +434,8 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
keypoints:nil],
];
MPPObjectDetectionResult *expectedDetectionResult =
[[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0];
MPPObjectDetectorResult *expectedDetectionResult =
[[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0];
MPPImage *image = [self imageWithFileInfo:kCatsAndDogsRotatedImage
orientation:UIImageOrientationRight];
@ -446,7 +443,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
[self assertResultsOfDetectInImage:image
usingObjectDetector:objectDetector
maxResults:1
equalsObjectDetectionResult:expectedDetectionResult];
equalsObjectDetectorResult:expectedDetectionResult];
}
#pragma mark Running Mode Tests
@ -613,15 +610,15 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
MPPImage *image = [self imageWithFileInfo:kCatsAndDogsImage];
for (int i = 0; i < 3; i++) {
MPPObjectDetectionResult *objectDetectionResult = [objectDetector detectInVideoFrame:image
timestampInMilliseconds:i
error:nil];
MPPObjectDetectorResult *ObjectDetectorResult = [objectDetector detectInVideoFrame:image
timestampInMilliseconds:i
error:nil];
[self assertObjectDetectionResult:objectDetectionResult
isEqualToExpectedResult:
[MPPObjectDetectorTests
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:i]
expectedDetectionsCount:maxResults];
[self assertObjectDetectorResult:ObjectDetectorResult
isEqualToExpectedResult:
[MPPObjectDetectorTests
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:i]
expectedDetectionsCount:maxResults];
}
}
@ -676,15 +673,15 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
// Because of flow limiting, we cannot ensure that the callback will be
// invoked `iterationCount` times.
// An normal expectation will fail if expectation.fullfill() is not called
// An normal expectation will fail if expectation.fulfill() is not called
// `expectation.expectedFulfillmentCount` times.
// If `expectation.isInverted = true`, the test will only succeed if
// expectation is not fullfilled for the specified `expectedFulfillmentCount`.
// expectation is not fulfilled for the specified `expectedFulfillmentCount`.
// Since in our case we cannot predict how many times the expectation is
// supposed to be fullfilled setting,
// supposed to be fulfilled setting,
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds if
// expectation is fullfilled <= `iterationCount` times.
// expectation is fulfilled <= `iterationCount` times.
XCTestExpectation *expectation = [[XCTestExpectation alloc]
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
expectation.expectedFulfillmentCount = iterationCount + 1;
@ -714,16 +711,16 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
#pragma mark MPPObjectDetectorLiveStreamDelegate Methods
- (void)objectDetector:(MPPObjectDetector *)objectDetector
didFinishDetectionWithResult:(MPPObjectDetectionResult *)objectDetectionResult
didFinishDetectionWithResult:(MPPObjectDetectorResult *)ObjectDetectorResult
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError *)error {
NSInteger maxResults = 4;
[self assertObjectDetectionResult:objectDetectionResult
isEqualToExpectedResult:
[MPPObjectDetectorTests
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
timestampInMilliseconds]
expectedDetectionsCount:maxResults];
[self assertObjectDetectorResult:ObjectDetectorResult
isEqualToExpectedResult:
[MPPObjectDetectorTests
expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:
timestampInMilliseconds]
expectedDetectionsCount:maxResults];
if (objectDetector == outOfOrderTimestampTestDict[kLiveStreamTestsDictObjectDetectorKey]) {
[outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];

View File

@ -64,4 +64,3 @@ objc_library(
"//mediapipe/tasks/ios/vision/gesture_recognizer/utils:MPPGestureRecognizerResultHelpers",
],
)

View File

@ -87,7 +87,7 @@ NS_SWIFT_NAME(GestureRecognizerOptions)
gestureRecognizerLiveStreamDelegate;
/** Sets the maximum number of hands can be detected by the GestureRecognizer. */
@property(nonatomic) NSInteger numberOfHands;
@property(nonatomic) NSInteger numberOfHands NS_SWIFT_NAME(numHands);
/** Sets minimum confidence score for the hand detection to be considered successful */
@property(nonatomic) float minHandDetectionConfidence;

View File

@ -31,7 +31,8 @@
MPPGestureRecognizerOptions *gestureRecognizerOptions = [super copyWithZone:zone];
gestureRecognizerOptions.runningMode = self.runningMode;
gestureRecognizerOptions.gestureRecognizerLiveStreamDelegate = self.gestureRecognizerLiveStreamDelegate;
gestureRecognizerOptions.gestureRecognizerLiveStreamDelegate =
self.gestureRecognizerLiveStreamDelegate;
gestureRecognizerOptions.numberOfHands = self.numberOfHands;
gestureRecognizerOptions.minHandDetectionConfidence = self.minHandDetectionConfidence;
gestureRecognizerOptions.minHandPresenceConfidence = self.minHandPresenceConfidence;

View File

@ -18,9 +18,9 @@
- (instancetype)initWithGestures:(NSArray<NSArray<MPPCategory *> *> *)gestures
handedness:(NSArray<NSArray<MPPCategory *> *> *)handedness
landmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)landmarks
worldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)worldLandmarks
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
landmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)landmarks
worldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)worldLandmarks
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
if (self) {
_landmarks = landmarks;

View File

@ -22,7 +22,12 @@ objc_library(
hdrs = ["sources/MPPGestureRecognizerOptions+Helpers.h"],
deps = [
"//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/ios/components/processors/utils:MPPClassifierOptionsHelpers",
"//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol",

View File

@ -18,7 +18,12 @@
#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h"
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
namespace {
using CalculatorOptionsProto = mediapipe::CalculatorOptions;

View File

@ -17,7 +17,7 @@
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
static const int kMicroSecondsPerMilliSecond = 1000;
static const int kMicrosecondsPerMillisecond = 1000;
namespace {
using ClassificationResultProto =
@ -29,19 +29,26 @@ using ::mediapipe::Packet;
+ (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket:
(const Packet &)packet {
MPPClassificationResult *classificationResult;
// Even if packet does not validate as the expected type, you can safely access the timestamp.
NSInteger timestampInMilliSeconds =
(NSInteger)(packet.Timestamp().Value() / kMicrosecondsPerMillisecond);
if (!packet.ValidateAsType<ClassificationResultProto>().ok()) {
return nil;
// MPPClassificationResult's timestamp is populated from timestamp `ClassificationResultProto`'s
// timestamp_ms(). It is 0 since the packet can't be validated as a `ClassificationResultProto`.
return [[MPPImageClassifierResult alloc]
initWithClassificationResult:[[MPPClassificationResult alloc] initWithClassifications:@[]
timestampInMilliseconds:0]
timestampInMilliseconds:timestampInMilliSeconds];
}
classificationResult = [MPPClassificationResult
MPPClassificationResult *classificationResult = [MPPClassificationResult
classificationResultWithProto:packet.Get<ClassificationResultProto>()];
return [[MPPImageClassifierResult alloc]
initWithClassificationResult:classificationResult
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)];
kMicrosecondsPerMillisecond)];
}
@end

View File

@ -17,9 +17,9 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPObjectDetectionResult",
srcs = ["sources/MPPObjectDetectionResult.m"],
hdrs = ["sources/MPPObjectDetectionResult.h"],
name = "MPPObjectDetectorResult",
srcs = ["sources/MPPObjectDetectorResult.m"],
hdrs = ["sources/MPPObjectDetectorResult.h"],
deps = [
"//mediapipe/tasks/ios/components/containers:MPPDetection",
"//mediapipe/tasks/ios/core:MPPTaskResult",
@ -31,7 +31,7 @@ objc_library(
srcs = ["sources/MPPObjectDetectorOptions.m"],
hdrs = ["sources/MPPObjectDetectorOptions.h"],
deps = [
":MPPObjectDetectionResult",
":MPPObjectDetectorResult",
"//mediapipe/tasks/ios/core:MPPTaskOptions",
"//mediapipe/tasks/ios/vision/core:MPPRunningMode",
],
@ -47,8 +47,8 @@ objc_library(
"-x objective-c++",
],
deps = [
":MPPObjectDetectionResult",
":MPPObjectDetectorOptions",
":MPPObjectDetectorResult",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
@ -56,7 +56,7 @@ objc_library(
"//mediapipe/tasks/ios/vision/core:MPPImage",
"//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
"//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner",
"//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectionResultHelpers",
"//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectorOptionsHelpers",
"//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectorResultHelpers",
],
)

View File

@ -15,8 +15,8 @@
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h"
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
NS_ASSUME_NONNULL_BEGIN
@ -109,14 +109,13 @@ NS_SWIFT_NAME(ObjectDetector)
* @param error An optional error parameter populated when there is an error in performing object
* detection on the input image.
*
* @return An `MPPObjectDetectionResult` object that contains a list of detections, each detection
* @return An `MPPObjectDetectorResult` object that contains a list of detections, each detection
* has a bounding box that is expressed in the unrotated input frame of reference coordinates
* system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying
* image data.
*/
- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image
error:(NSError **)error
NS_SWIFT_NAME(detect(image:));
- (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image
error:(NSError **)error NS_SWIFT_NAME(detect(image:));
/**
* Performs object detection on the provided video frame of type `MPPImage` using the whole
@ -139,14 +138,14 @@ NS_SWIFT_NAME(ObjectDetector)
* @param error An optional error parameter populated when there is an error in performing object
* detection on the input image.
*
* @return An `MPPObjectDetectionResult` object that contains a list of detections, each detection
* @return An `MPPObjectDetectorResult` object that contains a list of detections, each detection
* has a bounding box that is expressed in the unrotated input frame of reference coordinates
* system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying
* image data.
*/
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error
- (nullable MPPObjectDetectorResult *)detectInVideoFrame:(MPPImage *)image
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error
NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:));
/**

View File

@ -19,8 +19,8 @@
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h"
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.h"
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorOptions+Helpers.h"
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.h"
namespace {
using ::mediapipe::NormalizedRect;
@ -118,9 +118,9 @@ static NSString *const kTaskName = @"objectDetector";
return;
}
MPPObjectDetectionResult *result = [MPPObjectDetectionResult
objectDetectionResultWithDetectionsPacket:statusOrPackets.value()[kDetectionsStreamName
.cppString]];
MPPObjectDetectorResult *result = [MPPObjectDetectorResult
objectDetectorResultWithDetectionsPacket:statusOrPackets
.value()[kDetectionsStreamName.cppString]];
NSInteger timeStampInMilliseconds =
outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() /
@ -184,9 +184,9 @@ static NSString *const kTaskName = @"objectDetector";
return inputPacketMap;
}
- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image
regionOfInterest:(CGRect)roi
error:(NSError **)error {
- (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image
regionOfInterest:(CGRect)roi
error:(NSError **)error {
std::optional<NormalizedRect> rect =
[_visionTaskRunner normalizedRectFromRegionOfInterest:roi
imageSize:CGSizeMake(image.width, image.height)
@ -213,18 +213,18 @@ static NSString *const kTaskName = @"objectDetector";
return nil;
}
return [MPPObjectDetectionResult
objectDetectionResultWithDetectionsPacket:outputPacketMap
.value()[kDetectionsStreamName.cppString]];
return [MPPObjectDetectorResult
objectDetectorResultWithDetectionsPacket:outputPacketMap
.value()[kDetectionsStreamName.cppString]];
}
- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image error:(NSError **)error {
- (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image error:(NSError **)error {
return [self detectInImage:image regionOfInterest:CGRectZero error:error];
}
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error {
- (nullable MPPObjectDetectorResult *)detectInVideoFrame:(MPPImage *)image
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampInMilliseconds:timestampInMilliseconds
error:error];
@ -239,9 +239,9 @@ static NSString *const kTaskName = @"objectDetector";
return nil;
}
return [MPPObjectDetectionResult
objectDetectionResultWithDetectionsPacket:outputPacketMap
.value()[kDetectionsStreamName.cppString]];
return [MPPObjectDetectorResult
objectDetectorResultWithDetectionsPacket:outputPacketMap
.value()[kDetectionsStreamName.cppString]];
}
- (BOOL)detectAsyncInImage:(MPPImage *)image

View File

@ -16,7 +16,7 @@
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h"
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
NS_ASSUME_NONNULL_BEGIN
@ -44,7 +44,7 @@ NS_SWIFT_NAME(ObjectDetectorLiveStreamDelegate)
*
* @param objectDetector The object detector which performed the object detection.
* This is useful to test equality when there are multiple instances of `MPPObjectDetector`.
* @param result The `MPPObjectDetectionResult` object that contains a list of detections, each
* @param result The `MPPObjectDetectorResult` object that contains a list of detections, each
* detection has a bounding box that is expressed in the unrotated input frame of reference
* coordinates system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the
* underlying image data.
@ -54,7 +54,7 @@ NS_SWIFT_NAME(ObjectDetectorLiveStreamDelegate)
* detection on the input live stream image data.
*/
- (void)objectDetector:(MPPObjectDetector *)objectDetector
didFinishDetectionWithResult:(nullable MPPObjectDetectionResult *)result
didFinishDetectionWithResult:(nullable MPPObjectDetectorResult *)result
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(nullable NSError *)error
NS_SWIFT_NAME(objectDetector(_:didFinishDetection:timestampInMilliseconds:error:));

View File

@ -19,8 +19,8 @@
NS_ASSUME_NONNULL_BEGIN
/** Represents the detection results generated by `MPPObjectDetector`. */
NS_SWIFT_NAME(ObjectDetectionResult)
@interface MPPObjectDetectionResult : MPPTaskResult
NS_SWIFT_NAME(ObjectDetectorResult)
@interface MPPObjectDetectorResult : MPPTaskResult
/**
* The array of `MPPDetection` objects each of which has a bounding box that is expressed in the
@ -30,7 +30,7 @@ NS_SWIFT_NAME(ObjectDetectionResult)
@property(nonatomic, readonly) NSArray<MPPDetection *> *detections;
/**
* Initializes a new `MPPObjectDetectionResult` with the given array of detections and timestamp (in
* Initializes a new `MPPObjectDetectorResult` with the given array of detections and timestamp (in
* milliseconds).
*
* @param detections An array of `MPPDetection` objects each of which has a bounding box that is
@ -38,7 +38,7 @@ NS_SWIFT_NAME(ObjectDetectionResult)
* x [0,image_height)`, which are the dimensions of the underlying image data.
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
*
* @return An instance of `MPPObjectDetectionResult` initialized with the given array of detections
* @return An instance of `MPPObjectDetectorResult` initialized with the given array of detections
* and timestamp (in milliseconds).
*/
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections

View File

@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
@implementation MPPObjectDetectionResult
@implementation MPPObjectDetectorResult
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {

View File

@ -31,12 +31,12 @@ objc_library(
)
objc_library(
name = "MPPObjectDetectionResultHelpers",
srcs = ["sources/MPPObjectDetectionResult+Helpers.mm"],
hdrs = ["sources/MPPObjectDetectionResult+Helpers.h"],
name = "MPPObjectDetectorResultHelpers",
srcs = ["sources/MPPObjectDetectorResult+Helpers.mm"],
hdrs = ["sources/MPPObjectDetectorResult+Helpers.h"],
deps = [
"//mediapipe/framework:packet",
"//mediapipe/tasks/ios/components/containers/utils:MPPDetectionHelpers",
"//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetectionResult",
"//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetectorResult",
],
)

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h"
#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h"
#include "mediapipe/framework/packet.h"
@ -20,17 +20,17 @@ NS_ASSUME_NONNULL_BEGIN
static const int kMicroSecondsPerMilliSecond = 1000;
@interface MPPObjectDetectionResult (Helpers)
@interface MPPObjectDetectorResult (Helpers)
/**
* Creates an `MPPObjectDetectionResult` from a MediaPipe packet containing a
* Creates an `MPPObjectDetectorResult` from a MediaPipe packet containing a
* `std::vector<DetectionProto>`.
*
* @param packet a MediaPipe packet wrapping a `std::vector<DetectionProto>`.
*
* @return An `MPPObjectDetectionResult` object that contains a list of detections.
* @return An `MPPObjectDetectorResult` object that contains a list of detections.
*/
+ (nullable MPPObjectDetectionResult *)objectDetectionResultWithDetectionsPacket:
+ (nullable MPPObjectDetectorResult *)objectDetectorResultWithDetectionsPacket:
(const mediapipe::Packet &)packet;
@end

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.h"
#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h"
@ -21,9 +21,9 @@ using DetectionProto = ::mediapipe::Detection;
using ::mediapipe::Packet;
} // namespace
@implementation MPPObjectDetectionResult (Helpers)
@implementation MPPObjectDetectorResult (Helpers)
+ (nullable MPPObjectDetectionResult *)objectDetectionResultWithDetectionsPacket:
+ (nullable MPPObjectDetectorResult *)objectDetectorResultWithDetectionsPacket:
(const Packet &)packet {
if (!packet.ValidateAsType<std::vector<DetectionProto>>().ok()) {
return nil;
@ -37,10 +37,10 @@ using ::mediapipe::Packet;
[detections addObject:[MPPDetection detectionWithProto:detectionProto]];
}
return [[MPPObjectDetectionResult alloc]
initWithDetections:detections
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)];
return
[[MPPObjectDetectorResult alloc] initWithDetections:detections
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)];
}
@end

View File

@ -166,7 +166,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
// For 90° and 270° rotations, we need to swap width and height.
// This is due to the internal behavior of ImageToTensorCalculator, which:
// - first denormalizes the provided rect by multiplying the rect width or
// height by the image width or height, repectively.
// height by the image width or height, respectively.
// - then rotates this by denormalized rect by the provided rotation, and
// uses this for cropping,
// - then finally rotates this back.

View File

@ -115,6 +115,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
segmenterOptions.outputCategoryMask()
? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
: -1;
final int qualityScoresOutStreamIndex =
getStreamIndex.apply(outputStreams, "QUALITY_SCORES:quality_scores");
final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
// TODO: Consolidate OutputHandler and TaskRunner.
@ -128,6 +130,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
return ImageSegmenterResult.create(
Optional.empty(),
Optional.empty(),
new ArrayList<>(),
packets.get(imageOutStreamIndex).getTimestamp());
}
boolean copyImage = !segmenterOptions.resultListener().isPresent();
@ -182,9 +185,16 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
categoryMask = Optional.of(builder.build());
}
float[] qualityScores =
PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
for (float score : qualityScores) {
qualityScoresList.add(score);
}
return ImageSegmenterResult.create(
confidenceMasks,
categoryMask,
qualityScoresList,
BaseVisionTaskApi.generateResultTimestampMs(
segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
}
@ -592,8 +602,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
public abstract Builder setOutputCategoryMask(boolean value);
/**
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
* pipeline is done processing an image.
* /** Sets an optional {@link ResultListener} to receive the segmentation results when the
* graph pipeline is done processing an image.
*/
public abstract Builder setResultListener(
ResultListener<ImageSegmenterResult, MPImage> value);

View File

@ -34,19 +34,30 @@ public abstract class ImageSegmenterResult implements TaskResult {
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
* category mask, where each pixel represents the class which the pixel in the original image
* was predicted to belong to.
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Defaults
* to `1` if the model doesn't output quality scores. Each element corresponds to the score of
* the category in the model outputs.
* @param timestampMs a timestamp for this result.
*/
// TODO: consolidate output formats across platforms.
public static ImageSegmenterResult create(
Optional<List<MPImage>> confidenceMasks, Optional<MPImage> categoryMask, long timestampMs) {
Optional<List<MPImage>> confidenceMasks,
Optional<MPImage> categoryMask,
List<Float> qualityScores,
long timestampMs) {
return new AutoValue_ImageSegmenterResult(
confidenceMasks.map(Collections::unmodifiableList), categoryMask, timestampMs);
confidenceMasks.map(Collections::unmodifiableList),
categoryMask,
Collections.unmodifiableList(qualityScores),
timestampMs);
}
public abstract Optional<List<MPImage>> confidenceMasks();
public abstract Optional<MPImage> categoryMask();
public abstract List<Float> qualityScores();
@Override
public abstract long timestampMs();
}

View File

@ -127,6 +127,10 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
outputStreams.add("CATEGORY_MASK:category_mask");
}
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
outputStreams.add("QUALITY_SCORES:quality_scores");
final int qualityScoresOutStreamIndex = outputStreams.size() - 1;
outputStreams.add("IMAGE:image_out");
// TODO: add test for stream indices.
final int imageOutStreamIndex = outputStreams.size() - 1;
@ -142,6 +146,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
return ImageSegmenterResult.create(
Optional.empty(),
Optional.empty(),
new ArrayList<>(),
packets.get(imageOutStreamIndex).getTimestamp());
}
// If resultListener is not provided, the resulted MPImage is deep copied from
@ -199,9 +204,17 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
categoryMask = Optional.of(builder.build());
}
float[] qualityScores =
PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
for (float score : qualityScores) {
qualityScoresList.add(score);
}
return ImageSegmenterResult.create(
confidenceMasks,
categoryMask,
qualityScoresList,
BaseVisionTaskApi.generateResultTimestampMs(
RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
}

View File

@ -201,6 +201,7 @@ py_test(
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
],
tags = ["not_run:arm"],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/containers:rect",

View File

@ -27,13 +27,14 @@ export declare interface Detection {
boundingBox?: BoundingBox;
/**
* Optional list of keypoints associated with the detection. Keypoints
* represent interesting points related to the detection. For example, the
* keypoints represent the eye, ear and mouth from face detection model. Or
* in the template matching detection, e.g. KNIFT, they can represent the
* feature points for template matching.
* List of keypoints associated with the detection. Keypoints represent
* interesting points related to the detection. For example, the keypoints
* represent the eye, ear and mouth from face detection model. Or in the
* template matching detection, e.g. KNIFT, they can represent the feature
* points for template matching. Contains an empty list if no keypoints are
* detected.
*/
keypoints?: NormalizedKeypoint[];
keypoints: NormalizedKeypoint[];
}
/** Detection results of a model. */

View File

@ -85,7 +85,8 @@ describe('convertFromDetectionProto()', () => {
categoryName: '',
displayName: '',
}],
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
keypoints: []
});
});
});

View File

@ -26,7 +26,7 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
const labels = source.getLabelList();
const displayNames = source.getDisplayNameList();
const detection: Detection = {categories: []};
const detection: Detection = {categories: [], keypoints: []};
for (let i = 0; i < scores.length; i++) {
detection.categories.push({
score: scores[i],
@ -47,7 +47,6 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
}
if (source.getLocationData()?.getRelativeKeypointsList().length) {
detection.keypoints = [];
for (const keypoint of
source.getLocationData()!.getRelativeKeypointsList()) {
detection.keypoints.push({

View File

@ -62,7 +62,10 @@ jasmine_node_test(
mediapipe_ts_library(
name = "mask",
srcs = ["mask.ts"],
deps = [":image"],
deps = [
":image",
"//mediapipe/web/graph_runner:platform_utils",
],
)
mediapipe_ts_library(

View File

@ -60,6 +60,10 @@ class MPImageTestContext {
this.webGLTexture = gl.createTexture()!;
gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
gl.texImage2D(
gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, this.imageBitmap);
gl.bindTexture(gl.TEXTURE_2D, null);

View File

@ -187,10 +187,11 @@ export class MPImage {
destinationContainer =
assertNotNull(gl.createTexture(), 'Failed to create texture');
gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
this.configureTextureParams();
gl.texImage2D(
gl.TEXTURE_2D, 0, gl.RGBA, this.width, this.height, 0, gl.RGBA,
gl.UNSIGNED_BYTE, null);
gl.bindTexture(gl.TEXTURE_2D, null);
shaderContext.bindFramebuffer(gl, destinationContainer);
shaderContext.run(gl, /* flipVertically= */ false, () => {
@ -302,6 +303,20 @@ export class MPImage {
return webGLTexture;
}
/** Sets texture params for the currently bound texture. */
private configureTextureParams() {
const gl = this.getGL();
// `gl.LINEAR` might break rendering for some textures, but it allows us to
// do smooth resizing. Ideally, this would be user-configurable, but for now
// we hard-code the value here to `gl.LINEAR` (versus `gl.NEAREST` for
// `MPMask` where we do not want to interpolate mask values, especially for
// category masks).
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
}
/**
* Binds the backing texture to the canvas. If the texture does not yet
* exist, creates it first.
@ -318,16 +333,12 @@ export class MPImage {
assertNotNull(gl.createTexture(), 'Failed to create texture');
this.containers.push(webGLTexture);
this.ownsWebGLTexture = true;
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
this.configureTextureParams();
} else {
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
}
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
// TODO: Ideally, we would only set these once per texture and
// not once every frame.
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
return webGLTexture;
}

View File

@ -60,8 +60,11 @@ class MPMaskTestContext {
}
this.webGLTexture = gl.createTexture()!;
gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texImage2D(
gl.TEXTURE_2D, 0, gl.R32F, width, height, 0, gl.RED, gl.FLOAT,
new Float32Array(pixels).map(v => v / 255));

View File

@ -15,6 +15,7 @@
*/
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
import {isIOS} from '../../../../web/graph_runner/platform_utils';
/** Number of instances a user can keep alive before we raise a warning. */
const INSTANCE_COUNT_WARNING_THRESHOLD = 250;
@ -32,6 +33,8 @@ enum MPMaskType {
/** The supported mask formats. For internal usage. */
export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
/**
* The wrapper class for MediaPipe segmentation masks.
*
@ -56,6 +59,9 @@ export class MPMask {
*/
private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD;
/** The format used to write pixel values from textures. */
private static texImage2DFormat?: GLenum;
/** @hideconstructor */
constructor(
private readonly containers: MPMaskContainer[],
@ -127,6 +133,29 @@ export class MPMask {
return this.convertToWebGLTexture();
}
/**
* Returns the texture format used for writing float textures on this
* platform.
*/
getTexImage2DFormat(): GLenum {
const gl = this.getGL();
if (!MPMask.texImage2DFormat) {
// Note: This is the same check we use in
// `SegmentationPostprocessorGl::GetSegmentationResultGpu()`.
if (gl.getExtension('EXT_color_buffer_float') &&
gl.getExtension('OES_texture_float_linear') &&
gl.getExtension('EXT_float_blend')) {
MPMask.texImage2DFormat = gl.R32F;
} else if (gl.getExtension('EXT_color_buffer_half_float')) {
MPMask.texImage2DFormat = gl.R16F;
} else {
throw new Error(
'GPU does not fully support 4-channel float32 or float16 formats');
}
}
return MPMask.texImage2DFormat;
}
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
@ -175,8 +204,10 @@ export class MPMask {
destinationContainer =
assertNotNull(gl.createTexture(), 'Failed to create texture');
gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
this.configureTextureParams();
const format = this.getTexImage2DFormat();
gl.texImage2D(
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
gl.FLOAT, null);
gl.bindTexture(gl.TEXTURE_2D, null);
@ -207,7 +238,7 @@ export class MPMask {
if (!this.canvas) {
throw new Error(
'Conversion to different image formats require that a canvas ' +
'is passed when iniitializing the image.');
'is passed when initializing the image.');
}
if (!this.gl) {
this.gl = assertNotNull(
@ -215,11 +246,6 @@ export class MPMask {
'You cannot use a canvas that is already bound to a different ' +
'type of rendering context.');
}
const ext = this.gl.getExtension('EXT_color_buffer_float');
if (!ext) {
// TODO: Ensure this works on iOS
throw new Error('Missing required EXT_color_buffer_float extension');
}
return this.gl;
}
@ -237,18 +263,34 @@ export class MPMask {
if (uint8Array) {
float32Array = new Float32Array(uint8Array).map(v => v / 255);
} else {
float32Array = new Float32Array(this.width * this.height);
const gl = this.getGL();
const shaderContext = this.getShaderContext();
float32Array = new Float32Array(this.width * this.height);
// Create texture if needed
const webGlTexture = this.convertToWebGLTexture();
// Create a framebuffer from the texture and read back pixels
shaderContext.bindFramebuffer(gl, webGlTexture);
gl.readPixels(
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
shaderContext.unbindFramebuffer();
if (isIOS()) {
// WebKit on iOS only supports gl.HALF_FLOAT for single channel reads
// (as tested on iOS 16.4). HALF_FLOAT requires reading data into a
// Uint16Array, however, and requires a manual bitwise conversion from
// Uint16 to floating point numbers. This conversion is more expensive
// that reading back a Float32Array from the RGBA image and dropping
// the superfluous data, so we do this instead.
const outputArray = new Float32Array(this.width * this.height * 4);
gl.readPixels(
0, 0, this.width, this.height, gl.RGBA, gl.FLOAT, outputArray);
for (let i = 0, j = 0; i < float32Array.length; ++i, j += 4) {
float32Array[i] = outputArray[j];
}
} else {
gl.readPixels(
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
}
}
this.containers.push(float32Array);
}
@ -273,9 +315,9 @@ export class MPMask {
webGLTexture = this.bindTexture();
const data = this.convertToFloat32Array();
// TODO: Add support for R16F to support iOS
const format = this.getTexImage2DFormat();
gl.texImage2D(
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
gl.FLOAT, data);
this.unbindTexture();
}
@ -283,6 +325,19 @@ export class MPMask {
return webGLTexture;
}
/** Sets texture params for the currently bound texture. */
private configureTextureParams() {
const gl = this.getGL();
// `gl.NEAREST` ensures that we do not get interpolated values for
// masks. In some cases, the user might want interpolation (e.g. for
// confidence masks), so we might want to make this user-configurable.
// Note that `MPImage` uses `gl.LINEAR`.
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
}
/**
* Binds the backing texture to the canvas. If the texture does not yet
* exist, creates it first.
@ -299,15 +354,12 @@ export class MPMask {
assertNotNull(gl.createTexture(), 'Failed to create texture');
this.containers.push(webGLTexture);
this.ownsWebGLTexture = true;
}
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
// TODO: Ideally, we would only set these once per texture and
// not once every frame.
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
this.configureTextureParams();
} else {
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
}
return webGLTexture;
}

View File

@ -191,7 +191,8 @@ describe('FaceDetector', () => {
categoryName: '',
displayName: '',
}],
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
keypoints: []
});
});
});

View File

@ -171,7 +171,7 @@ export class FaceStylizer extends VisionTaskRunner {
/**
* Performs face stylization on the provided single image and returns the
* result. This method creates a copy of the resulting image and should not be
* used in high-throughput applictions. Only use this method when the
* used in high-throughput applications. Only use this method when the
* FaceStylizer is created with the image running mode.
*
* @param image An image to process.
@ -182,7 +182,7 @@ export class FaceStylizer extends VisionTaskRunner {
/**
* Performs face stylization on the provided single image and returns the
* result. This method creates a copy of the resulting image and should not be
* used in high-throughput applictions. Only use this method when the
* used in high-throughput applications. Only use this method when the
* FaceStylizer is created with the image running mode.
*
* The 'imageProcessingOptions' parameter can be used to specify one or all
@ -275,7 +275,7 @@ export class FaceStylizer extends VisionTaskRunner {
/**
* Performs face stylization on the provided video frame. This method creates
* a copy of the resulting image and should not be used in high-throughput
* applictions. Only use this method when the FaceStylizer is created with the
* applications. Only use this method when the FaceStylizer is created with the
* video running mode.
*
* The input frame can be of any size. It's required to provide the video

View File

@ -39,6 +39,7 @@ const IMAGE_STREAM = 'image_in';
const NORM_RECT_STREAM = 'norm_rect';
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
const CATEGORY_MASK_STREAM = 'category_mask';
const QUALITY_SCORES_STREAM = 'quality_scores';
const IMAGE_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
@ -61,6 +62,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
export class ImageSegmenter extends VisionTaskRunner {
private categoryMask?: MPMask;
private confidenceMasks?: MPMask[];
private qualityScores?: number[];
private labels: string[] = [];
private userCallback?: ImageSegmenterCallback;
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
@ -229,7 +231,7 @@ export class ImageSegmenter extends VisionTaskRunner {
/**
* Performs image segmentation on the provided single image and returns the
* segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-throughput applictions. Only use this method
* should not be used in high-throughput applications. Only use this method
* when the ImageSegmenter is created with running mode `image`.
*
* @param image An image to process.
@ -240,7 +242,7 @@ export class ImageSegmenter extends VisionTaskRunner {
/**
* Performs image segmentation on the provided single image and returns the
* segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-v applictions. Only use this method when
* should not be used in high-v applications. Only use this method when
* the ImageSegmenter is created with running mode `image`.
*
* @param image An image to process.
@ -318,7 +320,7 @@ export class ImageSegmenter extends VisionTaskRunner {
/**
* Performs image segmentation on the provided video frame and returns the
* segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-v applictions. Only use this method when
* should not be used in high-v applications. Only use this method when
* the ImageSegmenter is created with running mode `video`.
*
* @param videoFrame A video frame to process.
@ -367,12 +369,13 @@ export class ImageSegmenter extends VisionTaskRunner {
private reset(): void {
this.categoryMask = undefined;
this.confidenceMasks = undefined;
this.qualityScores = undefined;
}
private processResults(): ImageSegmenterResult|void {
try {
const result =
new ImageSegmenterResult(this.confidenceMasks, this.categoryMask);
const result = new ImageSegmenterResult(
this.confidenceMasks, this.categoryMask, this.qualityScores);
if (this.userCallback) {
this.userCallback(result);
} else {
@ -442,6 +445,20 @@ export class ImageSegmenter extends VisionTaskRunner {
});
}
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
this.graphRunner.attachFloatVectorListener(
QUALITY_SCORES_STREAM, (scores, timestamp) => {
this.qualityScores = scores;
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
QUALITY_SCORES_STREAM, timestamp => {
this.categoryMask = undefined;
this.setLatestOutputTimestamp(timestamp);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}

View File

@ -30,7 +30,13 @@ export class ImageSegmenterResult {
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
* which the pixel in the original image was predicted to belong to.
*/
readonly categoryMask?: MPMask) {}
readonly categoryMask?: MPMask,
/**
* The quality scores of the result masks, in the range of [0, 1].
* Defaults to `1` if the model doesn't output quality scores. Each
* element corresponds to the score of the category in the model outputs.
*/
readonly qualityScores?: number[]) {}
/** Frees the resources held by the category and confidence masks. */
close(): void {

View File

@ -35,6 +35,8 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
((images: WasmImage, timestamp: number) => void)|undefined;
confidenceMasksListener:
((images: WasmImage[], timestamp: number) => void)|undefined;
qualityScoresListener:
((data: number[], timestamp: number) => void)|undefined;
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
@ -52,6 +54,12 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
expect(stream).toEqual('confidence_masks');
this.confidenceMasksListener = listener;
});
this.attachListenerSpies[2] =
spyOn(this.graphRunner, 'attachFloatVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('quality_scores');
this.qualityScoresListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
@ -266,6 +274,7 @@ describe('ImageSegmenter', () => {
it('invokes listener after masks are available', async () => {
const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]);
const qualityScores = [1.0];
let listenerCalled = false;
await imageSegmenter.setOptions(
@ -283,11 +292,16 @@ describe('ImageSegmenter', () => {
],
1337);
expect(listenerCalled).toBeFalse();
imageSegmenter.qualityScoresListener!(qualityScores, 1337);
expect(listenerCalled).toBeFalse();
});
return new Promise<void>(resolve => {
imageSegmenter.segment({} as HTMLImageElement, () => {
imageSegmenter.segment({} as HTMLImageElement, result => {
listenerCalled = true;
expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.qualityScores).toEqual(qualityScores);
resolve();
});
});

View File

@ -42,6 +42,7 @@ const NORM_RECT_IN_STREAM = 'norm_rect_in';
const ROI_IN_STREAM = 'roi_in';
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
const CATEGORY_MASK_STREAM = 'category_mask';
const QUALITY_SCORES_STREAM = 'quality_scores';
const IMAGEA_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
@ -86,6 +87,7 @@ export type InteractiveSegmenterCallback =
export class InteractiveSegmenter extends VisionTaskRunner {
private categoryMask?: MPMask;
private confidenceMasks?: MPMask[];
private qualityScores?: number[];
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
private userCallback?: InteractiveSegmenterCallback;
@ -284,12 +286,13 @@ export class InteractiveSegmenter extends VisionTaskRunner {
private reset(): void {
this.confidenceMasks = undefined;
this.categoryMask = undefined;
this.qualityScores = undefined;
}
private processResults(): InteractiveSegmenterResult|void {
try {
const result = new InteractiveSegmenterResult(
this.confidenceMasks, this.categoryMask);
this.confidenceMasks, this.categoryMask, this.qualityScores);
if (this.userCallback) {
this.userCallback(result);
} else {
@ -361,6 +364,20 @@ export class InteractiveSegmenter extends VisionTaskRunner {
});
}
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
this.graphRunner.attachFloatVectorListener(
QUALITY_SCORES_STREAM, (scores, timestamp) => {
this.qualityScores = scores;
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
QUALITY_SCORES_STREAM, timestamp => {
this.categoryMask = undefined;
this.setLatestOutputTimestamp(timestamp);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}

View File

@ -30,7 +30,13 @@ export class InteractiveSegmenterResult {
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
* which the pixel in the original image was predicted to belong to.
*/
readonly categoryMask?: MPMask) {}
readonly categoryMask?: MPMask,
/**
* The quality scores of the result masks, in the range of [0, 1].
* Defaults to `1` if the model doesn't output quality scores. Each
* element corresponds to the score of the category in the model outputs.
*/
readonly qualityScores?: number[]) {}
/** Frees the resources held by the category and confidence masks. */
close(): void {

View File

@ -46,6 +46,8 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
((images: WasmImage, timestamp: number) => void)|undefined;
confidenceMasksListener:
((images: WasmImage[], timestamp: number) => void)|undefined;
qualityScoresListener:
((data: number[], timestamp: number) => void)|undefined;
lastRoi?: RenderDataProto;
constructor() {
@ -64,6 +66,12 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
expect(stream).toEqual('confidence_masks');
this.confidenceMasksListener = listener;
});
this.attachListenerSpies[2] =
spyOn(this.graphRunner, 'attachFloatVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('quality_scores');
this.qualityScoresListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
@ -277,9 +285,10 @@ describe('InteractiveSegmenter', () => {
});
});
it('invokes listener after masks are avaiblae', async () => {
it('invokes listener after masks are available', async () => {
const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]);
const qualityScores = [1.0];
let listenerCalled = false;
await interactiveSegmenter.setOptions(
@ -297,11 +306,16 @@ describe('InteractiveSegmenter', () => {
],
1337);
expect(listenerCalled).toBeFalse();
interactiveSegmenter.qualityScoresListener!(qualityScores, 1337);
expect(listenerCalled).toBeFalse();
});
return new Promise<void>(resolve => {
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => {
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
listenerCalled = true;
expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.qualityScores).toEqual(qualityScores);
resolve();
});
});

View File

@ -210,7 +210,8 @@ describe('ObjectDetector', () => {
categoryName: '',
displayName: '',
}],
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
keypoints: []
});
});
});

View File

@ -21,3 +21,16 @@ export function isWebKit(browser = navigator) {
// it uses "CriOS".
return userAgent.includes('Safari') && !userAgent.includes('Chrome');
}
/** Detect if code is running on iOS. */
export function isIOS() {
// Source:
// https://stackoverflow.com/questions/9038625/detect-if-device-is-ios
return [
'iPad Simulator', 'iPhone Simulator', 'iPod Simulator', 'iPad', 'iPhone',
'iPod'
// tslint:disable-next-line:deprecation
].includes(navigator.platform)
// iPad on iOS 13 detection
|| (navigator.userAgent.includes('Mac') && 'ontouchend' in document);
}

View File

@ -357,7 +357,10 @@ class BuildExtension(build_ext.build_ext):
for ext in self.extensions:
target_name = self.get_ext_fullpath(ext.name)
# Build x86
self._build_binary(ext)
self._build_binary(
ext,
['--cpu=darwin', '--ios_multi_cpus=i386,x86_64,armv7,arm64'],
)
x86_name = self.get_ext_fullpath(ext.name)
# Build Arm64
ext.name = ext.name + '.arm64'

View File

@ -42,16 +42,15 @@ filegroup(
"include/flatbuffers/allocator.h",
"include/flatbuffers/array.h",
"include/flatbuffers/base.h",
"include/flatbuffers/bfbs_generator.h",
"include/flatbuffers/buffer.h",
"include/flatbuffers/buffer_ref.h",
"include/flatbuffers/code_generator.h",
"include/flatbuffers/code_generators.h",
"include/flatbuffers/default_allocator.h",
"include/flatbuffers/detached_buffer.h",
"include/flatbuffers/file_manager.h",
"include/flatbuffers/flatbuffer_builder.h",
"include/flatbuffers/flatbuffers.h",
"include/flatbuffers/flatc.h",
"include/flatbuffers/flex_flat_util.h",
"include/flatbuffers/flexbuffers.h",
"include/flatbuffers/grpc.h",

View File

@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
def repo():
third_party_http_archive(
name = "flatbuffers",
strip_prefix = "flatbuffers-23.1.21",
sha256 = "d84cb25686514348e615163b458ae0767001b24b42325f426fd56406fd384238",
strip_prefix = "flatbuffers-23.5.8",
sha256 = "55b75dfa5b6f6173e4abf9c35284a10482ba65db886b39db511eba6c244f1e88",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
"https://github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
"https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz",
"https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz",
],
build_file = "//third_party/flatbuffers:BUILD.bazel",
delete = ["build_defs.bzl", "BUILD.bazel"],