Merge branch 'google:master' into pose-landmarker-python

This commit is contained in:
Kinar R 2023-04-19 10:21:23 +05:30 committed by GitHub
commit 39742b6641
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
98 changed files with 5096 additions and 974 deletions

View File

@ -78,7 +78,7 @@ class ConstantSidePacketCalculator : public CalculatorBase {
} else if (packet_options.has_string_value()) {
packet.Set<std::string>();
} else if (packet_options.has_uint64_value()) {
packet.Set<uint64>();
packet.Set<uint64_t>();
} else if (packet_options.has_classification_list_value()) {
packet.Set<ClassificationList>();
} else if (packet_options.has_landmark_list_value()) {
@ -112,7 +112,7 @@ class ConstantSidePacketCalculator : public CalculatorBase {
} else if (packet_options.has_string_value()) {
packet.Set(MakePacket<std::string>(packet_options.string_value()));
} else if (packet_options.has_uint64_value()) {
packet.Set(MakePacket<uint64>(packet_options.uint64_value()));
packet.Set(MakePacket<uint64_t>(packet_options.uint64_value()));
} else if (packet_options.has_classification_list_value()) {
packet.Set(MakePacket<ClassificationList>(
packet_options.classification_list_value()));

View File

@ -35,14 +35,14 @@ class GateCalculatorTest : public ::testing::Test {
}
// Use this when ALLOW/DISALLOW input is provided as a side packet.
void RunTimeStep(int64 timestamp, bool stream_payload) {
void RunTimeStep(int64_t timestamp, bool stream_payload) {
runner_->MutableInputs()->Get("", 0).packets.push_back(
MakePacket<bool>(stream_payload).At(Timestamp(timestamp)));
MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
}
// Use this when ALLOW/DISALLOW input is provided as an input stream.
void RunTimeStep(int64 timestamp, const std::string& control_tag,
void RunTimeStep(int64_t timestamp, const std::string& control_tag,
bool control) {
runner_->MutableInputs()->Get("", 0).packets.push_back(
MakePacket<bool>(true).At(Timestamp(timestamp)));
@ -134,9 +134,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWOptionToTrue) {
}
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -159,9 +159,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionSetToFalse) {
}
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -175,9 +175,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionNotSet) {
output_stream: "test_output"
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -193,9 +193,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) {
)");
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true));
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -215,9 +215,9 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) {
)");
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false));
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -237,9 +237,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) {
)");
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false));
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -255,9 +255,9 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) {
)");
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true));
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -272,13 +272,13 @@ TEST_F(GateCalculatorTest, Allow) {
output_stream: "test_output"
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "ALLOW", false);
constexpr int64 kTimestampValue2 = 44;
constexpr int64_t kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "ALLOW", true);
constexpr int64 kTimestampValue3 = 45;
constexpr int64_t kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "ALLOW", false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -297,13 +297,13 @@ TEST_F(GateCalculatorTest, Disallow) {
output_stream: "test_output"
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "DISALLOW", true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "DISALLOW", false);
constexpr int64 kTimestampValue2 = 44;
constexpr int64_t kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "DISALLOW", true);
constexpr int64 kTimestampValue3 = 45;
constexpr int64_t kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "DISALLOW", false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -323,13 +323,13 @@ TEST_F(GateCalculatorTest, AllowWithStateChange) {
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", false);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "ALLOW", true);
constexpr int64 kTimestampValue2 = 44;
constexpr int64_t kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "ALLOW", true);
constexpr int64 kTimestampValue3 = 45;
constexpr int64_t kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "ALLOW", false);
const std::vector<Packet>& output =
@ -379,13 +379,13 @@ TEST_F(GateCalculatorTest, DisallowWithStateChange) {
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "DISALLOW", true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "DISALLOW", false);
constexpr int64 kTimestampValue2 = 44;
constexpr int64_t kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "DISALLOW", false);
constexpr int64 kTimestampValue3 = 45;
constexpr int64_t kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "DISALLOW", true);
const std::vector<Packet>& output =
@ -432,7 +432,7 @@ TEST_F(GateCalculatorTest, DisallowInitialNoStateTransition) {
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "DISALLOW", false);
const std::vector<Packet>& output =
@ -450,7 +450,7 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) {
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", true);
const std::vector<Packet>& output =

View File

@ -64,16 +64,16 @@ REGISTER_CALCULATOR(StringToIntCalculator);
using StringToUintCalculator = StringToIntCalculatorTemplate<unsigned int>;
REGISTER_CALCULATOR(StringToUintCalculator);
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32>;
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32_t>;
REGISTER_CALCULATOR(StringToInt32Calculator);
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32>;
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32_t>;
REGISTER_CALCULATOR(StringToUint32Calculator);
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64>;
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64_t>;
REGISTER_CALCULATOR(StringToInt64Calculator);
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64>;
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64_t>;
REGISTER_CALCULATOR(StringToUint64Calculator);
} // namespace mediapipe

View File

@ -166,7 +166,7 @@ class WarpAffineRunnerHolder<mediapipe::Image> {
const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(),
frame_ptr->Height(), frame_ptr->WidthStep(),
const_cast<uint8_t*>(frame_ptr->PixelData()),
[](uint8* data){});
[](uint8_t* data){});
ASSIGN_OR_RETURN(auto result,
runner->Run(image_frame, matrix, size, border_mode));
return mediapipe::Image(std::make_shared<ImageFrame>(std::move(result)));

View File

@ -131,9 +131,9 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler {
ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) {
// Record the most recent first kept timestamp on any stream.
for (const auto& stream : input_stream_managers_) {
int32 queue_size = (stream->QueueSize() >= trigger_queue_size_)
? target_queue_size_
: trigger_queue_size_ - 1;
int32_t queue_size = (stream->QueueSize() >= trigger_queue_size_)
? target_queue_size_
: trigger_queue_size_ - 1;
if (stream->QueueSize() > queue_size) {
kept_timestamp_ = std::max(
kept_timestamp_, stream->GetMinTimestampAmongNLatest(queue_size + 1)
@ -214,8 +214,8 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler {
}
private:
int32 trigger_queue_size_;
int32 target_queue_size_;
int32_t trigger_queue_size_;
int32_t target_queue_size_;
bool fixed_min_size_;
// Indicates that GetNodeReadiness has returned kReadyForProcess once, and
// the corresponding call to FillInputSet has not yet completed.

View File

@ -67,53 +67,14 @@ absl::Status GlContext::CreateContextInternal(
// TODO: Investigate this option in more detail, esp. on Safari.
attrs.preserveDrawingBuffer = 0;
// Since the Emscripten canvas target finding function is visible from here,
// we hijack findCanvasEventTarget directly for enforcing old Module.canvas
// behavior if the user desires, falling back to the new DOM element CSS
// selector behavior next if that is specified, and finally just allowing the
// lookup to proceed on a null target.
// TODO: Ensure this works with all options (in particular,
// multithreading options, like the special-case combination of USE_PTHREADS
// and OFFSCREEN_FRAMEBUFFER)
// clang-format off
EM_ASM(
let init_once = true;
if (init_once) {
const cachedFindCanvasEventTarget = findCanvasEventTarget;
if (typeof cachedFindCanvasEventTarget !== 'function') {
if (typeof console !== 'undefined') {
console.error('Expected Emscripten global function '
+ '"findCanvasEventTarget" not found. WebGL context creation '
+ 'may fail.');
}
return;
}
findCanvasEventTarget = function(target) {
if (target == 0) {
if (Module && Module.canvas) {
return Module.canvas;
} else if (Module && Module.canvasCssSelector) {
return cachedFindCanvasEventTarget(Module.canvasCssSelector);
}
if (typeof console !== 'undefined') {
console.warn('Module properties canvas and canvasCssSelector not ' +
'found during WebGL context creation.');
}
}
// We still go through with the find attempt, although for most use
// cases it will not succeed, just in case the user does want to fall-
// back.
return cachedFindCanvasEventTarget(target);
}; // NOLINT: Necessary semicolon.
init_once = false;
}
);
// clang-format on
// Quick patch for -s DISABLE_DEPRECATED_FIND_EVENT_TARGET_BEHAVIOR so it also
// looks for our #canvas target in Module.canvas, where we expect it to be.
// -s OFFSCREENCANVAS_SUPPORT=1 will no longer work with this under the new
// event target behavior, but it was never supposed to be tapping into our
// canvas anyways. See b/278155946 for more background.
EM_ASM({ specialHTMLTargets["#canvas"] = Module.canvas; });
EMSCRIPTEN_WEBGL_CONTEXT_HANDLE context_handle =
emscripten_webgl_create_context(nullptr, &attrs);
emscripten_webgl_create_context("#canvas", &attrs);
// Check for failure
if (context_handle <= 0) {

View File

@ -64,7 +64,7 @@ std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Create(
int actual_ws = image_frame.WidthStep();
int alignment = 0;
std::unique_ptr<ImageFrame> temp;
const uint8* data = image_frame.PixelData();
const uint8_t* data = image_frame.PixelData();
// Let's see if the pixel data is tightly aligned to one of the alignments
// supported by OpenGL, preferring 4 if possible since it's the default.

View File

@ -167,7 +167,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
GpuBufferFormat format) {
libyuv::FourCC fourcc = FourCCForGpuBufferFormat(format);
int y_stride = std::ceil(1.0f * width / kDefaultDataAligment);
auto y_data = std::make_unique<uint8[]>(y_stride * height);
auto y_data = std::make_unique<uint8_t[]>(y_stride * height);
switch (fourcc) {
case libyuv::FOURCC_NV12:
case libyuv::FOURCC_NV21: {
@ -175,7 +175,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
int uv_width = 2 * std::ceil(0.5f * width);
int uv_height = std::ceil(0.5f * height);
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
auto uv_data = std::make_unique<uint8[]>(uv_stride * uv_height);
auto uv_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
yuv_image_ = std::make_shared<YUVImage>(
fourcc, std::move(y_data), y_stride, std::move(uv_data), uv_stride,
nullptr, 0, width, height);
@ -187,8 +187,8 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
int uv_width = std::ceil(0.5f * width);
int uv_height = std::ceil(0.5f * height);
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
auto u_data = std::make_unique<uint8[]>(uv_stride * uv_height);
auto v_data = std::make_unique<uint8[]>(uv_stride * uv_height);
auto u_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
auto v_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
yuv_image_ = std::make_shared<YUVImage>(
fourcc, std::move(y_data), y_stride, std::move(u_data), uv_stride,
std::move(v_data), uv_stride, width, height);

View File

@ -16,6 +16,7 @@ import csv
import filecmp
import os
import tempfile
import unittest
from unittest import mock as unittest_mock
import tensorflow as tf
@ -24,6 +25,7 @@ from mediapipe.model_maker.python.text import text_classifier
from mediapipe.tasks.python.test import test_utils
@unittest.skip('b/275624089')
class TextClassifierTest(tf.test.TestCase):
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (

View File

@ -175,11 +175,7 @@ py_test(
data = [":testdata"],
tags = ["requires-net:external"],
deps = [
":dataset",
":hyperparameters",
":model_spec",
":object_detector",
":object_detector_options",
":object_detector_import",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -19,11 +19,7 @@ from unittest import mock as unittest_mock
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.vision.object_detector import dataset
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
from mediapipe.model_maker.python.vision.object_detector import object_detector
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
from mediapipe.model_maker.python.vision import object_detector
from mediapipe.tasks.python.test import test_utils as task_test_utils
@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
super().setUp()
dataset_folder = task_test_utils.get_test_data_path('coco_data')
cache_dir = self.create_tempdir()
self.data = dataset.Dataset.from_coco_folder(
self.data = object_detector.Dataset.from_coco_folder(
dataset_folder, cache_dir=cache_dir
)
# Mock tempfile.gettempdir() to be unique for each test to avoid race
@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.addCleanup(mock_gettempdir.stop)
def test_object_detector(self):
hparams = hyperparameters.HParams(
hparams = object_detector.HParams(
epochs=1,
batch_size=2,
learning_rate=0.9,
shuffle=False,
export_dir=self.create_tempdir(),
)
options = object_detector_options.ObjectDetectorOptions(
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams
options = object_detector.ObjectDetectorOptions(
supported_model=object_detector.SupportedModels.MOBILENET_V2,
hparams=hparams,
)
# Test `create``
model = object_detector.ObjectDetector.create(
@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreater(os.path.getsize(output_metadata_file), 0)
# Test `quantization_aware_training`
qat_hparams = hyperparameters.QATHParams(
qat_hparams = object_detector.QATHParams(
learning_rate=0.9,
batch_size=2,
epochs=1,

View File

@ -24,8 +24,8 @@ namespace mediapipe {
void FrameAnnotationTracker::AddDetectionResult(
const FrameAnnotation& frame_annotation) {
const int64 time_us =
static_cast<int64>(std::round(frame_annotation.timestamp()));
const int64_t time_us =
static_cast<int64_t>(std::round(frame_annotation.timestamp()));
for (const auto& object_annotation : frame_annotation.annotations()) {
detected_objects_[time_us + object_annotation.object_id()] =
object_annotation;
@ -37,7 +37,7 @@ FrameAnnotation FrameAnnotationTracker::ConsolidateTrackingResult(
absl::flat_hash_set<int>* cancel_object_ids) {
CHECK(cancel_object_ids != nullptr);
FrameAnnotation frame_annotation;
std::vector<int64> keys_to_be_deleted;
std::vector<int64_t> keys_to_be_deleted;
for (const auto& detected_obj : detected_objects_) {
const int object_id = detected_obj.second.object_id();
if (cancel_object_ids->contains(object_id)) {

View File

@ -78,6 +78,7 @@ cc_library(
hdrs = ["mediapipe_builtin_op_resolver.h"],
deps = [
"//mediapipe/tasks/cc/text/custom_ops/ragged:ragged_tensor_to_tensor_tflite",
"//mediapipe/tasks/cc/text/custom_ops/sentencepiece:sentencepiece_tokenizer_tflite",
"//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup",
"//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash",
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
@ -51,6 +52,8 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
AddCustom("KmeansEmbeddingLookup",
mediapipe::tflite_operations::Register_KmeansEmbeddingLookup());
// For the UniversalSentenceEncoder model.
AddCustom("TFSentencepieceTokenizeOp",
mediapipe::tflite_operations::Register_SENTENCEPIECE_TOKENIZER());
AddCustom("RaggedTensorToTensor",
mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR());
}

View File

@ -18,6 +18,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
filegroup(
name = "testdata",
srcs = glob([
"testdata/**",
]),
)
filegroup(
name = "config_fbs",
srcs = ["config.fbs"],
@ -80,3 +87,86 @@ cc_test(
"//mediapipe/framework/port:gtest_main",
],
)
cc_library(
name = "sentencepiece_constants",
hdrs = ["sentencepiece_constants.h"],
)
cc_library(
name = "model_converter",
srcs = [
"model_converter.cc",
],
hdrs = [
"model_converter.h",
],
deps = [
":config",
":double_array_trie_builder",
":encoder_config",
":sentencepiece_constants",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_sentencepiece//src:sentencepiece_model_cc_proto",
],
)
cc_library(
name = "optimized_encoder",
srcs = [
"optimized_encoder.cc",
],
hdrs = [
"optimized_encoder.h",
],
deps = [
":double_array_trie",
":encoder_config",
":utils",
],
)
cc_library(
name = "sentencepiece_tokenizer_tflite",
srcs = ["sentencepiece_tokenizer_tflite.cc"],
hdrs = ["sentencepiece_tokenizer_tflite.h"],
visibility = [
"//visibility:public",
],
deps =
[
":optimized_encoder",
"@flatbuffers",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
],
)
cc_test(
name = "optimized_encoder_test",
srcs = [
"optimized_encoder_test.cc",
],
data = [
":testdata",
],
deps = [
":double_array_trie_builder",
":encoder_config",
":model_converter",
":optimized_encoder",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
"@com_google_sentencepiece//src:sentencepiece_cc_proto",
"@com_google_sentencepiece//src:sentencepiece_processor",
"@org_tensorflow//tensorflow/core:lib",
],
)

View File

@ -0,0 +1,131 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h"
#include "src/sentencepiece_model.pb.h"
namespace mediapipe::tflite_operations::sentencepiece {
std::tuple<std::vector<uint32_t>, std::vector<int8_t>>
DecodePrecompiledCharsmap(
const ::sentencepiece::NormalizerSpec& normalizer_spec) {
// This function "undoes" encoding done by
// sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap.
const char* precompiled_map = normalizer_spec.precompiled_charsmap().data();
const uint32_t trie_size =
*reinterpret_cast<const uint32_t*>(precompiled_map);
const uint32_t* trie_ptr =
reinterpret_cast<const uint32_t*>(precompiled_map + sizeof(uint32_t));
const int8_t* normalized_ptr = reinterpret_cast<const int8_t*>(
precompiled_map + sizeof(uint32_t) + trie_size);
const int normalized_size = normalizer_spec.precompiled_charsmap().length() -
sizeof(uint32_t) - trie_size;
return std::make_tuple(
std::vector<uint32_t>(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)),
std::vector<int8_t>(normalized_ptr, normalized_ptr + normalized_size));
}
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
const std::string& model_config_str, int encoding_offset) {
::sentencepiece::ModelProto model_config;
if (!model_config.ParseFromString(model_config_str)) {
return absl::InvalidArgumentError(
"Invalid configuration, can't parse SentencePiece model config " +
model_config.InitializationErrorString());
}
// Convert sentencepieces.
std::vector<std::string> pieces;
pieces.reserve(model_config.pieces_size());
std::vector<float> scores;
scores.reserve(model_config.pieces_size());
std::vector<int> ids;
ids.reserve(model_config.pieces_size());
float min_score = 0.0;
int index = 0;
for (const auto& piece : model_config.pieces()) {
switch (piece.type()) {
case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
pieces.push_back(piece.piece());
ids.push_back(index);
if (piece.score() < min_score) {
min_score = piece.score();
}
break;
case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
case ::sentencepiece::ModelProto::SentencePiece::CONTROL:
// Ignore unknown and control codes.
break;
default:
return absl::InvalidArgumentError("Invalid SentencePiece piece type " +
piece.piece());
}
scores.push_back(piece.score());
++index;
}
flatbuffers::FlatBufferBuilder builder(1024);
const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids));
const auto pieces_score_vector = builder.CreateVector(scores);
TrieBuilder pieces_trie_builder(builder);
pieces_trie_builder.add_nodes(pieces_trie_vector);
const auto pieces_trie_fbs = pieces_trie_builder.Finish();
// Converting normalization.
const auto normalization =
DecodePrecompiledCharsmap(model_config.normalizer_spec());
const auto normalization_trie = std::get<0>(normalization);
const auto normalization_strings = std::get<1>(normalization);
const auto normalization_trie_vector =
builder.CreateVector(normalization_trie);
TrieBuilder normalization_trie_builder(builder);
normalization_trie_builder.add_nodes(normalization_trie_vector);
const auto normalization_trie_fbs = normalization_trie_builder.Finish();
const auto normalization_strings_fbs =
builder.CreateVector(normalization_strings);
EncoderConfigBuilder ecb(builder);
ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
ecb.add_start_code(model_config.trainer_spec().bos_id());
ecb.add_end_code(model_config.trainer_spec().eos_id());
ecb.add_unknown_code(model_config.trainer_spec().unk_id());
ecb.add_unknown_penalty(min_score - kUnkPenalty);
ecb.add_encoding_offset(encoding_offset);
ecb.add_pieces(pieces_trie_fbs);
ecb.add_pieces_scores(pieces_score_vector);
ecb.add_remove_extra_whitespaces(
model_config.normalizer_spec().remove_extra_whitespaces());
ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix());
ecb.add_escape_whitespaces(
model_config.normalizer_spec().escape_whitespaces());
ecb.add_normalized_prefixes(normalization_trie_fbs);
ecb.add_normalized_replacements(normalization_strings_fbs);
FinishEncoderConfigBuffer(builder, ecb.Finish());
return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize());
}
std::string ConvertSentencepieceModel(const std::string& model_string) {
const auto result = ConvertSentencepieceModelToFlatBuffer(model_string);
assert(result.status().ok());
return result.value();
}
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,33 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
#include <string>
#include "absl/status/statusor.h"
namespace mediapipe::tflite_operations::sentencepiece {
// Converts Sentencepiece configuration to flatbuffer format.
// encoding_offset is used by some encoders that combine different encodings.
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
const std::string& model_config_str, int encoding_offset = 0);
std::string ConvertSentencepieceModel(const std::string& model_string);
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_

View File

@ -0,0 +1,236 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
#include <algorithm>
#include <tuple>
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
namespace mediapipe::tflite_operations::sentencepiece {
namespace {
const char kSpaceSymbol[] = "\xe2\x96\x81";
template <typename processing_callback>
std::tuple<std::string, std::vector<int>> process_string(
const std::string& input, const std::vector<int>& offsets,
const processing_callback& pc) {
std::string result_string;
result_string.reserve(input.size());
std::vector<int> result_offsets;
result_offsets.reserve(offsets.size());
for (int i = 0, j = 0; i < input.size();) {
auto result = pc(input.data() + i, input.size() - i);
auto consumed = std::get<0>(result);
auto new_string = std::get<1>(result);
if (consumed == 0) {
// Skip the current byte and move forward.
result_string.push_back(input[i]);
result_offsets.push_back(offsets[j]);
i++;
j++;
continue;
}
result_string.append(new_string.data(), new_string.length());
for (int i = 0; i < new_string.length(); ++i) {
result_offsets.push_back(offsets[j]);
}
j += consumed;
i += consumed;
}
return std::make_tuple(result_string, result_offsets);
}
inline char is_whitespace(char c) {
return c == ' ' || c == '\t' || c == '\r' || c == '\n';
}
std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data,
int len) {
if (len == 0 || !is_whitespace(*data)) {
return std::make_tuple(0, utils::string_view(nullptr, 0));
}
int num_consumed = 1;
for (; num_consumed < len && is_whitespace(data[num_consumed]);
++num_consumed) {
}
return num_consumed > 1
? std::make_tuple(num_consumed, utils::string_view(" ", 1))
: std::make_tuple(0, utils::string_view(nullptr, 0));
}
std::tuple<int, utils::string_view> find_replacement(
const char* data, int len, const DoubleArrayTrie& dat,
const flatbuffers::Vector<int8_t>& replacements) {
const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len));
if (!max_match.empty()) {
// Because flatbuffer byte is signed char which is not the same as char,
// there is the reinterpret_cast here.
const char* replaced_string_ptr =
reinterpret_cast<const char*>(replacements.data() + max_match.id);
return std::make_tuple(max_match.match_length,
utils::string_view(replaced_string_ptr));
}
return std::make_tuple(0, utils::string_view(nullptr, 0));
}
} // namespace
std::tuple<std::string, std::vector<int>> NormalizeString(
const std::string& in_string, const EncoderConfig& config) {
std::vector<int> output_offsets;
std::string result = in_string;
output_offsets.reserve(in_string.length());
for (int i = 0; i < in_string.length(); ++i) {
output_offsets.push_back(i);
}
if (in_string.empty()) {
return std::make_tuple(result, output_offsets);
}
if (config.add_dummy_prefix()) {
result.insert(result.begin(), ' ');
output_offsets.insert(output_offsets.begin(), 0);
}
// Greedely replace normalized_prefixes with normalized_replacements
if (config.normalized_prefixes() != nullptr &&
config.normalized_replacements() != nullptr) {
const DoubleArrayTrie normalized_prefixes_matcher(
config.normalized_prefixes()->nodes());
const auto norm_replace = [&config, &normalized_prefixes_matcher](
const char* data, int len) {
return find_replacement(data, len, normalized_prefixes_matcher,
*config.normalized_replacements());
};
std::tie(result, output_offsets) =
process_string(result, output_offsets, norm_replace);
}
if (config.remove_extra_whitespaces()) {
std::tie(result, output_offsets) =
process_string(result, output_offsets, remove_extra_whitespaces);
if (!result.empty() && is_whitespace(result.back())) {
result.pop_back();
output_offsets.pop_back();
}
}
if (config.escape_whitespaces()) {
const auto replace_whitespaces = [](const char* data, int len) {
if (len > 0 && is_whitespace(*data)) {
return std::make_tuple(1, utils::string_view(kSpaceSymbol));
}
return std::make_tuple(0, utils::string_view(nullptr, 0));
};
std::tie(result, output_offsets) =
process_string(result, output_offsets, replace_whitespaces);
}
return std::make_tuple(result, output_offsets);
}
EncoderResult EncodeNormalizedString(const std::string& str,
const std::vector<int>& offsets,
const EncoderConfig& config, bool add_bos,
bool add_eos, bool reverse) {
const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
const int unknown_code = config.unknown_code();
const float unknown_penalty = config.unknown_penalty();
struct LatticeElement {
float score = 0;
int code = -1;
int prev_position = -1;
LatticeElement(float score_, int code_, int prev_position_)
: score(score_), code(code_), prev_position(prev_position_) {}
LatticeElement() {}
};
const int length = str.length();
std::vector<LatticeElement> lattice(length + 1);
for (int i = 0; i < length; ++i) {
if (i > 0 && lattice[i].prev_position < 0) {
// This state is unreachable.
continue;
}
if (unknown_code >= 0) {
// Put unknown code.
const float penalized_score = lattice[i].score + unknown_penalty;
const int pos = i + 1;
LatticeElement& current_element = lattice[pos];
if (current_element.prev_position < 0 ||
current_element.score < penalized_score) {
current_element = LatticeElement(
penalized_score, unknown_code,
// If the current state is already reached by unknown code, merge
// states.
lattice[i].code == unknown_code ? lattice[i].prev_position : i);
}
}
auto lattice_update = [&lattice, i,
piece_scores](const DoubleArrayTrie::Match& m) {
LatticeElement& target_element = lattice[i + m.match_length];
const float score = lattice[i].score + (*piece_scores)[m.id];
if (target_element.prev_position < 0 || target_element.score < score) {
target_element = LatticeElement(score, m.id, i);
}
};
piece_matcher.IteratePrefixMatches(
utils::string_view(str.data() + i, length - i), lattice_update);
}
EncoderResult result;
if (add_eos) {
result.codes.push_back(config.end_code());
result.offsets.push_back(length);
}
if (lattice[length].prev_position >= 0) {
for (int pos = length; pos > 0;) {
auto code = lattice[pos].code;
if (code != config.unknown_code()) {
code += config.encoding_offset();
}
result.codes.push_back(code);
pos = lattice[pos].prev_position;
result.offsets.push_back(offsets[pos]);
}
}
if (add_bos) {
result.codes.push_back(config.start_code());
result.offsets.push_back(0);
}
if (!reverse) {
std::reverse(result.codes.begin(), result.codes.end());
std::reverse(result.offsets.begin(), result.offsets.end());
}
return result;
}
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
bool add_bos, bool add_eos, bool reverse) {
// Get the config from the buffer.
const EncoderConfig* config = GetEncoderConfig(config_buffer);
if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
EncoderResult result;
result.type = EncoderResultType::WRONG_CONFIG;
return result;
}
std::string normalized_string;
std::vector<int> offsets;
std::tie(normalized_string, offsets) = NormalizeString(string, *config);
return EncodeNormalizedString(normalized_string, offsets, *config, add_bos,
add_eos, reverse);
}
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,46 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
// Sentencepiece encoder optimized with memmapped model.
#include <string>
#include <tuple>
#include <vector>
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
namespace mediapipe::tflite_operations::sentencepiece {
enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 };
struct EncoderResult {
EncoderResultType type = EncoderResultType::SUCCESS;
std::vector<int> codes;
std::vector<int> offsets;
};
std::tuple<std::string, std::vector<int>> NormalizeString(
const std::string& in_string, const EncoderConfig& config);
// Encodes one string and returns ids and offsets. Takes the configuration as a
// type-erased buffer.
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
bool add_bos, bool add_eos, bool reverse);
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_

View File

@ -0,0 +1,171 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
#include <fstream>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
#include "src/sentencepiece.pb.h"
#include "src/sentencepiece_processor.h"
#include "tensorflow/core/platform/env.h"
namespace mediapipe::tflite_operations::sentencepiece {
namespace internal {
tensorflow::Status TFReadFileToString(const std::string& filepath,
std::string* data) {
return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath,
data);
}
absl::Status StdReadFileToString(const std::string& filepath,
std::string* data) {
std::ifstream infile(filepath);
if (!infile.is_open()) {
return absl::NotFoundError(
absl::StrFormat("Error when opening %s", filepath));
}
std::string contents((std::istreambuf_iterator<char>(infile)),
(std::istreambuf_iterator<char>()));
data->append(contents);
infile.close();
return absl::OkStatus();
}
} // namespace internal
namespace {
using ::mediapipe::file::JoinPath;
static char kConfigFilePath[] =
"/mediapipe/tasks/cc/text/custom_ops/"
"sentencepiece/testdata/sentencepiece.model";
TEST(OptimizedEncoder, NormalizeStringWhitestpaces) {
flatbuffers::FlatBufferBuilder builder(1024);
EncoderConfigBuilder ecb(builder);
ecb.add_remove_extra_whitespaces(true);
ecb.add_add_dummy_prefix(true);
ecb.add_escape_whitespaces(true);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
{
const auto result = NormalizeString("x y", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3));
}
{
const auto result = NormalizeString("\tx y\n", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4));
}
}
TEST(OptimizedEncoder, NormalizeStringReplacement) {
flatbuffers::FlatBufferBuilder builder(1024);
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA"};
const char norm_replacements[] = "A1\0A2\0A3\0A4";
const auto trie_vector =
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9}));
const auto norm_r = builder.CreateVector<int8_t>(
reinterpret_cast<const signed char*>(norm_replacements),
sizeof(norm_replacements));
TrieBuilder trie_builder(builder);
trie_builder.add_nodes(trie_vector);
const auto norm_p = trie_builder.Finish();
EncoderConfigBuilder ecb(builder);
ecb.add_remove_extra_whitespaces(false);
ecb.add_normalized_prefixes(norm_p);
ecb.add_normalized_replacements(norm_r);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
{
const auto result = NormalizeString("ABAABAAABAAAA", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, "A1BA2BA3BA4");
EXPECT_THAT(offsets,
::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9));
}
}
TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) {
flatbuffers::FlatBufferBuilder builder(1024);
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA",
"X"};
const char norm_replacements[] = "A1\0A2\0A3\0A4\0 ";
const auto trie_vector =
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12}));
const auto norm_r = builder.CreateVector<int8_t>(
reinterpret_cast<const signed char*>(norm_replacements),
sizeof(norm_replacements));
TrieBuilder trie_builder(builder);
trie_builder.add_nodes(trie_vector);
const auto norm_p = trie_builder.Finish();
EncoderConfigBuilder ecb(builder);
ecb.add_remove_extra_whitespaces(true);
ecb.add_normalized_prefixes(norm_p);
ecb.add_normalized_replacements(norm_r);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
{
const auto result = NormalizeString("XXABAABAAABAAAA", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, " A1BA2BA3BA4");
EXPECT_THAT(offsets,
::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11));
}
}
TEST(OptimizedEncoder, ConfigConverter) {
std::string config;
auto status =
internal::TFReadFileToString(JoinPath("./", kConfigFilePath), &config);
ASSERT_TRUE(status.ok());
::sentencepiece::SentencePieceProcessor processor;
ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok());
const auto converted_model = ConvertSentencepieceModel(config);
const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95");
const auto encoded =
EncodeString(test_string, converted_model.data(), false, false, false);
ASSERT_EQ(encoded.codes.size(), encoded.offsets.size());
::sentencepiece::SentencePieceText reference_encoded;
ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok());
EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size());
for (int i = 0; i < encoded.codes.size(); ++i) {
EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id());
EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin());
}
}
} // namespace
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,38 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
namespace mediapipe::tflite_operations::sentencepiece {
// The constant is copied from
// https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc
constexpr float kUnkPenalty = 10.0;
// These constants are copied from
// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc
//
// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK).
constexpr char kSpaceSymbol[] = "\xe2\x96\x81";
// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
// since this character can be useful both for user and
// developer. We can easily figure out that <unk> is emitted.
constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 ";
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_

View File

@ -0,0 +1,129 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
#include "flatbuffers/flexbuffers.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
namespace mediapipe::tflite_operations {
namespace sentencepiece::tokenizer {
namespace {
using ::tflite::SetTensorToDynamic;
constexpr int kSPModelIndex = 0;
constexpr int kInputIndex = 1;
constexpr int kAddBOSInput = 4;
constexpr int kAddEOSInput = 5;
constexpr int kReverseInput = 6;
constexpr int kOutputValuesInd = 0;
constexpr int kOutputSplitsInd = 1;
TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
int index = 0;
for (const int size : sizes) {
array_size->data[index++] = size;
}
return array_size;
}
} // namespace
// Initializes text encoder object from serialized parameters.
void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
size_t /*length*/) {
return nullptr;
}
void Free(TfLiteContext* /*context*/, void* /*buffer*/) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// TODO: Add checks for input and output tensors.
TfLiteTensor& output_values =
context->tensors[node->outputs->data[kOutputValuesInd]];
SetTensorToDynamic(&output_values);
TfLiteTensor& output_splits =
context->tensors[node->outputs->data[kOutputSplitsInd]];
SetTensorToDynamic(&output_splits);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor& model_tensor =
context->tensors[node->inputs->data[kSPModelIndex]];
const auto model_buffer_data = model_tensor.data.data;
const TfLiteTensor& input_text =
context->tensors[node->inputs->data[kInputIndex]];
const TfLiteTensor add_bos_tensor =
context->tensors[node->inputs->data[kAddBOSInput]];
const bool add_bos = add_bos_tensor.data.b[0];
const TfLiteTensor add_eos_tensor =
context->tensors[node->inputs->data[kAddEOSInput]];
const bool add_eos = add_eos_tensor.data.b[0];
const TfLiteTensor reverse_tensor =
context->tensors[node->inputs->data[kReverseInput]];
const bool reverse = reverse_tensor.data.b[0];
std::vector<int32> encoded;
std::vector<int32> splits;
const int num_strings = tflite::GetStringCount(&input_text);
for (int i = 0; i < num_strings; ++i) {
const auto strref = tflite::GetString(&input_text, i);
const auto res = EncodeString(std::string(strref.str, strref.len),
model_buffer_data, add_bos, add_eos, reverse);
TF_LITE_ENSURE_MSG(context, res.type == EncoderResultType::SUCCESS,
"Sentencepiece conversion failed");
std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded));
splits.emplace_back(encoded.size());
}
TfLiteTensor& output_values =
context->tensors[node->outputs->data[kOutputValuesInd]];
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(
context, &output_values,
CreateSizeArray({static_cast<int>(encoded.size())})));
int32_t* output_values_flat = output_values.data.i32;
std::copy(encoded.begin(), encoded.end(), output_values_flat);
TfLiteTensor& output_splits =
context->tensors[node->outputs->data[kOutputSplitsInd]];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(
context, &output_splits,
CreateSizeArray({static_cast<int>(splits.size() + 1)})));
int32_t* output_splits_flat = output_splits.data.i32;
*output_splits_flat = 0;
std::copy(splits.begin(), splits.end(), output_splits_flat + 1);
return kTfLiteOk;
}
} // namespace sentencepiece::tokenizer
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER() {
static TfLiteRegistration r = {
sentencepiece::tokenizer::Initialize, sentencepiece::tokenizer::Free,
sentencepiece::tokenizer::Prepare, sentencepiece::tokenizer::Eval};
return &r;
}
} // namespace mediapipe::tflite_operations

View File

@ -0,0 +1,27 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
#include "tensorflow/lite/kernels/register.h"
namespace mediapipe::tflite_operations {
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
} // namespace mediapipe::tflite_operations
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_

View File

@ -39,6 +39,8 @@ constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite";
// Embedding model with regex preprocessing.
constexpr char kRegexOneEmbeddingModel[] =
"regex_one_embedding_with_metadata.tflite";
constexpr char kUniversalSentenceEncoderModel[] =
"universal_sentence_encoder_qa_with_metadata.tflite";
// Tolerance for embedding vector coordinate values.
constexpr float kEpsilon = 1e-4;
@ -147,6 +149,35 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) {
MP_ASSERT_OK(text_embedder->Close());
}
TEST(EmbedTest, SucceedsWithUniversalSentenceEncoderModel) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
auto result0,
text_embedder->Embed("it's a charming and often affecting journey"));
ASSERT_EQ(result0.embeddings.size(), 1);
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 100);
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 1.422951f, kEpsilon);
MP_ASSERT_OK_AND_ASSIGN(
auto result1, text_embedder->Embed("what a great and fantastic trip"));
ASSERT_EQ(result1.embeddings.size(), 1);
ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 100);
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 1.404664f, kEpsilon);
// Check cosine similarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
result1.embeddings[0]));
ASSERT_NEAR(similarity, 0.851961, kSimilarityTolerancy);
MP_ASSERT_OK(text_embedder->Close());
}
TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
@ -178,5 +209,31 @@ TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
MP_ASSERT_OK(text_embedder->Close());
}
TEST_F(EmbedderTest, SucceedsWithUSEAndDifferentThemes) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result0,
text_embedder->Embed("When you go to this restaurant, they hold the "
"pancake upside-down before they hand it "
"to you. It's a great gimmick."));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result1,
text_embedder->Embed(
"Let's make a plan to steal the declaration of independence."));
// Check cosine similarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
result1.embeddings[0]));
EXPECT_NEAR(similarity, 0.780334, kSimilarityTolerancy);
MP_ASSERT_OK(text_embedder->Close());
}
} // namespace
} // namespace mediapipe::tasks::text::text_embedder

View File

@ -23,18 +23,12 @@ cc_library(
srcs = ["face_stylizer_graph.cc"],
deps = [
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/image:image_cropping_calculator",
"//mediapipe/calculators/image:image_cropping_calculator_cc_proto",
"//mediapipe/calculators/image:warp_affine_calculator",
"//mediapipe/calculators/image:warp_affine_calculator_cc_proto",
"//mediapipe/calculators/image:image_clone_calculator_cc_proto",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:face_to_rect_calculator",
"//mediapipe/calculators/util:from_image_calculator",
"//mediapipe/calculators/util:inverse_matrix_calculator",
"//mediapipe/calculators/util:landmarks_to_detection_calculator_cc_proto",
"//mediapipe/calculators/util:to_image_calculator",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
@ -53,7 +47,6 @@ cc_library(
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:strip_rotation_calculator",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",

View File

@ -84,9 +84,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
// The input image can be of any size with format RGB or RGBA.
// When no face is detected on the input image, the method returns a
// std::nullopt. Otherwise, returns the stylized image of the most visible
// face. To ensure that the output image has reasonable quality, the stylized
// output image size is the smaller of the model output size and the size of
// the 'region_of_interest' specified in 'image_processing_options'.
// face. The stylized output image size is the same as the model output size.
absl::StatusOr<std::optional<mediapipe::Image>> Stylize(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options =
@ -111,9 +109,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
// must be monotonically increasing.
// When no face is detected on the input image, the method returns a
// std::nullopt. Otherwise, returns the stylized image of the most visible
// face. To ensure that the output image has reasonable quality, the stylized
// output image size is the smaller of the model output size and the size of
// the 'region_of_interest' specified in 'image_processing_options'.
// face. The stylized output image size is the same as the model output size.
absl::StatusOr<std::optional<mediapipe::Image>> StylizeForVideo(
mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options =
@ -143,10 +139,8 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
// The "result_callback" provides:
// - When no face is detected on the input image, the method returns a
// std::nullopt. Otherwise, returns the stylized image of the most visible
// face. To ensure that the output image has reasonable quality, the
// stylized output image size is the smaller of the model output size and
// the size of the 'region_of_interest' specified in
// 'image_processing_options'.
// face. The stylized output image size is the same as the model output
// size.
// - The input timestamp in milliseconds.
absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions>

View File

@ -19,8 +19,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/status/statusor.h"
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
#include "mediapipe/calculators/image/image_cropping_calculator.pb.h"
#include "mediapipe/calculators/image/warp_affine_calculator.pb.h"
#include "mediapipe/calculators/image/image_clone_calculator.pb.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/calculators/util/landmarks_to_detection_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
@ -326,7 +325,6 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
image_in >> preprocessing.In(kImageTag);
face_rect >> preprocessing.In(kNormRectTag);
auto preprocessed_tensors = preprocessing.Out(kTensorsTag);
auto transform_matrix = preprocessing.Out(kMatrixTag);
// Adds inference subgraph and connects its input stream to the output
// tensors produced by the ImageToTensorCalculator.
@ -344,53 +342,12 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
model_output_tensors >> tensors_to_image.In(kTensorsTag);
auto tensor_image = tensors_to_image.Out(kImageTag);
auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator");
transform_matrix >> inverse_matrix.In(kMatrixTag);
auto inverse_transform_matrix = inverse_matrix.Out(kMatrixTag);
auto& image_converter = graph.AddNode("ImageCloneCalculator");
image_converter.GetOptions<mediapipe::ImageCloneCalculatorOptions>()
.set_output_on_gpu(false);
tensor_image >> image_converter.In("");
auto& warp_affine = graph.AddNode("WarpAffineCalculator");
auto& warp_affine_options =
warp_affine.GetOptions<WarpAffineCalculatorOptions>();
warp_affine_options.set_border_mode(
WarpAffineCalculatorOptions::BORDER_ZERO);
warp_affine_options.set_gpu_origin(mediapipe::GpuOrigin_Mode_TOP_LEFT);
tensor_image >> warp_affine.In(kImageTag);
inverse_transform_matrix >> warp_affine.In(kMatrixTag);
image_size >> warp_affine.In(kOutputSizeTag);
auto image_to_crop = warp_affine.Out(kImageTag);
// The following calculators are for cropping and resizing the output image
// based on the roi and the model output size. As the WarpAffineCalculator
// rotates the image based on the transform matrix, the rotation info in the
// rect proto is stripped to prevent the ImageCroppingCalculator from
// performing extra rotation.
auto& strip_rotation =
graph.AddNode("mediapipe.tasks.StripRotationCalculator");
face_rect >> strip_rotation.In(kNormRectTag);
auto norm_rect_no_rotation = strip_rotation.Out(kNormRectTag);
auto& from_image = graph.AddNode("FromImageCalculator");
image_to_crop >> from_image.In(kImageTag);
auto& image_cropping = graph.AddNode("ImageCroppingCalculator");
auto& image_cropping_opts =
image_cropping.GetOptions<ImageCroppingCalculatorOptions>();
image_cropping_opts.set_output_max_width(
image_to_tensor_options.output_tensor_width());
image_cropping_opts.set_output_max_height(
image_to_tensor_options.output_tensor_height());
norm_rect_no_rotation >> image_cropping.In(kNormRectTag);
auto& to_image = graph.AddNode("ToImageCalculator");
// ImageCroppingCalculator currently doesn't support mediapipe::Image, the
// graph selects its cpu or gpu path based on the image preprocessing
// backend.
if (use_gpu) {
from_image.Out(kImageGpuTag) >> image_cropping.In(kImageGpuTag);
image_cropping.Out(kImageGpuTag) >> to_image.In(kImageGpuTag);
} else {
from_image.Out(kImageCpuTag) >> image_cropping.In(kImageTag);
image_cropping.Out(kImageTag) >> to_image.In(kImageCpuTag);
}
return {{/*stylized_image=*/to_image.Out(kImageTag).Cast<Image>(),
return {{/*stylized_image=*/image_converter.Out("").Cast<Image>(),
/*original_image=*/preprocessing.Out(kImageTag).Cast<Image>()}};
}
};

View File

@ -100,6 +100,7 @@ cc_library(
"//mediapipe/util:graph_builder_utils",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_library(

View File

@ -90,7 +90,7 @@ NS_SWIFT_NAME(ClassificationResult)
* amount of data to process might exceed the maximum size that the model can process: to solve
* this, the input data is split into multiple chunks starting at different timestamps.
*/
@property(nonatomic, readonly) NSInteger timestampMs;
@property(nonatomic, readonly) NSInteger timestampInMilliseconds;
/**
* Initializes a new `MPPClassificationResult` with the given array of classifications and time
@ -98,14 +98,15 @@ NS_SWIFT_NAME(ClassificationResult)
*
* @param classifications An Array of `MPPClassifications` objects containing the predicted
* categories for each head of the model.
* @param timestampMs The timestamp (in milliseconds) of the start of the chunk of data
* @param timestampInMilliseconds The timestamp (in milliseconds) of the start of the chunk of data
* corresponding to these results.
*
* @return An instance of `MPPClassificationResult` initialized with the given array of
* classifications and timestampMs.
* classifications and timestamp (in milliseconds).
*/
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER;
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
NS_DESIGNATED_INITIALIZER;
- (instancetype)init NS_UNAVAILABLE;

View File

@ -38,11 +38,11 @@
@implementation MPPClassificationResult
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
timestampMs:(NSInteger)timestampMs {
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super init];
if (self) {
_classifications = classifications;
_timestampMs = timestampMs;
_timestampInMilliseconds = timestampInMilliseconds;
}
return self;

View File

@ -33,7 +33,7 @@ NS_SWIFT_NAME(EmbeddingResult)
* cases, the amount of data to process might exceed the maximum size that the model can process. To
* solve this, the input data is split into multiple chunks starting at different timestamps.
*/
@property(nonatomic, readonly) NSInteger timestampMs;
@property(nonatomic, readonly) NSInteger timestampInMilliseconds;
/**
* Initializes a new `MPPEmbedding` with the given array of embeddings and timestamp (in
@ -41,14 +41,14 @@ NS_SWIFT_NAME(EmbeddingResult)
*
* @param embeddings An Array of `MPPEmbedding` objects containing the embedding results for each
* head of the model.
* @param timestampMs The optional timestamp (in milliseconds) of the start of the chunk of data
* corresponding to these results. Pass `0` if timestamp is absent.
* @param timestampInMilliseconds The optional timestamp (in milliseconds) of the start of the chunk
* of data corresponding to these results. Pass `0` if timestamp is absent.
*
* @return An instance of `MPPEmbeddingResult` initialized with the given array of embeddings and
* timestampMs.
* timestamp (in milliseconds).
*/
- (instancetype)initWithEmbeddings:(NSArray<MPPEmbedding *> *)embeddings
timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER;
timestampInMilliseconds:(NSInteger)timestampInMilliseconds NS_DESIGNATED_INITIALIZER;
- (instancetype)init NS_UNAVAILABLE;

View File

@ -17,11 +17,11 @@
@implementation MPPEmbeddingResult
- (instancetype)initWithEmbeddings:(NSArray<MPPEmbedding *> *)embeddings
timestampMs:(NSInteger)timestampMs {
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super init];
if (self) {
_embeddings = embeddings;
_timestampMs = timestampMs;
_timestampInMilliseconds = timestampInMilliseconds;
}
return self;

View File

@ -55,13 +55,13 @@ using ClassificationResultProto =
[classifications addObject:[MPPClassifications classificationsWithProto:classificationsProto]];
}
NSInteger timestampMs = 0;
NSInteger timestampInMilliseconds = 0;
if (classificationResultProto.has_timestamp_ms()) {
timestampMs = (NSInteger)classificationResultProto.timestamp_ms();
timestampInMilliseconds = (NSInteger)classificationResultProto.timestamp_ms();
}
return [[MPPClassificationResult alloc] initWithClassifications:classifications
timestampMs:timestampMs];
timestampInMilliseconds:timestampInMilliseconds];
;
}

View File

@ -31,12 +31,13 @@ using EmbeddingResultProto = ::mediapipe::tasks::components::containers::proto::
[embeddings addObject:[MPPEmbedding embeddingWithProto:embeddingProto]];
}
NSInteger timestampMs = 0;
NSInteger timestampInMilliseconds = 0;
if (embeddingResultProto.has_timestamp_ms()) {
timestampMs = (NSInteger)embeddingResultProto.timestamp_ms();
timestampInMilliseconds = (NSInteger)embeddingResultProto.timestamp_ms();
}
return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings timestampMs:timestampMs];
return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings
timestampInMilliseconds:timestampInMilliseconds];
}
@end

View File

@ -26,11 +26,12 @@ NS_SWIFT_NAME(TaskResult)
/**
* Timestamp that is associated with the task result object.
*/
@property(nonatomic, assign, readonly) NSInteger timestampMs;
@property(nonatomic, assign, readonly) NSInteger timestampInMilliseconds;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithTimestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER;
- (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds
NS_DESIGNATED_INITIALIZER;
@end

View File

@ -16,16 +16,16 @@
@implementation MPPTaskResult
- (instancetype)initWithTimestampMs:(NSInteger)timestampMs {
- (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super init];
if (self) {
_timestampMs = timestampMs;
_timestampInMilliseconds = timestampInMilliseconds;
}
return self;
}
- (id)copyWithZone:(NSZone *)zone {
return [[MPPTaskResult alloc] initWithTimestampMs:self.timestampMs];
return [[MPPTaskResult alloc] initWithTimestampInMilliseconds:self.timestampInMilliseconds];
}
@end

View File

@ -487,7 +487,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
NSError *liveStreamApiCallError;
XCTAssertFalse([imageClassifier classifyAsyncImage:image
timestampMs:0
timestampInMilliseconds:0
error:&liveStreamApiCallError]);
NSError *expectedLiveStreamApiCallError =
@ -501,7 +501,9 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
NSError *videoApiCallError;
XCTAssertFalse([imageClassifier classifyVideoFrame:image timestampMs:0 error:&videoApiCallError]);
XCTAssertFalse([imageClassifier classifyVideoFrame:image
timestampInMilliseconds:0
error:&videoApiCallError]);
NSError *expectedVideoApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
@ -524,7 +526,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
NSError *liveStreamApiCallError;
XCTAssertFalse([imageClassifier classifyAsyncImage:image
timestampMs:0
timestampInMilliseconds:0
error:&liveStreamApiCallError]);
NSError *expectedLiveStreamApiCallError =
@ -575,7 +577,9 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
NSError *videoApiCallError;
XCTAssertFalse([imageClassifier classifyVideoFrame:image timestampMs:0 error:&videoApiCallError]);
XCTAssertFalse([imageClassifier classifyVideoFrame:image
timestampInMilliseconds:0
error:&videoApiCallError]);
NSError *expectedVideoApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
@ -601,7 +605,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
for (int i = 0; i < 3; i++) {
MPPImageClassifierResult *imageClassifierResult = [imageClassifier classifyVideoFrame:image
timestampMs:i
timestampInMilliseconds:i
error:nil];
[self assertImageClassifierResult:imageClassifierResult
hasExpectedCategoriesCount:maxResults
@ -630,10 +634,10 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
MPPImage *image = [self imageWithFileInfo:kBurgerImage];
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampMs:1 error:nil]);
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:1 error:nil]);
NSError *error;
XCTAssertFalse([imageClassifier classifyAsyncImage:image timestampMs:0 error:&error]);
XCTAssertFalse([imageClassifier classifyAsyncImage:image timestampInMilliseconds:0 error:&error]);
NSError *expectedError =
[NSError errorWithDomain:kExpectedErrorDomain
@ -668,7 +672,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
MPPImage *image = [self imageWithFileInfo:kBurgerImage];
for (int i = 0; i < 3; i++) {
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampMs:i error:nil]);
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:i error:nil]);
}
}

View File

@ -31,13 +31,13 @@ NS_SWIFT_NAME(TextClassifierResult)
*
* @param classificationResult The `MPPClassificationResult` instance containing one set of results
* per classifier head.
* @param timestampMs The timestamp for this result.
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
*
* @return An instance of `MPPTextClassifierResult` initialized with the given
* `MPPClassificationResult` and timestamp (in milliseconds).
*/
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
timestampMs:(NSInteger)timestampMs;
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
@end

View File

@ -17,8 +17,8 @@
@implementation MPPTextClassifierResult
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
timestampMs:(NSInteger)timestampMs {
self = [super initWithTimestampMs:timestampMs];
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
if (self) {
_classificationResult = classificationResult;
}

View File

@ -35,7 +35,7 @@ using ::mediapipe::Packet;
return [[MPPTextClassifierResult alloc]
initWithClassificationResult:classificationResult
timestampMs:(NSInteger)(packet.Timestamp().Value() /
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)];
}

View File

@ -31,13 +31,13 @@ NS_SWIFT_NAME(TextEmbedderResult)
*
* @param embeddingResult The `MPPEmbeddingResult` instance containing one set of results
* per classifier head.
* @param timestampMs The timestamp for this result.
* @param timestampInMilliseconds The timestamp (in millisecondss) for this result.
*
* @return An instance of `MPPTextEmbedderResult` initialized with the given
* `MPPEmbeddingResult` and timestamp (in milliseconds).
*/
- (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult
timestampMs:(NSInteger)timestampMs;
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
- (instancetype)init NS_UNAVAILABLE;

View File

@ -17,8 +17,8 @@
@implementation MPPTextEmbedderResult
- (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult
timestampMs:(NSInteger)timestampMs {
self = [super initWithTimestampMs:timestampMs];
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
if (self) {
_embeddingResult = embeddingResult;
}

View File

@ -34,7 +34,7 @@ using ::mediapipe::Packet;
return [[MPPTextEmbedderResult alloc]
initWithEmbeddingResult:embeddingResult
timestampMs:(NSInteger)(packet.Timestamp().Value() /
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)];
}

View File

@ -41,7 +41,7 @@
* timestamp.
*
* @param image The image to send to the MediaPipe graph.
* @param timestampMs The timestamp (in milliseconds) to assign to the packet.
* @param timestampInMilliseconds The timestamp (in milliseconds) to assign to the packet.
* @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no
* error will be saved.
*
@ -49,7 +49,7 @@
* occurred during the conversion.
*/
+ (mediapipe::Packet)createPacketWithMPPImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error;
/**
@ -66,11 +66,11 @@
* specified timestamp.
*
* @param image The `NormalizedRect` to send to the MediaPipe graph.
* @param timestampMs The timestamp (in milliseconds) to assign to the packet.
* @param timestampInMilliseconds The timestamp (in milliseconds) to assign to the packet.
*
* @return The MediaPipe packet containing the normalized rect.
*/
+ (mediapipe::Packet)createPacketWithNormalizedRect:(mediapipe::NormalizedRect &)normalizedRect
timestampMs:(NSInteger)timestampMs;
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
@end

View File

@ -42,7 +42,7 @@ using ::mediapipe::Timestamp;
}
+ (Packet)createPacketWithMPPImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error {
std::unique_ptr<ImageFrame> imageFrame = [image imageFrameWithError:error];
@ -51,7 +51,7 @@ using ::mediapipe::Timestamp;
}
return MakePacket<Image>(std::move(imageFrame))
.At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond)));
.At(Timestamp(int64(timestampInMilliseconds * kMicroSecondsPerMilliSecond)));
}
+ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect {
@ -59,9 +59,9 @@ using ::mediapipe::Timestamp;
}
+ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect
timestampMs:(NSInteger)timestampMs {
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
return MakePacket<NormalizedRect>(std::move(normalizedRect))
.At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond)));
.At(Timestamp(int64(timestampInMilliseconds * kMicroSecondsPerMilliSecond)));
}
@end

View File

@ -21,7 +21,7 @@
handedness:(NSArray<NSArray<MPPCategory *> *> *)handedness
gestures:(NSArray<NSArray<MPPCategory *> *> *)gestures
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super initWithTimestampMs:timestampInMilliseconds];
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
if (self) {
_landmarks = landmarks;
_worldLandmarks = worldLandmarks;

View File

@ -122,17 +122,17 @@ NS_SWIFT_NAME(ImageClassifier)
* `MPPRunningModeVideo`.
*
* @param image The `MPPImage` on which image classification is to be performed.
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be
* monotonically increasing.
* @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
* timestamps must be monotonically increasing.
* @param error An optional error parameter populated when there is an error in performing image
* classification on the input video frame.
*
* @return An `MPPImageClassifierResult` object that contains a list of image classifications.
*/
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error
NS_SWIFT_NAME(classify(videoFrame:timestampMs:));
NS_SWIFT_NAME(classify(videoFrame:timestampInMilliseconds:));
/**
* Performs image classification on the provided video frame of type `MPPImage` cropped to the
@ -145,8 +145,8 @@ NS_SWIFT_NAME(ImageClassifier)
*
* @param image A live stream image data of type `MPPImage` on which image classification is to be
* performed.
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be
* monotonically increasing.
* @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
* timestamps must be monotonically increasing.
* @param roi A `CGRect` specifying the region of interest within the video frame of type
* `MPPImage`, on which image classification should be performed.
* @param error An optional error parameter populated when there is an error in performing image
@ -155,10 +155,10 @@ NS_SWIFT_NAME(ImageClassifier)
* @return An `MPPImageClassifierResult` object that contains a list of image classifications.
*/
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi
error:(NSError **)error
NS_SWIFT_NAME(classify(videoFrame:timestampMs:regionOfInterest:));
NS_SWIFT_NAME(classify(videoFrame:timestampInMilliseconds:regionOfInterest:));
/**
* Sends live stream image data of type `MPPImage` to perform image classification using the whole
@ -172,16 +172,17 @@ NS_SWIFT_NAME(ImageClassifier)
*
* @param image A live stream image data of type `MPPImage` on which image classification is to be
* performed.
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent
* to the image classifier. The input timestamps must be monotonically increasing.
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
* image is sent to the image classifier. The input timestamps must be monotonically increasing.
* @param error An optional error parameter populated when there is an error in performing image
* classification on the input live stream image data.
*
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
*/
- (BOOL)classifyAsyncImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
error:(NSError **)error NS_SWIFT_NAME(classifyAsync(image:timestampMs:));
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error
NS_SWIFT_NAME(classifyAsync(image:timestampInMilliseconds:));
/**
* Sends live stream image data of type `MPPImage` to perform image classification, cropped to the
@ -195,8 +196,8 @@ NS_SWIFT_NAME(ImageClassifier)
*
* @param image A live stream image data of type `MPPImage` on which image classification is to be
* performed.
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent
* to the image classifier. The input timestamps must be monotonically increasing.
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
* image is sent to the image classifier. The input timestamps must be monotonically increasing.
* @param roi A `CGRect` specifying the region of interest within the given live stream image data
* of type `MPPImage`, on which image classification should be performed.
* @param error An optional error parameter populated when there is an error in performing image
@ -205,10 +206,10 @@ NS_SWIFT_NAME(ImageClassifier)
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
*/
- (BOOL)classifyAsyncImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
regionOfInterest:(CGRect)roi
error:(NSError **)error
NS_SWIFT_NAME(classifyAsync(image:timestampMs:regionOfInterest:));
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi
error:(NSError **)error
NS_SWIFT_NAME(classifyAsync(image:timestampInMilliseconds:regionOfInterest:));
- (instancetype)init NS_UNAVAILABLE;

View File

@ -149,7 +149,7 @@ static NSString *const kTaskGraphName =
}
- (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi
error:(NSError **)error {
std::optional<NormalizedRect> rect =
@ -162,14 +162,15 @@ static NSString *const kTaskGraphName =
}
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
timestampMs:timestampMs
timestampInMilliseconds:timestampInMilliseconds
error:error];
if (imagePacket.IsEmpty()) {
return std::nullopt;
}
Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
timestampMs:timestampMs];
Packet normalizedRectPacket =
[MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
timestampInMilliseconds:timestampInMilliseconds];
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
return inputPacketMap;
@ -180,11 +181,11 @@ static NSString *const kTaskGraphName =
}
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi
error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampMs:timestampMs
timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:roi
error:error];
if (!inputPacketMap.has_value()) {
@ -204,20 +205,20 @@ static NSString *const kTaskGraphName =
}
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error {
return [self classifyVideoFrame:image
timestampMs:timestampMs
timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:CGRectZero
error:error];
}
- (BOOL)classifyAsyncImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
regionOfInterest:(CGRect)roi
error:(NSError **)error {
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi
error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampMs:timestampMs
timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:roi
error:error];
if (!inputPacketMap.has_value()) {
@ -228,10 +229,10 @@ static NSString *const kTaskGraphName =
}
- (BOOL)classifyAsyncImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
error:(NSError **)error {
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error {
return [self classifyAsyncImage:image
timestampMs:timestampMs
timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:CGRectZero
error:error];
}

View File

@ -31,13 +31,13 @@ NS_SWIFT_NAME(ImageClassifierResult)
*
* @param classificationResult The `MPPClassificationResult` instance containing one set of results
* per classifier head.
* @param timestampMs The timestamp for this result.
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
*
* @return An instance of `MPPImageClassifierResult` initialized with the given
* `MPPClassificationResult` and timestamp (in milliseconds).
*/
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
timestampMs:(NSInteger)timestampMs;
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
@end

View File

@ -17,8 +17,8 @@
@implementation MPPImageClassifierResult
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
timestampMs:(NSInteger)timestampMs {
self = [super initWithTimestampMs:timestampMs];
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
if (self) {
_classificationResult = classificationResult;
}

View File

@ -34,7 +34,7 @@ using ::mediapipe::Packet;
return [[MPPImageClassifierResult alloc]
initWithClassificationResult:classificationResult
timestampMs:(NSInteger)(packet.Timestamp().Value() /
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)];
}

View File

@ -36,13 +36,13 @@ NS_SWIFT_NAME(ObjectDetectionResult)
* @param detections An array of `MPPDetection` objects each of which has a bounding box that is
* expressed in the unrotated input frame of reference coordinates system, i.e. in `[0,image_width)
* x [0,image_height)`, which are the dimensions of the underlying image data.
* @param timestampMs The timestamp for this result.
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
*
* @return An instance of `MPPObjectDetectionResult` initialized with the given array of detections
* and timestamp (in milliseconds).
*/
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
timestampMs:(NSInteger)timestampMs;
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
@end

View File

@ -17,8 +17,8 @@
@implementation MPPObjectDetectionResult
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
timestampMs:(NSInteger)timestampMs {
self = [super initWithTimestampMs:timestampMs];
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
if (self) {
_detections = detections;
}

View File

@ -138,8 +138,8 @@ NS_SWIFT_NAME(ObjectDetector)
* `MPPRunningModeVideo`.
*
* @param image The `MPPImage` on which object detection is to be performed.
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be
* monotonically increasing.
* @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
* timestamps must be monotonically increasing.
* @param error An optional error parameter populated when there is an error in performing object
* detection on the input image.
*
@ -149,9 +149,9 @@ NS_SWIFT_NAME(ObjectDetector)
* image data.
*/
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error
NS_SWIFT_NAME(detect(videoFrame:timestampMs:));
NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:));
/**
* Performs object detection on the provided video frame of type `MPPImage` cropped to the
@ -164,8 +164,8 @@ NS_SWIFT_NAME(ObjectDetector)
*
* @param image A live stream image data of type `MPPImage` on which object detection is to be
* performed.
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be
* monotonically increasing.
* @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
* timestamps must be monotonically increasing.
* @param roi A `CGRect` specifying the region of interest within the given `MPPImage`, on which
* object detection should be performed.
*
@ -178,10 +178,10 @@ NS_SWIFT_NAME(ObjectDetector)
* image data.
*/
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi
error:(NSError **)error
NS_SWIFT_NAME(detect(videoFrame:timestampMs:regionOfInterest:));
NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:regionOfInterest:));
/**
* Sends live stream image data of type `MPPImage` to perform object detection using the whole
@ -195,16 +195,17 @@ NS_SWIFT_NAME(ObjectDetector)
*
* @param image A live stream image data of type `MPPImage` on which object detection is to be
* performed.
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent
* to the object detector. The input timestamps must be monotonically increasing.
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
* image is sent to the object detector. The input timestamps must be monotonically increasing.
* @param error An optional error parameter populated when there is an error in performing object
* detection on the input live stream image data.
*
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
*/
- (BOOL)detectAsyncInImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
error:(NSError **)error NS_SWIFT_NAME(detectAsync(image:timestampMs:));
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error
NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:));
/**
* Sends live stream image data of type `MPPImage` to perform object detection, cropped to the
@ -218,8 +219,8 @@ NS_SWIFT_NAME(ObjectDetector)
*
* @param image A live stream image data of type `MPPImage` on which object detection is to be
* performed.
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent
* to the object detector. The input timestamps must be monotonically increasing.
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
* image is sent to the object detector. The input timestamps must be monotonically increasing.
* @param roi A `CGRect` specifying the region of interest within the given live stream image data
* of type `MPPImage`, on which iobject detection should be performed.
* @param error An optional error parameter populated when there is an error in performing object
@ -228,10 +229,10 @@ NS_SWIFT_NAME(ObjectDetector)
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
*/
- (BOOL)detectAsyncInImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
regionOfInterest:(CGRect)roi
error:(NSError **)error
NS_SWIFT_NAME(detectAsync(image:timestampMs:regionOfInterest:));
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi
error:(NSError **)error
NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:regionOfInterest:));
- (instancetype)init NS_UNAVAILABLE;

View File

@ -157,7 +157,7 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
}
- (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi
error:(NSError **)error {
std::optional<NormalizedRect> rect =
@ -170,14 +170,15 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
}
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
timestampMs:timestampMs
timestampInMilliseconds:timestampInMilliseconds
error:error];
if (imagePacket.IsEmpty()) {
return std::nullopt;
}
Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
timestampMs:timestampMs];
Packet normalizedRectPacket =
[MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
timestampInMilliseconds:timestampInMilliseconds];
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
return inputPacketMap;
@ -188,11 +189,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
}
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi
error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampMs:timestampMs
timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:roi
error:error];
if (!inputPacketMap.has_value()) {
@ -212,20 +213,20 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
}
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error {
return [self detectInVideoFrame:image
timestampMs:timestampMs
timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:CGRectZero
error:error];
}
- (BOOL)detectAsyncInImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
regionOfInterest:(CGRect)roi
error:(NSError **)error {
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi
error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampMs:timestampMs
timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:roi
error:error];
if (!inputPacketMap.has_value()) {
@ -236,10 +237,10 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
}
- (BOOL)detectAsyncInImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
error:(NSError **)error {
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error {
return [self detectAsyncInImage:image
timestampMs:timestampMs
timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:CGRectZero
error:error];
}

View File

@ -38,8 +38,9 @@ using ::mediapipe::Packet;
}
return [[MPPObjectDetectionResult alloc]
initWithDetections:detections
timestampMs:(NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond)];
initWithDetections:detections
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)];
}
@end

View File

@ -198,9 +198,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul>
*
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the
* size of the stylized output is based the model output size and can be smaller than the input
* image.
* <p>The input image can be of any size. The output image is the stylized image with the most
* visible face. The stylized output image size is the same as the model output size. When no face
* is detected on the input image, returns {@code Optional.empty()}.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is created
@ -220,9 +220,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul>
*
* <p>The input image can be of any size. To ensure that the output image has reasonable quality,
* the stylized output image size is the smaller of the model output size and the size of the
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
* <p>The input image can be of any size. The output image is the stylized image with the most
* visible face. The stylized output image size is the same as the model output size. When no face
* is detected on the input image, returns {@code Optional.empty()}.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
@ -256,9 +256,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul>
*
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the
* size of the stylized output is based the model output size and can be smaller than the input
* image.
* <p>The input image can be of any size. The output image is the stylized image with the most
* visible face. The stylized output image size is the same as the model output size. When no face
* is detected on the input image, returns {@code Optional.empty()}.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
@ -281,9 +281,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul>
*
* <p>The input image can be of any size. To ensure that the output image has reasonable quality,
* the stylized output image size is the smaller of the model output size and the size of the
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
* <p>The input image can be of any size. The output image is the stylized image with the most
* visible face. The stylized output image size is the same as the model output size. When no face
* is detected on the input image, returns {@code Optional.empty()}.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
@ -320,9 +320,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul>
*
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the
* size of the stylized output is based the model output size and can be smaller than the input
* image.
* <p>The input image can be of any size. The output image is the stylized image with the most
* visible face. The stylized output image size is the same as the model output size. When no face
* is detected on the input image, returns {@code Optional.empty()}.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
@ -346,9 +346,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul>
*
* <p>The input image can be of any size. To ensure that the output image has reasonable quality,
* the stylized output image size is the smaller of the model output size and the size of the
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
* <p>The input image can be of any size. The output image is the stylized image with the most
* visible face. The stylized output image size is the same as the model output size. When no face
* is detected on the input image, returns {@code Optional.empty()}. *
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
@ -387,9 +387,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul>
*
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the
* size of the stylized output is based the model output size and can be smaller than the input
* image.
* <p>The input image can be of any size. The output image is the stylized image with the most
* visible face. The stylized output image size is the same as the model output size. When no face
* is detected on the input image, returns {@code Optional.empty()}.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
@ -414,9 +414,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul>
*
* <p>The input image can be of any size. To ensure that the output image has reasonable quality,
* the stylized output image size is the smaller of the model output size and the size of the
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
* <p>The input image can be of any size. The output image is the stylized image with the most
* visible face. The stylized output image size is the same as the model output size. When no face
* is detected on the input image, returns {@code Optional.empty()}.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
@ -445,9 +445,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
*
* <p>{@link FaceStylizer} supports the following color space types:
*
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the
* size of the stylized output is based the model output * size and can be smaller than the input
* image.
* <p>The input image can be of any size. The output image is the stylized image with the most
* visible face. The stylized output image size is the same as the model output size. When no face
* is detected on the input image, returns {@code Optional.empty()}.
*
* <ul>
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
@ -475,9 +475,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul>
*
* <p>The input image can be of any size. To ensure that the output image has reasonable quality,
* the stylized output image size is the smaller of the model output size and the size of the
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
* <p>The input image can be of any size. The output image is the stylized image with the most
* visible face. The stylized output image size is the same as the model output size. When no face
* is detected on the input image, returns {@code Optional.empty()}.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the

View File

@ -94,15 +94,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
"IMAGE:" + IMAGE_IN_STREAM_NAME,
"ROI:" + ROI_IN_STREAM_NAME,
"NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList(
"GROUPED_SEGMENTATION:segmented_mask_out",
"IMAGE:image_out",
"SEGMENTATION:0:segmentation"));
private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0;
private static final int IMAGE_OUT_STREAM_INDEX = 1;
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
private static final int IMAGE_OUT_STREAM_INDEX = 0;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
@ -123,6 +115,21 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
*/
public static InteractiveSegmenter createFromOptions(
Context context, InteractiveSegmenterOptions segmenterOptions) {
if (!segmenterOptions.outputConfidenceMasks() && !segmenterOptions.outputCategoryMask()) {
throw new IllegalArgumentException(
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
}
List<String> outputStreams = new ArrayList<>();
outputStreams.add("IMAGE:image_out");
if (segmenterOptions.outputConfidenceMasks()) {
outputStreams.add("CONFIDENCE_MASKS:confidence_masks");
}
final int confidenceMasksOutStreamIndex = outputStreams.size() - 1;
if (segmenterOptions.outputCategoryMask()) {
outputStreams.add("CATEGORY_MASK:category_mask");
}
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
// TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
@ -130,52 +137,72 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
@Override
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
throws MediaPipeException {
if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) {
if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
return ImageSegmenterResult.create(
Optional.empty(),
Optional.empty(),
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp());
packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
}
List<MPImage> segmentedMasks = new ArrayList<>();
int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
int imageFormat =
segmenterOptions.outputType()
== InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK
? MPImage.IMAGE_FORMAT_VEC32F1
: MPImage.IMAGE_FORMAT_ALPHA;
int imageListSize =
PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX));
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
// If resultListener is not provided, the resulted MPImage is deep copied from mediapipe
// graph. If provided, the result MPImage is wrapping the mediapipe packet memory.
if (!segmenterOptions.resultListener().isPresent()) {
for (int i = 0; i < imageListSize; i++) {
buffersArray[i] =
ByteBuffer.allocateDirect(
width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1));
// If resultListener is not provided, the resulted MPImage is deep copied from
// mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
// memory.
boolean copyImage = !segmenterOptions.resultListener().isPresent();
Optional<List<MPImage>> confidenceMasks = Optional.empty();
if (segmenterOptions.outputConfidenceMasks()) {
confidenceMasks = Optional.of(new ArrayList<>());
int width =
PacketGetter.getImageWidthFromImageList(
packets.get(confidenceMasksOutStreamIndex));
int height =
PacketGetter.getImageHeightFromImageList(
packets.get(confidenceMasksOutStreamIndex));
int imageListSize =
PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
// confidence masks are float type image.
final int numBytes = 4;
if (copyImage) {
for (int i = 0; i < imageListSize; i++) {
buffersArray[i] = ByteBuffer.allocateDirect(width * height * numBytes);
}
}
if (!PacketGetter.getImageList(
packets.get(confidenceMasksOutStreamIndex), buffersArray, copyImage)) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"There is an error getting confidence masks.");
}
for (ByteBuffer buffer : buffersArray) {
ByteBufferImageBuilder builder =
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1);
confidenceMasks.get().add(builder.build());
}
}
if (!PacketGetter.getImageList(
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX),
buffersArray,
!segmenterOptions.resultListener().isPresent())) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"There is an error getting segmented masks. It usually results from incorrect"
+ " options of unsupported OutputType of given model.");
}
for (ByteBuffer buffer : buffersArray) {
Optional<MPImage> categoryMask = Optional.empty();
if (segmenterOptions.outputCategoryMask()) {
int width = PacketGetter.getImageWidth(packets.get(categoryMaskOutStreamIndex));
int height = PacketGetter.getImageHeight(packets.get(categoryMaskOutStreamIndex));
ByteBuffer buffer;
if (copyImage) {
buffer = ByteBuffer.allocateDirect(width * height);
if (!PacketGetter.getImageData(packets.get(categoryMaskOutStreamIndex), buffer)) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"There is an error getting category mask.");
}
} else {
buffer = PacketGetter.getImageDataDirectly(packets.get(categoryMaskOutStreamIndex));
}
ByteBufferImageBuilder builder =
new ByteBufferImageBuilder(buffer, width, height, imageFormat);
segmentedMasks.add(builder.build());
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
categoryMask = Optional.of(builder.build());
}
return ImageSegmenterResult.create(
Optional.of(segmentedMasks),
Optional.empty(),
confidenceMasks,
categoryMask,
BaseVisionTaskApi.generateResultTimestampMs(
RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
RunningMode.IMAGE, packets.get(IMAGE_OUT_STREAM_INDEX)));
}
@Override
@ -195,7 +222,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
.setTaskRunningModeName(RunningMode.IMAGE.name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)
.setOutputStreams(outputStreams)
.setTaskOptions(segmenterOptions)
.setEnableFlowLimiting(false)
.build(),
@ -394,8 +421,11 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
/** Sets the base options for the image segmenter task. */
public abstract Builder setBaseOptions(BaseOptions value);
/** The output type from image segmenter. */
public abstract Builder setOutputType(OutputType value);
/** Sets whether to output confidence masks. Default to true. */
public abstract Builder setOutputConfidenceMasks(boolean value);
/** Sets whether to output category mask. Default to false. */
public abstract Builder setOutputCategoryMask(boolean value);
/**
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
@ -417,25 +447,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
abstract BaseOptions baseOptions();
abstract OutputType outputType();
abstract boolean outputConfidenceMasks();
abstract boolean outputCategoryMask();
abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
abstract Optional<ErrorListener> errorListener();
/** The output type of segmentation results. */
public enum OutputType {
// Gives a single output mask where each pixel represents the class which
// the pixel in the original image was predicted to belong to.
CATEGORY_MASK,
// Gives a list of output masks where, for each mask, each pixel represents
// the prediction confidence, usually in the [0, 1] range.
CONFIDENCE_MASK
}
public static Builder builder() {
return new AutoValue_InteractiveSegmenter_InteractiveSegmenterOptions.Builder()
.setOutputType(OutputType.CATEGORY_MASK);
.setOutputConfidenceMasks(true)
.setOutputCategoryMask(false);
}
/**
@ -454,14 +477,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
SegmenterOptionsProto.SegmenterOptions.newBuilder();
if (outputType() == OutputType.CONFIDENCE_MASK) {
segmenterOptionsBuilder.setOutputType(
SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK);
} else if (outputType() == OutputType.CATEGORY_MASK) {
segmenterOptionsBuilder.setOutputType(
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
}
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
return CalculatorOptions.newBuilder()
.setExtension(

View File

@ -234,8 +234,8 @@ public class FaceStylizerTest {
FaceStylizerResult actualResult = faceStylizer.stylize(inputImage);
MPImage stylizedImage = actualResult.stylizedImage().get();
assertThat(stylizedImage).isNotNull();
assertThat(stylizedImage.getWidth()).isEqualTo(83);
assertThat(stylizedImage.getHeight()).isEqualTo(83);
assertThat(stylizedImage.getWidth()).isEqualTo(modelImageSize);
assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize);
}
@Test

View File

@ -53,18 +53,15 @@ public class InteractiveSegmenterTest {
InteractiveSegmenterOptions options =
InteractiveSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(InteractiveSegmenterOptions.OutputType.CATEGORY_MASK)
.setOutputConfidenceMasks(false)
.setOutputCategoryMask(true)
.build();
InteractiveSegmenter imageSegmenter =
InteractiveSegmenter.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
MPImage image = getImageFromAsset(inputImageName);
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
// TODO update to correct category mask output.
// After InteractiveSegmenter updated according to (b/276519300), update this to use
// categoryMask field instead of confidenceMasks.
List<MPImage> segmentations = actualResult.confidenceMasks().get();
assertThat(segmentations.size()).isEqualTo(1);
assertThat(actualResult.categoryMask().isPresent()).isTrue();
}
@Test
@ -75,15 +72,17 @@ public class InteractiveSegmenterTest {
InteractiveSegmenterOptions options =
InteractiveSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setOutputConfidenceMasks(true)
.setOutputCategoryMask(false)
.build();
InteractiveSegmenter imageSegmenter =
InteractiveSegmenter.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
ImageSegmenterResult actualResult =
imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
List<MPImage> segmentations = actualResult.confidenceMasks().get();
assertThat(segmentations.size()).isEqualTo(2);
assertThat(actualResult.confidenceMasks().isPresent()).isTrue();
List<MPImage> confidenceMasks = actualResult.confidenceMasks().get();
assertThat(confidenceMasks.size()).isEqualTo(2);
}
}

View File

@ -204,6 +204,11 @@ This can be useful for resetting a stateful task graph to process new data.
Raises:
RuntimeError: The underlying medipaipe graph fails to reset and restart.
)doc");
task_runner.def(
"get_graph_config",
[](TaskRunner* self) { return self->GetGraphConfig(); },
R"doc(Returns the canonicalized CalculatorGraphConfig of the underlying graph.)doc");
}
} // namespace python

View File

@ -32,6 +32,7 @@ _TextEmbedderOptions = text_embedder.TextEmbedderOptions
_BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite'
_REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite'
_USE_MODEL_FILE = 'universal_sentence_encoder_qa_with_metadata.tflite'
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
# Tolerance for embedding vector coordinate values.
_EPSILON = 1e-4
@ -138,6 +139,24 @@ class TextEmbedderTest(parameterized.TestCase):
16,
(0.549632, 0.552879),
),
(
False,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_NAME,
0.851961,
100,
(1.422951, 1.404664),
),
(
True,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_CONTENT,
0.851961,
100,
(0.127049, 0.125416),
),
)
def test_embed(self, l2_normalize, quantize, model_name, model_file_type,
expected_similarity, expected_size, expected_first_values):
@ -213,6 +232,24 @@ class TextEmbedderTest(parameterized.TestCase):
16,
(0.549632, 0.552879),
),
(
False,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_NAME,
0.851961,
100,
(1.422951, 1.404664),
),
(
True,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_CONTENT,
0.851961,
100,
(0.127049, 0.125416),
),
)
def test_embed_in_context(self, l2_normalize, quantize, model_name,
model_file_type, expected_similarity, expected_size,
@ -251,6 +288,7 @@ class TextEmbedderTest(parameterized.TestCase):
@parameterized.parameters(
# TODO: The similarity should likely be lower
(_BERT_MODEL_FILE, 0.980880),
(_USE_MODEL_FILE, 0.780334),
)
def test_embed_with_different_themes(self, model_file, expected_similarity):
# Creates embedder.

View File

@ -15,7 +15,6 @@
import enum
import os
from typing import List
from unittest import mock
from absl.testing import absltest
@ -30,11 +29,10 @@ from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import image_segmenter
from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageSegmenterResult = image_segmenter.ImageSegmenterResult
_BaseOptions = base_options_module.BaseOptions
_Image = image_module.Image
_ImageFormat = image_frame.ImageFormat
_OutputType = image_segmenter.ImageSegmenterOptions.OutputType
_Activation = image_segmenter.ImageSegmenterOptions.Activation
_ImageSegmenter = image_segmenter.ImageSegmenter
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
@ -42,11 +40,33 @@ _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
_MODEL_FILE = 'deeplabv3.tflite'
_IMAGE_FILE = 'segmentation_input_rotation0.jpg'
_SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
_CAT_IMAGE = 'cat.jpg'
_CAT_MASK = 'cat_mask.jpg'
_MASK_MAGNIFICATION_FACTOR = 10
_MASK_SIMILARITY_THRESHOLD = 0.98
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
def _calculate_soft_iou(m1, m2):
intersection_sum = np.sum(m1 * m2)
union_sum = np.sum(m1 * m1) + np.sum(m2 * m2) - intersection_sum
if union_sum > 0:
return intersection_sum / union_sum
else:
return 0
def _similar_to_float_mask(actual_mask, expected_mask, similarity_threshold):
actual_mask = actual_mask.numpy_view()
expected_mask = expected_mask.numpy_view() / 255.0
return (
actual_mask.shape == expected_mask.shape
and _calculate_soft_iou(actual_mask, expected_mask) > similarity_threshold
)
def _similar_to_uint8_mask(actual_mask, expected_mask):
actual_mask_pixels = actual_mask.numpy_view().flatten()
expected_mask_pixels = expected_mask.numpy_view().flatten()
@ -56,8 +76,9 @@ def _similar_to_uint8_mask(actual_mask, expected_mask):
for index in range(num_pixels):
consistent_pixels += (
actual_mask_pixels[index] *
_MASK_MAGNIFICATION_FACTOR == expected_mask_pixels[index])
actual_mask_pixels[index] * _MASK_MAGNIFICATION_FACTOR
== expected_mask_pixels[index]
)
return consistent_pixels / num_pixels >= _MASK_SIMILARITY_THRESHOLD
@ -73,16 +94,27 @@ class ImageSegmenterTest(parameterized.TestCase):
super().setUp()
# Load the test input image.
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))
)
# Loads ground truth segmentation file.
gt_segmentation_data = cv2.imread(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)),
cv2.IMREAD_GRAYSCALE)
os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)
),
cv2.IMREAD_GRAYSCALE,
)
self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data)
self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
)
def _load_segmentation_mask(self, file_path: str):
# Loads ground truth segmentation file.
gt_segmentation_data = cv2.imread(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_path)),
cv2.IMREAD_GRAYSCALE,
)
return _Image(_ImageFormat.GRAY8, gt_segmentation_data)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
@ -98,9 +130,11 @@ class ImageSegmenterTest(parameterized.TestCase):
def test_create_from_options_fails_with_invalid_model_path(self):
with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
):
base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite')
model_asset_path='/path/to/invalid/model.tflite'
)
options = _ImageSegmenterOptions(base_options=base_options)
_ImageSegmenter.create_from_options(options)
@ -112,8 +146,9 @@ class ImageSegmenterTest(parameterized.TestCase):
segmenter = _ImageSegmenter.create_from_options(options)
self.assertIsInstance(segmenter, _ImageSegmenter)
@parameterized.parameters((ModelFileType.FILE_NAME,),
(ModelFileType.FILE_CONTENT,))
@parameterized.parameters(
(ModelFileType.FILE_NAME,), (ModelFileType.FILE_CONTENT,)
)
def test_segment_succeeds_with_category_mask(self, model_file_type):
# Creates segmenter.
if model_file_type is ModelFileType.FILE_NAME:
@ -127,22 +162,27 @@ class ImageSegmenterTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.')
options = _ImageSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
base_options=base_options,
output_category_mask=True,
output_confidence_masks=False,
)
segmenter = _ImageSegmenter.create_from_options(options)
# Performs image segmentation on the input.
category_masks = segmenter.segment(self.test_image)
self.assertLen(category_masks, 1)
category_mask = category_masks[0]
segmentation_result = segmenter.segment(self.test_image)
category_mask = segmentation_result.category_mask
result_pixels = category_mask.numpy_view().flatten()
# Check if data type of `category_mask` is correct.
self.assertEqual(result_pixels.dtype, np.uint8)
self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
_similar_to_uint8_mask(category_mask, self.test_seg_image),
(
'Number of pixels in the candidate mask differing from that of the'
f' ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
),
)
# Closes the segmenter explicitly when the segmenter is not used in
# a context.
@ -152,74 +192,46 @@ class ImageSegmenterTest(parameterized.TestCase):
# Creates segmenter.
base_options = _BaseOptions(model_asset_path=self.model_path)
# Run segmentation on the model in CATEGORY_MASK mode.
options = _ImageSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
segmenter = _ImageSegmenter.create_from_options(options)
category_masks = segmenter.segment(self.test_image)
category_mask = category_masks[0].numpy_view()
# Load the cat image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
)
# Run segmentation on the model in CONFIDENCE_MASK mode.
options = _ImageSegmenterOptions(
base_options=base_options,
output_type=_OutputType.CONFIDENCE_MASK,
activation=_Activation.SOFTMAX)
segmenter = _ImageSegmenter.create_from_options(options)
confidence_masks = segmenter.segment(self.test_image)
output_category_mask=False,
output_confidence_masks=True,
)
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks, 21,
'Number of confidence masks must match with number of categories.')
# Gather the confidence masks in a single array `confidence_mask_array`.
confidence_mask_array = np.array(
[confidence_mask.numpy_view() for confidence_mask in confidence_masks])
# Check if data type of `confidence_masks` are correct.
self.assertEqual(confidence_mask_array.dtype, np.float32)
# Compute the category mask from the created confidence mask.
calculated_category_mask = np.argmax(confidence_mask_array, axis=0)
self.assertListEqual(
calculated_category_mask.tolist(), category_mask.tolist(),
'Confidence mask does not match with the category mask.')
# Closes the segmenter explicitly when the segmenter is not used in
# a context.
segmenter.close()
@parameterized.parameters((ModelFileType.FILE_NAME),
(ModelFileType.FILE_CONTENT))
def test_segment_in_context(self, model_file_type):
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_contents = f.read()
base_options = _BaseOptions(model_asset_buffer=model_contents)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _ImageSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
with _ImageSegmenter.create_from_options(options) as segmenter:
# Performs image segmentation on the input.
category_masks = segmenter.segment(self.test_image)
self.assertLen(category_masks, 1)
segmentation_result = segmenter.segment(test_image)
confidence_masks = segmentation_result.confidence_masks
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks,
21,
'Number of confidence masks must match with number of categories.',
)
# Loads ground truth segmentation file.
expected_mask = self._load_segmentation_mask(_CAT_MASK)
self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
_similar_to_float_mask(
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
)
)
def test_missing_result_callback(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
running_mode=_RUNNING_MODE.LIVE_STREAM,
)
with self.assertRaisesRegex(
ValueError, r'result callback must be provided'
):
with _ImageSegmenter.create_from_options(options) as unused_segmenter:
pass
@ -228,130 +240,236 @@ class ImageSegmenterTest(parameterized.TestCase):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
result_callback=mock.MagicMock(),
)
with self.assertRaisesRegex(
ValueError, r'result callback should not be provided'
):
with _ImageSegmenter.create_from_options(options) as unused_segmenter:
pass
def test_calling_segment_for_video_in_image_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
running_mode=_RUNNING_MODE.IMAGE,
)
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
segmenter.segment_for_video(self.test_image, 0)
def test_calling_segment_async_in_image_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
running_mode=_RUNNING_MODE.IMAGE,
)
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
segmenter.segment_async(self.test_image, 0)
def test_calling_segment_in_video_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
segmenter.segment(self.test_image)
def test_calling_segment_async_in_video_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
segmenter.segment_async(self.test_image, 0)
def test_segment_for_video_with_out_of_order_timestamp(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageSegmenter.create_from_options(options) as segmenter:
unused_result = segmenter.segment_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
ValueError, r'Input timestamp must be monotonically increasing'
):
segmenter.segment_for_video(self.test_image, 0)
def test_segment_for_video(self):
def test_segment_for_video_in_category_mask_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
output_type=_OutputType.CATEGORY_MASK,
running_mode=_RUNNING_MODE.VIDEO)
output_category_mask=True,
output_confidence_masks=False,
running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
category_masks = segmenter.segment_for_video(self.test_image, timestamp)
self.assertLen(category_masks, 1)
segmentation_result = segmenter.segment_for_video(
self.test_image, timestamp
)
category_mask = segmentation_result.category_mask
self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
_similar_to_uint8_mask(category_mask, self.test_seg_image),
(
'Number of pixels in the candidate mask differing from that of'
f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
),
)
def test_segment_for_video_in_confidence_mask_mode(self):
# Load the cat image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
)
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
output_category_mask=False,
output_confidence_masks=True,
)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
segmentation_result = segmenter.segment_for_video(test_image, timestamp)
confidence_masks = segmentation_result.confidence_masks
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks,
21,
'Number of confidence masks must match with number of categories.',
)
# Loads ground truth segmentation file.
expected_mask = self._load_segmentation_mask(_CAT_MASK)
self.assertTrue(
_similar_to_float_mask(
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
)
)
def test_calling_segment_in_live_stream_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
result_callback=mock.MagicMock(),
)
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
segmenter.segment(self.test_image)
def test_calling_segment_for_video_in_live_stream_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
result_callback=mock.MagicMock(),
)
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
segmenter.segment_for_video(self.test_image, 0)
def test_segment_async_calls_with_illegal_timestamp(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
result_callback=mock.MagicMock(),
)
with _ImageSegmenter.create_from_options(options) as segmenter:
segmenter.segment_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
ValueError, r'Input timestamp must be monotonically increasing'
):
segmenter.segment_async(self.test_image, 0)
def test_segment_async_calls(self):
def test_segment_async_calls_in_category_mask_mode(self):
observed_timestamp_ms = -1
def check_result(result: List[image_module.Image], output_image: _Image,
timestamp_ms: int):
def check_result(
result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int
):
# Get the output category mask.
category_mask = result[0]
category_mask = result.category_mask
self.assertEqual(output_image.width, self.test_image.width)
self.assertEqual(output_image.height, self.test_image.height)
self.assertEqual(output_image.width, self.test_seg_image.width)
self.assertEqual(output_image.height, self.test_seg_image.height)
self.assertTrue(
_similar_to_uint8_mask(category_mask, self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
(
'Number of pixels in the candidate mask differing from that of'
f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
),
)
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
output_type=_OutputType.CATEGORY_MASK,
output_category_mask=True,
output_confidence_masks=False,
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result)
result_callback=check_result,
)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
segmenter.segment_async(self.test_image, timestamp)
def test_segment_async_calls_in_confidence_mask_mode(self):
# Load the cat image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
)
# Loads ground truth segmentation file.
expected_mask = self._load_segmentation_mask(_CAT_MASK)
observed_timestamp_ms = -1
def check_result(
result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int
):
# Get the output category mask.
confidence_masks = result.confidence_masks
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks,
21,
'Number of confidence masks must match with number of categories.',
)
self.assertEqual(output_image.width, test_image.width)
self.assertEqual(output_image.height, test_image.height)
self.assertTrue(
_similar_to_float_mask(
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
)
)
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
output_category_mask=False,
output_confidence_masks=True,
result_callback=check_result,
)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
segmenter.segment_async(test_image, timestamp)
if __name__ == '__main__':
absltest.main()

View File

@ -30,12 +30,12 @@ from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import interactive_segmenter
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
InteractiveSegmenterResult = interactive_segmenter.InteractiveSegmenterResult
_BaseOptions = base_options_module.BaseOptions
_Image = image_module.Image
_ImageFormat = image_frame.ImageFormat
_NormalizedKeypoint = keypoint_module.NormalizedKeypoint
_Rect = rect.Rect
_OutputType = interactive_segmenter.InteractiveSegmenterOptions.OutputType
_InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter
_InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions
_RegionOfInterest = interactive_segmenter.RegionOfInterest
@ -200,15 +200,16 @@ class InteractiveSegmenterTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.')
options = _InteractiveSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK
base_options=base_options,
output_category_mask=True,
output_confidence_masks=False,
)
segmenter = _InteractiveSegmenter.create_from_options(options)
# Performs image segmentation on the input.
roi = _RegionOfInterest(format=roi_format, keypoint=keypoint)
category_masks = segmenter.segment(self.test_image, roi)
self.assertLen(category_masks, 1)
category_mask = category_masks[0]
segmentation_result = segmenter.segment(self.test_image, roi)
category_mask = segmentation_result.category_mask
result_pixels = category_mask.numpy_view().flatten()
# Check if data type of `category_mask` is correct.
@ -219,7 +220,7 @@ class InteractiveSegmenterTest(parameterized.TestCase):
self.assertTrue(
_similar_to_uint8_mask(
category_masks[0], test_seg_image, similarity_threshold
category_mask, test_seg_image, similarity_threshold
),
(
'Number of pixels in the candidate mask differing from that of the'
@ -254,12 +255,15 @@ class InteractiveSegmenterTest(parameterized.TestCase):
# Run segmentation on the model in CONFIDENCE_MASK mode.
options = _InteractiveSegmenterOptions(
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK
base_options=base_options,
output_category_mask=False,
output_confidence_masks=True,
)
with _InteractiveSegmenter.create_from_options(options) as segmenter:
# Perform segmentation
confidence_masks = segmenter.segment(self.test_image, roi)
segmentation_result = segmenter.segment(self.test_image, roi)
confidence_masks = segmentation_result.confidence_masks
# Check if confidence mask shape is correct.
self.assertLen(
@ -287,15 +291,18 @@ class InteractiveSegmenterTest(parameterized.TestCase):
# Run segmentation on the model in CONFIDENCE_MASK mode.
options = _InteractiveSegmenterOptions(
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK
base_options=base_options,
output_category_mask=False,
output_confidence_masks=True,
)
with _InteractiveSegmenter.create_from_options(options) as segmenter:
# Perform segmentation
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
confidence_masks = segmenter.segment(
segmentation_result = segmenter.segment(
self.test_image, roi, image_processing_options
)
confidence_masks = segmentation_result.confidence_masks
# Check if confidence mask shape is correct.
self.assertLen(
@ -314,7 +321,9 @@ class InteractiveSegmenterTest(parameterized.TestCase):
# Run segmentation on the model in CONFIDENCE_MASK mode.
options = _InteractiveSegmenterOptions(
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK
base_options=base_options,
output_category_mask=False,
output_confidence_masks=True,
)
with self.assertRaisesRegex(

View File

@ -32,6 +32,7 @@ FaceDetectorResult = face_detector.FaceDetectorResult
FaceLandmarker = face_landmarker.FaceLandmarker
FaceLandmarkerOptions = face_landmarker.FaceLandmarkerOptions
FaceLandmarkerResult = face_landmarker.FaceLandmarkerResult
FaceLandmarksConnections = face_landmarker.FaceLandmarksConnections
FaceStylizer = face_stylizer.FaceStylizer
FaceStylizerOptions = face_stylizer.FaceStylizerOptions
GestureRecognizer = gesture_recognizer.GestureRecognizer

View File

@ -208,6 +208,11 @@ class BaseVisionTaskApi(object):
"""
self._runner.close()
def get_graph_config(self) -> calculator_pb2.CalculatorGraphConfig:
"""Returns the canonicalized CalculatorGraphConfig of the underlying graph.
"""
return self._runner.get_graph_config()
def __enter__(self):
"""Return `self` upon entering the runtime context."""
return self

File diff suppressed because it is too large Load Diff

View File

@ -176,16 +176,13 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
Only use this method when the FaceStylizer is created with the image
running mode.
To ensure that the output image has reasonable quality, the stylized output
image size is the smaller of the model output size and the size of the
`region_of_interest` specified in `image_processing_options`.
Args:
image: MediaPipe Image.
image_processing_options: Options for image processing.
Returns:
The stylized image of the most visible face. None if no face is detected
The stylized image of the most visible face. The stylized output image
size is the same as the model output size. None if no face is detected
on the input image.
Raises:
@ -217,17 +214,14 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
milliseconds) along with the video frame. The input timestamps should be
monotonically increasing for adjacent calls of this method.
To ensure that the output image has reasonable quality, the stylized output
image size is the smaller of the model output size and the size of the
`region_of_interest` specified in `image_processing_options`.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds.
image_processing_options: Options for image processing.
Returns:
The stylized image of the most visible face. None if no face is detected
The stylized image of the most visible face. The stylized output image
size is the same as the model output size. None if no face is detected
on the input image.
Raises:
@ -266,12 +260,9 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
images if needed. In other words, it's not guaranteed to have output per
input image.
To ensure that the stylized image has reasonable quality, the stylized
output image size is the smaller of the model output size and the size of
the `region_of_interest` specified in `image_processing_options`.
The `result_callback` provides:
- The stylized image of the most visible face. None if no face is detected
- The stylized image of the most visible face. The stylized output image
size is the same as the model output size. None if no face is detected
on the input image.
- The input image that the face stylizer runs on.
- The input timestamp in milliseconds.

View File

@ -14,7 +14,6 @@
"""MediaPipe image segmenter task."""
import dataclasses
import enum
from typing import Callable, List, Mapping, Optional
from mediapipe.python import packet_creator
@ -31,7 +30,6 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageSegmenterResult = List[image_module.Image]
_NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
@ -42,8 +40,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out'
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
_CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks'
_CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS'
_CATEGORY_MASK_STREAM_NAME = 'category_mask'
_CATEGORY_MASK_TAG = 'CATEGORY_MASK'
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
@ -53,6 +53,21 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
@dataclasses.dataclass
class ImageSegmenterResult:
"""Output result of ImageSegmenter.
confidence_masks: multiple masks of float image where, for each mask, each
pixel represents the prediction confidence, usually in the [0, 1] range.
category_mask: a category mask of uint8 image where each pixel represents the
class which the pixel in the original image was predicted to belong to.
"""
confidence_masks: Optional[List[image_module.Image]] = None
category_mask: Optional[image_module.Image] = None
@dataclasses.dataclass
class ImageSegmenterOptions:
"""Options for the image segmenter task.
@ -64,28 +79,17 @@ class ImageSegmenterOptions:
objects on single image inputs. 2) The video mode for segmenting objects
on the decoded frames of a video. 3) The live stream mode for segmenting
objects on a live stream of input data, such as from camera.
output_type: The output mask type allows specifying the type of
post-processing to perform on the raw model results.
activation: Activation function to apply to input tensor.
output_confidence_masks: Whether to output confidence masks.
output_category_mask: Whether to output category mask.
result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode
is set to the live stream mode.
"""
class OutputType(enum.Enum):
UNSPECIFIED = 0
CATEGORY_MASK = 1
CONFIDENCE_MASK = 2
class Activation(enum.Enum):
NONE = 0
SIGMOID = 1
SOFTMAX = 2
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK
activation: Optional[Activation] = Activation.NONE
output_confidence_masks: bool = True
output_category_mask: bool = False
result_callback: Optional[
Callable[[ImageSegmenterResult, image_module.Image, int], None]
] = None
@ -97,9 +101,7 @@ class ImageSegmenterOptions:
base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
segmenter_options_proto = _SegmenterOptionsProto(
output_type=self.output_type.value, activation=self.activation.value
)
segmenter_options_proto = _SegmenterOptionsProto()
return _ImageSegmenterGraphOptionsProto(
base_options=base_options_proto,
segmenter_options=segmenter_options_proto,
@ -177,27 +179,48 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
def packets_callback(output_packets: Mapping[str, packet.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
)
segmentation_result = ImageSegmenterResult()
if options.output_confidence_masks:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if options.output_category_mask:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(
segmentation_result,
image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
output_streams = [
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
]
if options.output_confidence_masks:
output_streams.append(
':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME])
)
if options.output_category_mask:
output_streams.append(
':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME])
)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
],
output_streams=[
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
output_streams=output_streams,
task_options=options,
)
return cls(
@ -240,9 +263,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2()
),
})
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
)
segmentation_result = ImageSegmenterResult()
if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if _CATEGORY_MASK_STREAM_NAME in output_packets:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
return segmentation_result
def segment_for_video(
@ -285,9 +317,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
)
segmentation_result = ImageSegmenterResult()
if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if _CATEGORY_MASK_STREAM_NAME in output_packets:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
return segmentation_result
def segment_async(

View File

@ -41,8 +41,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out'
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
_CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks'
_CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS'
_CATEGORY_MASK_STREAM_NAME = 'category_mask'
_CATEGORY_MASK_TAG = 'CATEGORY_MASK'
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_ROI_STREAM_NAME = 'roi_in'
@ -55,32 +57,41 @@ _TASK_GRAPH_NAME = (
)
@dataclasses.dataclass
class InteractiveSegmenterResult:
"""Output result of InteractiveSegmenter.
confidence_masks: multiple masks of float image where, for each mask, each
pixel represents the prediction confidence, usually in the [0, 1] range.
category_mask: a category mask of uint8 image where each pixel represents the
class which the pixel in the original image was predicted to belong to.
"""
confidence_masks: Optional[List[image_module.Image]] = None
category_mask: Optional[image_module.Image] = None
@dataclasses.dataclass
class InteractiveSegmenterOptions:
"""Options for the interactive segmenter task.
Attributes:
base_options: Base options for the interactive segmenter task.
output_type: The output mask type allows specifying the type of
post-processing to perform on the raw model results.
output_confidence_masks: Whether to output confidence masks.
output_category_mask: Whether to output category mask.
"""
class OutputType(enum.Enum):
UNSPECIFIED = 0
CATEGORY_MASK = 1
CONFIDENCE_MASK = 2
base_options: _BaseOptions
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK
output_confidence_masks: bool = True
output_category_mask: bool = False
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
"""Generates an InteractiveSegmenterOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False
segmenter_options_proto = _SegmenterOptionsProto(
output_type=self.output_type.value
)
segmenter_options_proto = _SegmenterOptionsProto()
return _ImageSegmenterGraphOptionsProto(
base_options=base_options_proto,
segmenter_options=segmenter_options_proto,
@ -192,6 +203,20 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If other types of error occurred.
"""
output_streams = [
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
]
if options.output_confidence_masks:
output_streams.append(
':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME])
)
if options.output_category_mask:
output_streams.append(
':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME])
)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[
@ -199,10 +224,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
':'.join([_ROI_TAG, _ROI_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
],
output_streams=[
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
output_streams=output_streams,
task_options=options,
)
return cls(
@ -216,7 +238,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
image: image_module.Image,
roi: RegionOfInterest,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> List[image_module.Image]:
) -> InteractiveSegmenterResult:
"""Performs the actual segmentation task on the provided MediaPipe Image.
The image can be of any size with format RGB.
@ -248,7 +270,16 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2()
),
})
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
)
segmentation_result = InteractiveSegmenterResult()
if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if _CATEGORY_MASK_STREAM_NAME in output_packets:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
return segmentation_result

View File

@ -59,13 +59,12 @@ export function drawCategoryMask(
const isFloatArray = image instanceof Float32Array;
for (let i = 0; i < image.length; i++) {
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
const color = COLOR_MAP[colorIndex];
let color = COLOR_MAP[colorIndex % COLOR_MAP.length];
// When we're given a confidence mask by accident, we just log and return.
// TODO: We should fix this.
if (!color) {
// TODO: We should fix this.
console.warn('No color for ', colorIndex);
return;
color = COLOR_MAP[colorIndex % COLOR_MAP.length];
}
rgbaArray[4 * i] = color[0];

View File

@ -25,17 +25,6 @@ import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/ke
*/
export type SegmentationMask = Uint8ClampedArray|Float32Array|WebGLTexture;
/**
* A callback that receives the computed masks from the segmentation tasks. The
* callback either receives a single element array with a category mask (as a
* `[Uint8ClampedArray]`) or multiple confidence masks (as a `Float32Array[]`).
* The returned data is only valid for the duration of the callback. If
* asynchronous processing is needed, all data needs to be copied before the
* callback returns.
*/
export type SegmentationMaskCallback =
(masks: SegmentationMask[], width: number, height: number) => void;
/**
* A callback that receives an `ImageData` object from a Vision task. The
* lifetime of the underlying data is limited to the duration of the callback.

View File

@ -19,7 +19,7 @@ import {Connection} from '../../../../tasks/web/vision/core/types';
// tslint:disable:class-as-namespace Using for easier import by 3P users
/**
* A class containing the Pairs of landmark indices to be rendered with
* A class containing the pairs of landmark indices to be rendered with
* connections.
*/
export class FaceLandmarksConnections {

View File

@ -129,10 +129,6 @@ export class FaceStylizer extends VisionTaskRunner {
* synchronously once the callback returns. Only use this method when the
* FaceStylizer is created with the image running mode.
*
* The input image can be of any size. To ensure that the output image has
* reasonable quality, the stylized output image size is determined by the
* model output size.
*
* @param image An image to process.
* @param callback The callback that is invoked with the stylized image. The
* lifetime of the returned data is only guaranteed for the duration of the
@ -153,11 +149,6 @@ export class FaceStylizer extends VisionTaskRunner {
* If both are specified, the crop around the region-of-interest is extracted
* first, then the specified rotation is applied to the crop.
*
* The input image can be of any size. To ensure that the output image has
* reasonable quality, the stylized output image size is the smaller of the
* model output size and the size of the 'regionOfInterest' specified in
* 'imageProcessingOptions'.
*
* @param image An image to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
@ -192,9 +183,6 @@ export class FaceStylizer extends VisionTaskRunner {
* frame's timestamp (in milliseconds). The input timestamps must be
* monotonically increasing.
*
* To ensure that the output image has reasonable quality, the stylized
* output image size is determined by the model output size.
*
* @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the stylized image. The
@ -221,10 +209,6 @@ export class FaceStylizer extends VisionTaskRunner {
* frame's timestamp (in milliseconds). The input timestamps must be
* monotonically increasing.
*
* To ensure that the output image has reasonable quality, the stylized
* output image size is the smaller of the model output size and the size of
* the 'regionOfInterest' specified in 'imageProcessingOptions'.
*
* @param videoFrame A video frame to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
@ -278,8 +262,12 @@ export class FaceStylizer extends VisionTaskRunner {
this.graphRunner.attachImageListener(
STYLIZED_IMAGE_STREAM, (image, timestamp) => {
const imageData = this.convertToImageData(image);
this.userCallback(imageData, image.width, image.height);
if (image.data instanceof WebGLTexture) {
this.userCallback(image.data, image.width, image.height);
} else {
const imageData = this.convertToImageData(image);
this.userCallback(imageData, image.width, image.height);
}
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(

View File

@ -34,6 +34,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/vision/core:image_processing_options",
"//mediapipe/tasks/web/vision/core:vision_task_runner",
"//mediapipe/tasks/web/vision/hand_landmarker:hand_landmarks_connections",
"//mediapipe/web/graph_runner:graph_runner_ts",
],
)

View File

@ -31,6 +31,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {HAND_CONNECTIONS} from '../../../../tasks/web/vision/hand_landmarker/hand_landmarks_connections';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
@ -72,6 +73,12 @@ export class GestureRecognizer extends VisionTaskRunner {
private readonly handGestureRecognizerGraphOptions:
HandGestureRecognizerGraphOptions;
/**
* An array containing the pairs of hand landmark indices to be rendered with
* connections.
*/
static HAND_CONNECTIONS = HAND_CONNECTIONS;
/**
* Initializes the Wasm runtime and creates a new gesture recognizer from the
* provided options.

View File

@ -16,6 +16,7 @@ mediapipe_ts_library(
visibility = ["//visibility:public"],
deps = [
":hand_landmarker_types",
":hand_landmarks_connections",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework:calculator_options_jspb_proto",
"//mediapipe/framework/formats:classification_jspb_proto",
@ -72,3 +73,9 @@ jasmine_node_test(
tags = ["nomsan"],
deps = [":hand_landmarker_test_lib"],
)
mediapipe_ts_library(
name = "hand_landmarks_connections",
srcs = ["hand_landmarks_connections.ts"],
deps = ["//mediapipe/tasks/web/vision/core:types"],
)

View File

@ -27,6 +27,7 @@ import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/con
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {HAND_CONNECTIONS} from '../../../../tasks/web/vision/hand_landmarker/hand_landmarks_connections';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
@ -63,6 +64,12 @@ export class HandLandmarker extends VisionTaskRunner {
HandLandmarksDetectorGraphOptions;
private readonly handDetectorGraphOptions: HandDetectorGraphOptions;
/**
* An array containing the pairs of hand landmark indices to be rendered with
* connections.
*/
static HAND_CONNECTIONS = HAND_CONNECTIONS;
/**
* Initializes the Wasm runtime and creates a new `HandLandmarker` from the
* provided options.

View File

@ -0,0 +1,31 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {Connection} from '../../../../tasks/web/vision/core/types';
/**
* An array containing the pairs of hand landmark indices to be rendered with
* connections.
*/
export const HAND_CONNECTIONS: Connection[] = [
{start: 0, end: 1}, {start: 1, end: 2}, {start: 2, end: 3},
{start: 3, end: 4}, {start: 0, end: 5}, {start: 5, end: 6},
{start: 6, end: 7}, {start: 7, end: 8}, {start: 5, end: 9},
{start: 9, end: 10}, {start: 10, end: 11}, {start: 11, end: 12},
{start: 9, end: 13}, {start: 13, end: 14}, {start: 14, end: 15},
{start: 15, end: 16}, {start: 13, end: 17}, {start: 0, end: 17},
{start: 17, end: 18}, {start: 18, end: 19}, {start: 19, end: 20}
];

View File

@ -29,7 +29,10 @@ mediapipe_ts_library(
mediapipe_ts_declaration(
name = "image_segmenter_types",
srcs = ["image_segmenter_options.d.ts"],
srcs = [
"image_segmenter_options.d.ts",
"image_segmenter_result.d.ts",
],
deps = [
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options",

View File

@ -22,33 +22,48 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
import {SegmentationMask} from '../../../../tasks/web/vision/core/types';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {LabelMapItem} from '../../../../util/label_map_pb';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
import {ImageSegmenterOptions} from './image_segmenter_options';
import {ImageSegmenterResult} from './image_segmenter_result';
export * from './image_segmenter_options';
export {SegmentationMask, SegmentationMaskCallback};
export * from './image_segmenter_result';
export {SegmentationMask};
export {ImageSource}; // Used in the public API
const IMAGE_STREAM = 'image_in';
const NORM_RECT_STREAM = 'norm_rect';
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
const CATEGORY_MASK_STREAM = 'category_mask';
const IMAGE_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
'mediapipe.tasks.TensorsToSegmentationCalculator';
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true;
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
/**
* A callback that receives the computed masks from the image segmenter. The
* returned data is only valid for the duration of the callback. If
* asynchronous processing is needed, all data needs to be copied before the
* callback returns.
*/
export type ImageSegmenterCallack = (result: ImageSegmenterResult) => void;
/** Performs image segmentation on images. */
export class ImageSegmenter extends VisionTaskRunner {
private userCallback: SegmentationMaskCallback = () => {};
private result: ImageSegmenterResult = {width: 0, height: 0};
private labels: string[] = [];
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
private readonly options: ImageSegmenterGraphOptionsProto;
private readonly segmenterOptions: SegmenterOptionsProto;
@ -109,7 +124,6 @@ export class ImageSegmenter extends VisionTaskRunner {
this.options.setBaseOptions(new BaseOptionsProto());
}
protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions()!;
}
@ -137,12 +151,14 @@ export class ImageSegmenter extends VisionTaskRunner {
this.options.clearDisplayNamesLocale();
}
if (options.outputType === 'CONFIDENCE_MASK') {
this.segmenterOptions.setOutputType(
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
} else {
this.segmenterOptions.setOutputType(
SegmenterOptionsProto.OutputType.CATEGORY_MASK);
if ('outputCategoryMask' in options) {
this.outputCategoryMask =
options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK;
}
if ('outputConfidenceMasks' in options) {
this.outputConfidenceMasks =
options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS;
}
return super.applyOptions(options);
@ -192,7 +208,7 @@ export class ImageSegmenter extends VisionTaskRunner {
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segment(image: ImageSource, callback: SegmentationMaskCallback): void;
segment(image: ImageSource, callback: ImageSegmenterCallack): void;
/**
* Performs image segmentation on the provided single image and invokes the
* callback with the response. The method returns synchronously once the
@ -208,22 +224,77 @@ export class ImageSegmenter extends VisionTaskRunner {
*/
segment(
image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
callback: SegmentationMaskCallback): void;
callback: ImageSegmenterCallack): void;
segment(
image: ImageSource,
imageProcessingOptionsOrCallback: ImageProcessingOptions|
SegmentationMaskCallback,
callback?: SegmentationMaskCallback): void {
ImageSegmenterCallack,
callback?: ImageSegmenterCallack): void {
const imageProcessingOptions =
typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback :
{};
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
const userCallback =
typeof imageProcessingOptionsOrCallback === 'function' ?
imageProcessingOptionsOrCallback :
callback!;
this.reset();
this.processImageData(image, imageProcessingOptions);
this.userCallback = () => {};
userCallback(this.result);
}
/**
* Performs image segmentation on the provided video frame and invokes the
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segmentForVideo(
videoFrame: ImageSource, timestamp: number,
callback: ImageSegmenterCallack): void;
/**
* Performs image segmentation on the provided video frame and invokes the
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segmentForVideo(
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
timestamp: number, callback: ImageSegmenterCallack): void;
segmentForVideo(
videoFrame: ImageSource,
timestampOrImageProcessingOptions: number|ImageProcessingOptions,
timestampOrCallback: number|ImageSegmenterCallack,
callback?: ImageSegmenterCallack): void {
const imageProcessingOptions =
typeof timestampOrImageProcessingOptions !== 'number' ?
timestampOrImageProcessingOptions :
{};
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
timestampOrImageProcessingOptions :
timestampOrCallback as number;
const userCallback = typeof timestampOrCallback === 'function' ?
timestampOrCallback :
callback!;
this.reset();
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
userCallback(this.result);
}
/**
@ -241,56 +312,8 @@ export class ImageSegmenter extends VisionTaskRunner {
return this.labels;
}
/**
* Performs image segmentation on the provided video frame and invokes the
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segmentForVideo(
videoFrame: ImageSource, timestamp: number,
callback: SegmentationMaskCallback): void;
/**
* Performs image segmentation on the provided video frame and invokes the
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segmentForVideo(
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
timestamp: number, callback: SegmentationMaskCallback): void;
segmentForVideo(
videoFrame: ImageSource,
timestampOrImageProcessingOptions: number|ImageProcessingOptions,
timestampOrCallback: number|SegmentationMaskCallback,
callback?: SegmentationMaskCallback): void {
const imageProcessingOptions =
typeof timestampOrImageProcessingOptions !== 'number' ?
timestampOrImageProcessingOptions :
{};
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
timestampOrImageProcessingOptions :
timestampOrCallback as number;
this.userCallback = typeof timestampOrCallback === 'function' ?
timestampOrCallback :
callback!;
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
this.userCallback = () => {};
private reset(): void {
this.result = {width: 0, height: 0};
}
/** Updates the MediaPipe graph configuration. */
@ -298,7 +321,6 @@ export class ImageSegmenter extends VisionTaskRunner {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(IMAGE_STREAM);
graphConfig.addInputStream(NORM_RECT_STREAM);
graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM);
const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension(
@ -308,26 +330,47 @@ export class ImageSegmenter extends VisionTaskRunner {
segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH);
segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM);
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
segmenterNode.addOutputStream(
'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM);
segmenterNode.setOptions(calculatorOptions);
graphConfig.addNode(segmenterNode);
this.graphRunner.attachImageVectorListener(
GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => {
if (masks.length === 0) {
this.userCallback([], 0, 0);
} else {
this.userCallback(
masks.map(m => m.data), masks[0].width, masks[0].height);
}
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
GROUPED_SEGMENTATIONS_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
if (this.outputConfidenceMasks) {
graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
segmenterNode.addOutputStream(
'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
this.graphRunner.attachImageVectorListener(
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
this.result.confidenceMasks = masks.map(m => m.data);
if (masks.length >= 0) {
this.result.width = masks[0].width;
this.result.height = masks[0].height;
}
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
CONFIDENCE_MASKS_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
}
if (this.outputCategoryMask) {
graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
this.graphRunner.attachImageListener(
CATEGORY_MASK_STREAM, (mask, timestamp) => {
this.result.categoryMask = mask.data;
this.result.width = mask.width;
this.result.height = mask.height;
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
CATEGORY_MASK_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
}
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -24,18 +24,9 @@ export interface ImageSegmenterOptions extends VisionTaskOptions {
*/
displayNamesLocale?: string|undefined;
/**
* The output type of segmentation results.
*
* The two supported modes are:
* - Category Mask: Gives a single output mask where each pixel represents
* the class which the pixel in the original image was
* predicted to belong to.
* - Confidence Mask: Gives a list of output masks (one for each class). For
* each mask, the pixel represents the prediction
* confidence, usually in the [0.0, 0.1] range.
*
* Defaults to `CATEGORY_MASK`.
*/
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
/** Whether to output confidence masks. Defaults to true. */
outputConfidenceMasks?: boolean|undefined;
/** Whether to output the category masks. Defaults to false. */
outputCategoryMask?: boolean|undefined;
}

View File

@ -0,0 +1,37 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/** The output result of ImageSegmenter. */
export declare interface ImageSegmenterResult {
/**
* Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each
* pixel represents the prediction confidence, usually in the [0, 1] range.
*/
confidenceMasks?: Float32Array[]|WebGLTexture[];
/**
* A category mask as a Uint8ClampedArray or WebGLTexture where each
* pixel represents the class which the pixel in the original image was
* predicted to belong to.
*/
categoryMask?: Uint8ClampedArray|WebGLTexture;
/** The width of the masks. */
width: number;
/** The height of the masks. */
height: number;
}

View File

@ -18,7 +18,7 @@ import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {ImageSegmenter} from './image_segmenter';
@ -30,7 +30,9 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
graph: CalculatorGraphConfig|undefined;
fakeWasmModule: SpyWasmModule;
imageVectorListener:
categoryMaskListener:
((images: WasmImage, timestamp: number) => void)|undefined;
confidenceMasksListener:
((images: WasmImage[], timestamp: number) => void)|undefined;
constructor() {
@ -38,11 +40,16 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] =
this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('category_mask');
this.categoryMaskListener = listener;
});
this.attachListenerSpies[1] =
spyOn(this.graphRunner, 'attachImageVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('segmented_masks');
this.imageVectorListener = listener;
expect(stream).toEqual('confidence_masks');
this.confidenceMasksListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
@ -63,17 +70,18 @@ describe('ImageSegmenter', () => {
it('initializes graph', async () => {
verifyGraph(imageSegmenter);
verifyListenersRegistered(imageSegmenter);
// Verify default options
expect(imageSegmenter.categoryMaskListener).not.toBeDefined();
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
});
it('reloads graph when settings are changed', async () => {
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
verifyListenersRegistered(imageSegmenter);
await imageSegmenter.setOptions({displayNamesLocale: 'de'});
verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']);
verifyListenersRegistered(imageSegmenter);
});
it('can use custom models', async () => {
@ -100,9 +108,11 @@ describe('ImageSegmenter', () => {
});
it('merges options', async () => {
await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
await imageSegmenter.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]);
verifyGraph(
imageSegmenter, [['baseOptions', 'modelAsset', 'fileContent'], '']);
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
});
@ -115,22 +125,13 @@ describe('ImageSegmenter', () => {
defaultValue: unknown;
}
const testCases: TestCase[] = [
{
optionName: 'displayNamesLocale',
fieldPath: ['displayNamesLocale'],
userValue: 'en',
graphValue: 'en',
defaultValue: 'en'
},
{
optionName: 'outputType',
fieldPath: ['segmenterOptions', 'outputType'],
userValue: 'CONFIDENCE_MASK',
graphValue: 2,
defaultValue: 1
},
];
const testCases: TestCase[] = [{
optionName: 'displayNamesLocale',
fieldPath: ['displayNamesLocale'],
userValue: 'en',
graphValue: 'en',
defaultValue: 'en'
}];
for (const testCase of testCases) {
it(`can set ${testCase.optionName}`, async () => {
@ -158,27 +159,31 @@ describe('ImageSegmenter', () => {
}).toThrowError('This task doesn\'t support region-of-interest.');
});
it('supports category masks', (done) => {
it('supports category mask', async () => {
const mask = new Uint8ClampedArray([1, 2, 3, 4]);
await imageSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: false});
// Pass the test data to our listener
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(imageSegmenter);
imageSegmenter.imageVectorListener!(
[
{data: mask, width: 2, height: 2},
],
/* timestamp= */ 1337);
expect(imageSegmenter.categoryMaskListener).toBeDefined();
imageSegmenter.categoryMaskListener!
({data: mask, width: 2, height: 2},
/* timestamp= */ 1337);
});
// Invoke the image segmenter
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(masks).toHaveSize(1);
expect(masks[0]).toEqual(mask);
expect(width).toEqual(2);
expect(height).toEqual(2);
done();
return new Promise<void>(resolve => {
imageSegmenter.segment({} as HTMLImageElement, result => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result.categoryMask).toEqual(mask);
expect(result.confidenceMasks).not.toBeDefined();
expect(result.width).toEqual(2);
expect(result.height).toEqual(2);
resolve();
});
});
});
@ -186,12 +191,13 @@ describe('ImageSegmenter', () => {
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
await imageSegmenter.setOptions(
{outputCategoryMask: false, outputConfidenceMasks: true});
// Pass the test data to our listener
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(imageSegmenter);
imageSegmenter.imageVectorListener!(
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
imageSegmenter.confidenceMasksListener!(
[
{data: mask1, width: 2, height: 2},
{data: mask2, width: 2, height: 2},
@ -201,13 +207,49 @@ describe('ImageSegmenter', () => {
return new Promise<void>(resolve => {
// Invoke the image segmenter
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
imageSegmenter.segment({} as HTMLImageElement, result => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(masks).toHaveSize(2);
expect(masks[0]).toEqual(mask1);
expect(masks[1]).toEqual(mask2);
expect(width).toEqual(2);
expect(height).toEqual(2);
expect(result.categoryMask).not.toBeDefined();
expect(result.confidenceMasks).toEqual([mask1, mask2]);
expect(result.width).toEqual(2);
expect(result.height).toEqual(2);
resolve();
});
});
});
it('supports combined category and confidence masks', async () => {
const categoryMask = new Uint8ClampedArray([1, 0]);
const confidenceMask1 = new Float32Array([0.0, 1.0]);
const confidenceMask2 = new Float32Array([1.0, 0.0]);
await imageSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: true});
// Pass the test data to our listener
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
expect(imageSegmenter.categoryMaskListener).toBeDefined();
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
imageSegmenter.categoryMaskListener!
({data: categoryMask, width: 1, height: 1}, 1337);
imageSegmenter.confidenceMasksListener!(
[
{data: confidenceMask1, width: 1, height: 1},
{data: confidenceMask2, width: 1, height: 1},
],
1337);
});
return new Promise<void>(resolve => {
// Invoke the image segmenter
imageSegmenter.segment({} as HTMLImageElement, result => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result.categoryMask).toEqual(categoryMask);
expect(result.confidenceMasks).toEqual([
confidenceMask1, confidenceMask2
]);
expect(result.width).toEqual(1);
expect(result.height).toEqual(1);
resolve();
});
});

View File

@ -30,7 +30,10 @@ mediapipe_ts_library(
mediapipe_ts_declaration(
name = "interactive_segmenter_types",
srcs = ["interactive_segmenter_options.d.ts"],
srcs = [
"interactive_segmenter_options.d.ts",
"interactive_segmenter_result.d.ts",
],
deps = [
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options",

View File

@ -21,7 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {RegionOfInterest, SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
import {RegionOfInterest, SegmentationMask} from '../../../../tasks/web/vision/core/types';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {Color as ColorProto} from '../../../../util/color_pb';
import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb';
@ -29,21 +29,35 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner
// Placeholder for internal dependency on trusted resource url
import {InteractiveSegmenterOptions} from './interactive_segmenter_options';
import {InteractiveSegmenterResult} from './interactive_segmenter_result';
export * from './interactive_segmenter_options';
export {SegmentationMask, SegmentationMaskCallback, RegionOfInterest};
export * from './interactive_segmenter_result';
export {SegmentationMask, RegionOfInterest};
export {ImageSource};
const IMAGE_IN_STREAM = 'image_in';
const NORM_RECT_IN_STREAM = 'norm_rect_in';
const ROI_IN_STREAM = 'roi_in';
const IMAGE_OUT_STREAM = 'image_out';
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
const CATEGORY_MASK_STREAM = 'category_mask';
const IMAGEA_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true;
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
/**
* A callback that receives the computed masks from the interactive segmenter.
* The returned data is only valid for the duration of the callback. If
* asynchronous processing is needed, all data needs to be copied before the
* callback returns.
*/
export type InteractiveSegmenterCallack =
(result: InteractiveSegmenterResult) => void;
/**
* Performs interactive segmentation on images.
*
@ -69,7 +83,9 @@ const IMAGEA_SEGMENTER_GRAPH =
* - batch is always 1
*/
export class InteractiveSegmenter extends VisionTaskRunner {
private userCallback: SegmentationMaskCallback = () => {};
private result: InteractiveSegmenterResult = {width: 0, height: 0};
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
private readonly options: ImageSegmenterGraphOptionsProto;
private readonly segmenterOptions: SegmenterOptionsProto;
@ -154,12 +170,14 @@ export class InteractiveSegmenter extends VisionTaskRunner {
* @return A Promise that resolves when the settings have been applied.
*/
override setOptions(options: InteractiveSegmenterOptions): Promise<void> {
if (options.outputType === 'CONFIDENCE_MASK') {
this.segmenterOptions.setOutputType(
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
} else {
this.segmenterOptions.setOutputType(
SegmenterOptionsProto.OutputType.CATEGORY_MASK);
if ('outputCategoryMask' in options) {
this.outputCategoryMask =
options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK;
}
if ('outputConfidenceMasks' in options) {
this.outputConfidenceMasks =
options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS;
}
return super.applyOptions(options);
@ -184,7 +202,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
*/
segment(
image: ImageSource, roi: RegionOfInterest,
callback: SegmentationMaskCallback): void;
callback: InteractiveSegmenterCallack): void;
/**
* Performs interactive segmentation on the provided single image and invokes
* the callback with the response. The `roi` parameter is used to represent a
@ -213,24 +231,29 @@ export class InteractiveSegmenter extends VisionTaskRunner {
segment(
image: ImageSource, roi: RegionOfInterest,
imageProcessingOptions: ImageProcessingOptions,
callback: SegmentationMaskCallback): void;
callback: InteractiveSegmenterCallack): void;
segment(
image: ImageSource, roi: RegionOfInterest,
imageProcessingOptionsOrCallback: ImageProcessingOptions|
SegmentationMaskCallback,
callback?: SegmentationMaskCallback): void {
InteractiveSegmenterCallack,
callback?: InteractiveSegmenterCallack): void {
const imageProcessingOptions =
typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback :
{};
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
const userCallback =
typeof imageProcessingOptionsOrCallback === 'function' ?
imageProcessingOptionsOrCallback :
callback!;
this.reset();
this.processRenderData(roi, this.getSynctheticTimestamp());
this.processImageData(image, imageProcessingOptions);
this.userCallback = () => {};
userCallback(this.result);
}
private reset(): void {
this.result = {width: 0, height: 0};
}
/** Updates the MediaPipe graph configuration. */
@ -239,7 +262,6 @@ export class InteractiveSegmenter extends VisionTaskRunner {
graphConfig.addInputStream(IMAGE_IN_STREAM);
graphConfig.addInputStream(ROI_IN_STREAM);
graphConfig.addInputStream(NORM_RECT_IN_STREAM);
graphConfig.addOutputStream(IMAGE_OUT_STREAM);
const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension(
@ -250,24 +272,47 @@ export class InteractiveSegmenter extends VisionTaskRunner {
segmenterNode.addInputStream('IMAGE:' + IMAGE_IN_STREAM);
segmenterNode.addInputStream('ROI:' + ROI_IN_STREAM);
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_IN_STREAM);
segmenterNode.addOutputStream('GROUPED_SEGMENTATION:' + IMAGE_OUT_STREAM);
segmenterNode.setOptions(calculatorOptions);
graphConfig.addNode(segmenterNode);
this.graphRunner.attachImageVectorListener(
IMAGE_OUT_STREAM, (masks, timestamp) => {
if (masks.length === 0) {
this.userCallback([], 0, 0);
} else {
this.userCallback(
masks.map(m => m.data), masks[0].width, masks[0].height);
}
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(IMAGE_OUT_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
if (this.outputConfidenceMasks) {
graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
segmenterNode.addOutputStream(
'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
this.graphRunner.attachImageVectorListener(
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
this.result.confidenceMasks = masks.map(m => m.data);
if (masks.length >= 0) {
this.result.width = masks[0].width;
this.result.height = masks[0].height;
}
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
CONFIDENCE_MASKS_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
}
if (this.outputCategoryMask) {
graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
this.graphRunner.attachImageListener(
CATEGORY_MASK_STREAM, (mask, timestamp) => {
this.result.categoryMask = mask.data;
this.result.width = mask.width;
this.result.height = mask.height;
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
CATEGORY_MASK_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
}
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -19,18 +19,9 @@ import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'
/** Options to configure the MediaPipe Interactive Segmenter Task */
export interface InteractiveSegmenterOptions extends TaskRunnerOptions {
/**
* The output type of segmentation results.
*
* The two supported modes are:
* - Category Mask: Gives a single output mask where each pixel represents
* the class which the pixel in the original image was
* predicted to belong to.
* - Confidence Mask: Gives a list of output masks (one for each class). For
* each mask, the pixel represents the prediction
* confidence, usually in the [0.0, 0.1] range.
*
* Defaults to `CATEGORY_MASK`.
*/
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
/** Whether to output confidence masks. Defaults to true. */
outputConfidenceMasks?: boolean|undefined;
/** Whether to output the category masks. Defaults to false. */
outputCategoryMask?: boolean|undefined;
}

View File

@ -0,0 +1,37 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/** The output result of InteractiveSegmenter. */
export declare interface InteractiveSegmenterResult {
/**
* Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each
* pixel represents the prediction confidence, usually in the [0, 1] range.
*/
confidenceMasks?: Float32Array[]|WebGLTexture[];
/**
* A category mask as a Uint8ClampedArray or WebGLTexture where each
* pixel represents the class which the pixel in the original image was
* predicted to belong to.
*/
categoryMask?: Uint8ClampedArray|WebGLTexture;
/** The width of the masks. */
width: number;
/** The height of the masks. */
height: number;
}

View File

@ -18,7 +18,7 @@ import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
import {RenderData as RenderDataProto} from '../../../../util/render_data_pb';
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
@ -37,7 +37,9 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
graph: CalculatorGraphConfig|undefined;
fakeWasmModule: SpyWasmModule;
imageVectorListener:
categoryMaskListener:
((images: WasmImage, timestamp: number) => void)|undefined;
confidenceMasksListener:
((images: WasmImage[], timestamp: number) => void)|undefined;
lastRoi?: RenderDataProto;
@ -46,11 +48,16 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] =
this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('category_mask');
this.categoryMaskListener = listener;
});
this.attachListenerSpies[1] =
spyOn(this.graphRunner, 'attachImageVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('image_out');
this.imageVectorListener = listener;
expect(stream).toEqual('confidence_masks');
this.confidenceMasksListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
@ -79,17 +86,21 @@ describe('InteractiveSegmenter', () => {
it('initializes graph', async () => {
verifyGraph(interactiveSegmenter);
verifyListenersRegistered(interactiveSegmenter);
// Verify default options
expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined();
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
});
it('reloads graph when settings are changed', async () => {
await interactiveSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 1]);
verifyListenersRegistered(interactiveSegmenter);
await interactiveSegmenter.setOptions(
{outputConfidenceMasks: true, outputCategoryMask: false});
expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined();
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 2]);
verifyListenersRegistered(interactiveSegmenter);
await interactiveSegmenter.setOptions(
{outputConfidenceMasks: false, outputCategoryMask: true});
expect(interactiveSegmenter.categoryMaskListener).toBeDefined();
});
it('can use custom models', async () => {
@ -115,23 +126,6 @@ describe('InteractiveSegmenter', () => {
]);
});
describe('setOptions()', () => {
const fieldPath = ['segmenterOptions', 'outputType'];
it(`can set outputType`, async () => {
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
});
it(`can clear outputType`, async () => {
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
await interactiveSegmenter.setOptions({outputType: undefined});
verifyGraph(interactiveSegmenter, [fieldPath, 1]);
});
});
it('doesn\'t support region of interest', () => {
expect(() => {
interactiveSegmenter.segment(
@ -153,60 +147,99 @@ describe('InteractiveSegmenter', () => {
interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {});
});
it('supports category masks', (done) => {
it('supports category mask', async () => {
const mask = new Uint8ClampedArray([1, 2, 3, 4]);
await interactiveSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: false});
// Pass the test data to our listener
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(interactiveSegmenter);
interactiveSegmenter.imageVectorListener!(
[
{data: mask, width: 2, height: 2},
],
/* timestamp= */ 1337);
expect(interactiveSegmenter.categoryMaskListener).toBeDefined();
interactiveSegmenter.categoryMaskListener!
({data: mask, width: 2, height: 2},
/* timestamp= */ 1337);
});
// Invoke the image segmenter
interactiveSegmenter.segment(
{} as HTMLImageElement, ROI, (masks, width, height) => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled();
expect(masks).toHaveSize(1);
expect(masks[0]).toEqual(mask);
expect(width).toEqual(2);
expect(height).toEqual(2);
done();
});
return new Promise<void>(resolve => {
interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled();
expect(result.categoryMask).toEqual(mask);
expect(result.confidenceMasks).not.toBeDefined();
expect(result.width).toEqual(2);
expect(result.height).toEqual(2);
resolve();
});
});
});
it('supports confidence masks', async () => {
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
await interactiveSegmenter.setOptions(
{outputCategoryMask: false, outputConfidenceMasks: true});
// Pass the test data to our listener
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(interactiveSegmenter);
interactiveSegmenter.imageVectorListener!(
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
interactiveSegmenter.confidenceMasksListener!(
[
{data: mask1, width: 2, height: 2},
{data: mask2, width: 2, height: 2},
],
1337);
});
return new Promise<void>(resolve => {
// Invoke the image segmenter
interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled();
expect(result.categoryMask).not.toBeDefined();
expect(result.confidenceMasks).toEqual([mask1, mask2]);
expect(result.width).toEqual(2);
expect(result.height).toEqual(2);
resolve();
});
});
});
it('supports combined category and confidence masks', async () => {
const categoryMask = new Uint8ClampedArray([1, 0]);
const confidenceMask1 = new Float32Array([0.0, 1.0]);
const confidenceMask2 = new Float32Array([1.0, 0.0]);
await interactiveSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: true});
// Pass the test data to our listener
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
expect(interactiveSegmenter.categoryMaskListener).toBeDefined();
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
interactiveSegmenter.categoryMaskListener!
({data: categoryMask, width: 1, height: 1}, 1337);
interactiveSegmenter.confidenceMasksListener!(
[
{data: confidenceMask1, width: 1, height: 1},
{data: confidenceMask2, width: 1, height: 1},
],
1337);
});
return new Promise<void>(resolve => {
// Invoke the image segmenter
interactiveSegmenter.segment(
{} as HTMLImageElement, ROI, (masks, width, height) => {
{} as HTMLImageElement, ROI, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled();
expect(masks).toHaveSize(2);
expect(masks[0]).toEqual(mask1);
expect(masks[1]).toEqual(mask2);
expect(width).toEqual(2);
expect(height).toEqual(2);
expect(result.categoryMask).toEqual(categoryMask);
expect(result.confidenceMasks).toEqual([
confidenceMask1, confidenceMask2
]);
expect(result.width).toEqual(1);
expect(result.height).toEqual(1);
resolve();
});
});

View File

@ -56,8 +56,8 @@ bool NormalizedtoPixelCoordinates(double normalized_x, double normalized_y,
VLOG(1) << "Normalized coordinates must be between 0.0 and 1.0";
}
*x_px = static_cast<int32>(round(normalized_x * image_width));
*y_px = static_cast<int32>(round(normalized_y * image_height));
*x_px = static_cast<int32_t>(round(normalized_x * image_width));
*y_px = static_cast<int32_t>(round(normalized_y * image_height));
return true;
}

View File

@ -43,7 +43,7 @@ ABSL_FLAG(std::string, system_cpu_max_freq_file,
namespace mediapipe {
namespace {
constexpr uint32 kBufferLength = 64;
constexpr uint32_t kBufferLength = 64;
absl::StatusOr<std::string> GetFilePath(int cpu) {
if (!absl::StrContains(absl::GetFlag(FLAGS_system_cpu_max_freq_file), "$0")) {
@ -54,7 +54,7 @@ absl::StatusOr<std::string> GetFilePath(int cpu) {
return absl::Substitute(absl::GetFlag(FLAGS_system_cpu_max_freq_file), cpu);
}
absl::StatusOr<uint64> GetCpuMaxFrequency(int cpu) {
absl::StatusOr<uint64_t> GetCpuMaxFrequency(int cpu) {
auto path_or_status = GetFilePath(cpu);
if (!path_or_status.ok()) {
return path_or_status.status();
@ -65,7 +65,7 @@ absl::StatusOr<uint64> GetCpuMaxFrequency(int cpu) {
char buffer[kBufferLength];
file.getline(buffer, kBufferLength);
file.close();
uint64 frequency;
uint64_t frequency;
if (absl::SimpleAtoi(buffer, &frequency)) {
return frequency;
} else {
@ -79,7 +79,7 @@ absl::StatusOr<uint64> GetCpuMaxFrequency(int cpu) {
}
std::set<int> InferLowerOrHigherCoreIds(bool lower) {
std::vector<std::pair<int, uint64>> cpu_freq_pairs;
std::vector<std::pair<int, uint64_t>> cpu_freq_pairs;
for (int cpu = 0; cpu < NumCPUCores(); ++cpu) {
auto freq_or_status = GetCpuMaxFrequency(cpu);
if (freq_or_status.ok()) {
@ -90,12 +90,12 @@ std::set<int> InferLowerOrHigherCoreIds(bool lower) {
return {};
}
absl::c_sort(cpu_freq_pairs, [lower](const std::pair<int, uint64>& left,
const std::pair<int, uint64>& right) {
absl::c_sort(cpu_freq_pairs, [lower](const std::pair<int, uint64_t>& left,
const std::pair<int, uint64_t>& right) {
return (lower && left.second < right.second) ||
(!lower && left.second > right.second);
});
uint64 edge_freq = cpu_freq_pairs[0].second;
uint64_t edge_freq = cpu_freq_pairs[0].second;
std::set<int> inferred_cores;
for (const auto& cpu_freq_pair : cpu_freq_pairs) {

View File

@ -89,12 +89,12 @@ void ImageFrameToYUVImage(const ImageFrame& image_frame, YUVImage* yuv_image) {
const int uv_stride = (uv_width + 15) & ~15;
const int y_size = y_stride * height;
const int uv_size = uv_stride * uv_height;
uint8* data =
reinterpret_cast<uint8*>(aligned_malloc(y_size + uv_size * 2, 16));
uint8_t* data =
reinterpret_cast<uint8_t*>(aligned_malloc(y_size + uv_size * 2, 16));
std::function<void()> deallocate = [data]() { aligned_free(data); };
uint8* y = data;
uint8* u = y + y_size;
uint8* v = u + uv_size;
uint8_t* y = data;
uint8_t* u = y + y_size;
uint8_t* v = u + uv_size;
yuv_image->Initialize(libyuv::FOURCC_I420, deallocate, //
y, y_stride, //
u, uv_stride, //
@ -123,10 +123,11 @@ void ImageFrameToYUVNV12Image(const ImageFrame& image_frame,
const int uv_stride = y_stride;
const int uv_height = (height + 1) / 2;
const int uv_size = uv_stride * uv_height;
uint8* data = reinterpret_cast<uint8*>(aligned_malloc(y_size + uv_size, 16));
uint8_t* data =
reinterpret_cast<uint8_t*>(aligned_malloc(y_size + uv_size, 16));
std::function<void()> deallocate = [data] { aligned_free(data); };
uint8* y = data;
uint8* uv = y + y_size;
uint8_t* y = data;
uint8_t* uv = y + y_size;
yuv_nv12_image->Initialize(libyuv::FOURCC_NV12, deallocate, y, y_stride, uv,
uv_stride, nullptr, 0, width, height);
const int rv = libyuv::I420ToNV12(
@ -210,44 +211,44 @@ void YUVImageToImageFrameFromFormat(const YUVImage& yuv_image,
}
}
void SrgbToMpegYCbCr(const uint8 r, const uint8 g, const uint8 b, //
uint8* y, uint8* cb, uint8* cr) {
void SrgbToMpegYCbCr(const uint8_t r, const uint8_t g, const uint8_t b, //
uint8_t* y, uint8_t* cb, uint8_t* cr) {
// ITU-R BT.601 conversion from sRGB to YCbCr.
// FastIntRound is used rather than SafeRound since the possible
// range of values is [16,235] for Y and [16,240] for Cb and Cr and we
// don't care about the rounding direction for values exactly between
// two integers.
*y = static_cast<uint8>(
*y = static_cast<uint8_t>(
mediapipe::MathUtil::FastIntRound(16.0 + //
65.481 * r / 255.0 + //
128.553 * g / 255.0 + //
24.966 * b / 255.0));
*cb = static_cast<uint8>(
*cb = static_cast<uint8_t>(
mediapipe::MathUtil::FastIntRound(128.0 + //
-37.797 * r / 255.0 + //
-74.203 * g / 255.0 + //
112.0 * b / 255.0));
*cr = static_cast<uint8>(
*cr = static_cast<uint8_t>(
mediapipe::MathUtil::FastIntRound(128.0 + //
112.0 * r / 255.0 + //
-93.786 * g / 255.0 + //
-18.214 * b / 255.0));
}
void MpegYCbCrToSrgb(const uint8 y, const uint8 cb, const uint8 cr, //
uint8* r, uint8* g, uint8* b) {
void MpegYCbCrToSrgb(const uint8_t y, const uint8_t cb, const uint8_t cr, //
uint8_t* r, uint8_t* g, uint8_t* b) {
// ITU-R BT.601 conversion from YCbCr to sRGB
// Use SafeRound since many MPEG YCbCr values do not correspond directly
// to an sRGB value.
*r = mediapipe::MathUtil::SafeRound<uint8, double>( //
255.0 / 219.0 * (y - 16.0) + //
*r = mediapipe::MathUtil::SafeRound<uint8_t, double>( //
255.0 / 219.0 * (y - 16.0) + //
255.0 / 112.0 * 0.701 * (cr - 128.0));
*g = mediapipe::MathUtil::SafeRound<uint8, double>(
*g = mediapipe::MathUtil::SafeRound<uint8_t, double>(
255.0 / 219.0 * (y - 16.0) - //
255.0 / 112.0 * 0.886 * 0.114 / 0.587 * (cb - 128.0) - //
255.0 / 112.0 * 0.701 * 0.299 / 0.587 * (cr - 128.0));
*b = mediapipe::MathUtil::SafeRound<uint8, double>( //
255.0 / 219.0 * (y - 16.0) + //
*b = mediapipe::MathUtil::SafeRound<uint8_t, double>( //
255.0 / 219.0 * (y - 16.0) + //
255.0 / 112.0 * 0.886 * (cb - 128.0));
}
@ -260,15 +261,15 @@ void MpegYCbCrToSrgb(const uint8 y, const uint8 cb, const uint8 cr, //
cv::Mat GetSrgbToLinearRgb16Lut() {
cv::Mat lut(1, 256, CV_16UC1);
uint16* ptr = lut.ptr<uint16>();
uint16_t* ptr = lut.ptr<uint16_t>();
constexpr double kUint8Max = 255.0;
constexpr double kUint16Max = 65535.0;
for (int i = 0; i < 256; ++i) {
if (i < 0.04045 * kUint8Max) {
ptr[i] = static_cast<uint16>(
ptr[i] = static_cast<uint16_t>(
(static_cast<double>(i) / kUint8Max / 12.92) * kUint16Max + .5);
} else {
ptr[i] = static_cast<uint16>(
ptr[i] = static_cast<uint16_t>(
pow((static_cast<double>(i) / kUint8Max + 0.055) / 1.055, 2.4) *
kUint16Max +
.5);
@ -279,15 +280,15 @@ cv::Mat GetSrgbToLinearRgb16Lut() {
cv::Mat GetLinearRgb16ToSrgbLut() {
cv::Mat lut(1, 65536, CV_8UC1);
uint8* ptr = lut.ptr<uint8>();
uint8_t* ptr = lut.ptr<uint8_t>();
constexpr double kUint8Max = 255.0;
constexpr double kUint16Max = 65535.0;
for (int i = 0; i < 65536; ++i) {
if (i < 0.0031308 * kUint16Max) {
ptr[i] = static_cast<uint8>(
ptr[i] = static_cast<uint8_t>(
(static_cast<double>(i) / kUint16Max * 12.92) * kUint8Max + .5);
} else {
ptr[i] = static_cast<uint8>(
ptr[i] = static_cast<uint8_t>(
(1.055 * pow(static_cast<double>(i) / kUint16Max, 1.0 / 2.4) - .055) *
kUint8Max +
.5);
@ -306,13 +307,13 @@ void LinearRgb16ToSrgb(const cv::Mat& source, cv::Mat* destination) {
destination->create(source.size(), CV_8UC(source.channels()));
static const cv::Mat kLut = GetLinearRgb16ToSrgbLut();
const uint8* lookup_table_ptr = kLut.ptr<uint8>();
const uint8_t* lookup_table_ptr = kLut.ptr<uint8_t>();
const int num_channels = source.channels();
for (int row = 0; row < source.rows; ++row) {
for (int col = 0; col < source.cols; ++col) {
for (int channel = 0; channel < num_channels; ++channel) {
uint8* ptr = destination->ptr<uint8>(row);
const uint16* ptr16 = source.ptr<uint16>(row);
uint8_t* ptr = destination->ptr<uint8_t>(row);
const uint16_t* ptr16 = source.ptr<uint16_t>(row);
ptr[col * num_channels + channel] =
lookup_table_ptr[ptr16[col * num_channels + channel]];
}

View File

@ -43,14 +43,14 @@ mediapipe::ImageFormat::Format GetImageFormat(int image_channels) {
Packet MakeImageFramePacket(cv::Mat input, int timestamp) {
ImageFrame input_image(GetImageFormat(input.channels()), input.cols,
input.rows, input.step, input.data, [](uint8*) {});
input.rows, input.step, input.data, [](uint8_t*) {});
return MakePacket<ImageFrame>(std::move(input_image)).At(Timestamp(0));
}
Packet MakeImagePacket(cv::Mat input, int timestamp) {
mediapipe::Image input_image(std::make_shared<mediapipe::ImageFrame>(
GetImageFormat(input.channels()), input.cols, input.rows, input.step,
input.data, [](uint8*) {}));
input.data, [](uint8_t*) {}));
return MakePacket<mediapipe::Image>(std::move(input_image)).At(Timestamp(0));
}

View File

@ -25,7 +25,7 @@
namespace mediapipe {
absl::StatusOr<proto_ns::Map<int64, LabelMapItem>> BuildLabelMapFromFiles(
absl::StatusOr<proto_ns::Map<int64_t, LabelMapItem>> BuildLabelMapFromFiles(
absl::string_view labels_file_contents,
absl::string_view display_names_file) {
if (labels_file_contents.empty()) {
@ -68,7 +68,7 @@ absl::StatusOr<proto_ns::Map<int64, LabelMapItem>> BuildLabelMapFromFiles(
label_map_items[i].set_display_name(display_names[i]);
}
}
proto_ns::Map<int64, LabelMapItem> label_map;
proto_ns::Map<int64_t, LabelMapItem> label_map;
for (int i = 0; i < label_map_items.size(); ++i) {
label_map[i] = label_map_items[i];
}

View File

@ -45,12 +45,16 @@ filegroup(
"include/flatbuffers/bfbs_generator.h",
"include/flatbuffers/buffer.h",
"include/flatbuffers/buffer_ref.h",
"include/flatbuffers/code_generator.h",
"include/flatbuffers/code_generators.h",
"include/flatbuffers/default_allocator.h",
"include/flatbuffers/detached_buffer.h",
"include/flatbuffers/flatbuffer_builder.h",
"include/flatbuffers/flatbuffers.h",
"include/flatbuffers/flatc.h",
"include/flatbuffers/flex_flat_util.h",
"include/flatbuffers/flexbuffers.h",
"include/flatbuffers/grpc.h",
"include/flatbuffers/hash.h",
"include/flatbuffers/idl.h",
"include/flatbuffers/minireflect.h",

View File

@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
def repo():
third_party_http_archive(
name = "flatbuffers",
strip_prefix = "flatbuffers-2.0.6",
sha256 = "e2dc24985a85b278dd06313481a9ca051d048f9474e0f199e372fea3ea4248c9",
strip_prefix = "flatbuffers-23.1.21",
sha256 = "d84cb25686514348e615163b458ae0767001b24b42325f426fd56406fd384238",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v2.0.6.tar.gz",
"https://github.com/google/flatbuffers/archive/v2.0.6.tar.gz",
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
"https://github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
],
build_file = "//third_party/flatbuffers:BUILD.bazel",
delete = ["build_defs.bzl", "BUILD.bazel"],

View File

@ -12,72 +12,72 @@ def wasm_files():
http_file(
name = "com_google_mediapipe_wasm_audio_wasm_internal_js",
sha256 = "0eca68e2291a548b734bcab5db4c9e6b997e852ea7e19228003b9e2a78c7c646",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1681328323089931"],
sha256 = "b810de53d7ccf991b9c70fcdf7e88b5c3f2942ae766436f22be48159b6a7e687",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1681849488227617"],
)
http_file(
name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm",
sha256 = "69bc95af5b783b510ec1842d6fb9594254907d8e1334799c5753164878a7dcac",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1681328325829340"],
sha256 = "26d91147e5c6c8a92e0a4ebf59599068a3cff6108847b793ef33ac23e98eddb9",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1681849491546937"],
)
http_file(
name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js",
sha256 = "88a0176cc80d6a1eb175a5105df705cf8b8684cf13f6db0a264af0b67b65a22a",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1681328328330829"],
sha256 = "b38e37b3024692558eaaba159921fedd3297d1a09bba1c16a06fed327845b0bd",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1681849494099698"],
)
http_file(
name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm",
sha256 = "1cc0c3db7d252801be4b090d8bbba61f308cc3dd5efe197319581d3af29495c7",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1681328331085637"],
sha256 = "6a8e73d2e926565046e16adf1748f0f8ec5135fafe7eb8b9c83892e64c1a449a",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1681849496451970"],
)
http_file(
name = "com_google_mediapipe_wasm_text_wasm_internal_js",
sha256 = "d9cd100b6d330d36f7749fe5fc64a2cdd0abb947a0376e6140784cfb0361a4e2",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1681328333442454"],
sha256 = "785cba67b623b1dc66dc3621e97fd6b30edccbb408184a3094d0aa68ddd5becb",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1681849498746265"],
)
http_file(
name = "com_google_mediapipe_wasm_text_wasm_internal_wasm",
sha256 = "30a2fcca630bdad6e99173ea7d0d8c5d7086aedf393d0159fa05bf9d08d4ff65",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1681328335803336"],
sha256 = "a858b8a2e8b40e9c936b66566c5aefd396536c4e936459ab9ae7e239621adc14",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1681849501370461"],
)
http_file(
name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js",
sha256 = "70ca2bd15c56e0ce7bb10ff2188b4a1f9eafbb657eb9424e4cab8d7b29179871",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1681328338162884"],
sha256 = "5292f1442d5e5c037e7cffb78a8c2d71255348ca2c3bd759b314bdbedd5590c2",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1681849503379116"],
)
http_file(
name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm",
sha256 = "8221b385905f36a769d7731a0adbe18b681bcb873561890429ca84278c67c3fd",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1681328340808115"],
sha256 = "e44b48ab29ee1d8befec804e9a63445c56266b679d19fb476d556ca621f0e493",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1681849505997020"],
)
http_file(
name = "com_google_mediapipe_wasm_vision_wasm_internal_js",
sha256 = "07692acd8202adafebd35dbcd7e2b8e88a76d4a0e6b9229cb3cad59503eeddc7",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1681328343147709"],
sha256 = "205855eba70464a92b9d00e90acac15c51a9f76192f900e697304ac6dea8f714",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1681849508414277"],
)
http_file(
name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm",
sha256 = "03bf553fa6a768b0d70103a5e7d835b6b37371ff44e201c3392f22e0879737c3",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1681328345605574"],
sha256 = "c0cbd0df3adb2a9cd1331d14f522d2bae9f8adc9f1b35f92cbbc4b782b190cef",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1681849510936608"],
)
http_file(
name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js",
sha256 = "36697be14f921985eac15d1447ec8a260817b05ade1c9bb3ca7e906e0f047ec0",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1681328348025082"],
sha256 = "0969812de4d3573198fa2eba4f5b0a7e97e98f97bd4215d876543f4925e57b84",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1681849513292639"],
)
http_file(
name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm",
sha256 = "103fb145438d61cfecb2e8db3f06b43a5d77a7e3fcea940437fe272227cf2592",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1681328350709881"],
sha256 = "f2ab62c3f8dabab0a573dadf5c105ff81a03c29c70f091f8cf273ae030c0a86f",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1681849515999000"],
)