diff --git a/mediapipe/calculators/core/constant_side_packet_calculator.cc b/mediapipe/calculators/core/constant_side_packet_calculator.cc index 45ff07110..509f7e9dd 100644 --- a/mediapipe/calculators/core/constant_side_packet_calculator.cc +++ b/mediapipe/calculators/core/constant_side_packet_calculator.cc @@ -78,7 +78,7 @@ class ConstantSidePacketCalculator : public CalculatorBase { } else if (packet_options.has_string_value()) { packet.Set(); } else if (packet_options.has_uint64_value()) { - packet.Set(); + packet.Set(); } else if (packet_options.has_classification_list_value()) { packet.Set(); } else if (packet_options.has_landmark_list_value()) { @@ -112,7 +112,7 @@ class ConstantSidePacketCalculator : public CalculatorBase { } else if (packet_options.has_string_value()) { packet.Set(MakePacket(packet_options.string_value())); } else if (packet_options.has_uint64_value()) { - packet.Set(MakePacket(packet_options.uint64_value())); + packet.Set(MakePacket(packet_options.uint64_value())); } else if (packet_options.has_classification_list_value()) { packet.Set(MakePacket( packet_options.classification_list_value())); diff --git a/mediapipe/calculators/core/gate_calculator_test.cc b/mediapipe/calculators/core/gate_calculator_test.cc index c734ddb5f..192019820 100644 --- a/mediapipe/calculators/core/gate_calculator_test.cc +++ b/mediapipe/calculators/core/gate_calculator_test.cc @@ -35,14 +35,14 @@ class GateCalculatorTest : public ::testing::Test { } // Use this when ALLOW/DISALLOW input is provided as a side packet. - void RunTimeStep(int64 timestamp, bool stream_payload) { + void RunTimeStep(int64_t timestamp, bool stream_payload) { runner_->MutableInputs()->Get("", 0).packets.push_back( MakePacket(stream_payload).At(Timestamp(timestamp))); MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed."; } // Use this when ALLOW/DISALLOW input is provided as an input stream. - void RunTimeStep(int64 timestamp, const std::string& control_tag, + void RunTimeStep(int64_t timestamp, const std::string& control_tag, bool control) { runner_->MutableInputs()->Get("", 0).packets.push_back( MakePacket(true).At(Timestamp(timestamp))); @@ -134,9 +134,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWOptionToTrue) { } )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -159,9 +159,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionSetToFalse) { } )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -175,9 +175,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionNotSet) { output_stream: "test_output" )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -193,9 +193,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) { )"); runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true)); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -215,9 +215,9 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) { )"); runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false)); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -237,9 +237,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) { )"); runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false)); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -255,9 +255,9 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) { )"); runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true)); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -272,13 +272,13 @@ TEST_F(GateCalculatorTest, Allow) { output_stream: "test_output" )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, "ALLOW", true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, "ALLOW", false); - constexpr int64 kTimestampValue2 = 44; + constexpr int64_t kTimestampValue2 = 44; RunTimeStep(kTimestampValue2, "ALLOW", true); - constexpr int64 kTimestampValue3 = 45; + constexpr int64_t kTimestampValue3 = 45; RunTimeStep(kTimestampValue3, "ALLOW", false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -297,13 +297,13 @@ TEST_F(GateCalculatorTest, Disallow) { output_stream: "test_output" )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, "DISALLOW", true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, "DISALLOW", false); - constexpr int64 kTimestampValue2 = 44; + constexpr int64_t kTimestampValue2 = 44; RunTimeStep(kTimestampValue2, "DISALLOW", true); - constexpr int64 kTimestampValue3 = 45; + constexpr int64_t kTimestampValue3 = 45; RunTimeStep(kTimestampValue3, "DISALLOW", false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -323,13 +323,13 @@ TEST_F(GateCalculatorTest, AllowWithStateChange) { output_stream: "STATE_CHANGE:state_changed" )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, "ALLOW", false); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, "ALLOW", true); - constexpr int64 kTimestampValue2 = 44; + constexpr int64_t kTimestampValue2 = 44; RunTimeStep(kTimestampValue2, "ALLOW", true); - constexpr int64 kTimestampValue3 = 45; + constexpr int64_t kTimestampValue3 = 45; RunTimeStep(kTimestampValue3, "ALLOW", false); const std::vector& output = @@ -379,13 +379,13 @@ TEST_F(GateCalculatorTest, DisallowWithStateChange) { output_stream: "STATE_CHANGE:state_changed" )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, "DISALLOW", true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, "DISALLOW", false); - constexpr int64 kTimestampValue2 = 44; + constexpr int64_t kTimestampValue2 = 44; RunTimeStep(kTimestampValue2, "DISALLOW", false); - constexpr int64 kTimestampValue3 = 45; + constexpr int64_t kTimestampValue3 = 45; RunTimeStep(kTimestampValue3, "DISALLOW", true); const std::vector& output = @@ -432,7 +432,7 @@ TEST_F(GateCalculatorTest, DisallowInitialNoStateTransition) { output_stream: "STATE_CHANGE:state_changed" )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, "DISALLOW", false); const std::vector& output = @@ -450,7 +450,7 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) { output_stream: "STATE_CHANGE:state_changed" )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, "ALLOW", true); const std::vector& output = diff --git a/mediapipe/calculators/core/string_to_int_calculator.cc b/mediapipe/calculators/core/string_to_int_calculator.cc index ecd55afb6..fa67aa8e5 100644 --- a/mediapipe/calculators/core/string_to_int_calculator.cc +++ b/mediapipe/calculators/core/string_to_int_calculator.cc @@ -64,16 +64,16 @@ REGISTER_CALCULATOR(StringToIntCalculator); using StringToUintCalculator = StringToIntCalculatorTemplate; REGISTER_CALCULATOR(StringToUintCalculator); -using StringToInt32Calculator = StringToIntCalculatorTemplate; +using StringToInt32Calculator = StringToIntCalculatorTemplate; REGISTER_CALCULATOR(StringToInt32Calculator); -using StringToUint32Calculator = StringToIntCalculatorTemplate; +using StringToUint32Calculator = StringToIntCalculatorTemplate; REGISTER_CALCULATOR(StringToUint32Calculator); -using StringToInt64Calculator = StringToIntCalculatorTemplate; +using StringToInt64Calculator = StringToIntCalculatorTemplate; REGISTER_CALCULATOR(StringToInt64Calculator); -using StringToUint64Calculator = StringToIntCalculatorTemplate; +using StringToUint64Calculator = StringToIntCalculatorTemplate; REGISTER_CALCULATOR(StringToUint64Calculator); } // namespace mediapipe diff --git a/mediapipe/calculators/image/warp_affine_calculator.cc b/mediapipe/calculators/image/warp_affine_calculator.cc index 388701773..dcc371036 100644 --- a/mediapipe/calculators/image/warp_affine_calculator.cc +++ b/mediapipe/calculators/image/warp_affine_calculator.cc @@ -166,7 +166,7 @@ class WarpAffineRunnerHolder { const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(), frame_ptr->Height(), frame_ptr->WidthStep(), const_cast(frame_ptr->PixelData()), - [](uint8* data){}); + [](uint8_t* data){}); ASSIGN_OR_RETURN(auto result, runner->Run(image_frame, matrix, size, border_mode)); return mediapipe::Image(std::make_shared(std::move(result))); diff --git a/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.cc b/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.cc index d53acedc9..fd51a7383 100644 --- a/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.cc @@ -131,9 +131,9 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler { ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) { // Record the most recent first kept timestamp on any stream. for (const auto& stream : input_stream_managers_) { - int32 queue_size = (stream->QueueSize() >= trigger_queue_size_) - ? target_queue_size_ - : trigger_queue_size_ - 1; + int32_t queue_size = (stream->QueueSize() >= trigger_queue_size_) + ? target_queue_size_ + : trigger_queue_size_ - 1; if (stream->QueueSize() > queue_size) { kept_timestamp_ = std::max( kept_timestamp_, stream->GetMinTimestampAmongNLatest(queue_size + 1) @@ -214,8 +214,8 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler { } private: - int32 trigger_queue_size_; - int32 target_queue_size_; + int32_t trigger_queue_size_; + int32_t target_queue_size_; bool fixed_min_size_; // Indicates that GetNodeReadiness has returned kReadyForProcess once, and // the corresponding call to FillInputSet has not yet completed. diff --git a/mediapipe/gpu/gl_context_webgl.cc b/mediapipe/gpu/gl_context_webgl.cc index b1f5295c9..25cbed83d 100644 --- a/mediapipe/gpu/gl_context_webgl.cc +++ b/mediapipe/gpu/gl_context_webgl.cc @@ -67,53 +67,14 @@ absl::Status GlContext::CreateContextInternal( // TODO: Investigate this option in more detail, esp. on Safari. attrs.preserveDrawingBuffer = 0; - // Since the Emscripten canvas target finding function is visible from here, - // we hijack findCanvasEventTarget directly for enforcing old Module.canvas - // behavior if the user desires, falling back to the new DOM element CSS - // selector behavior next if that is specified, and finally just allowing the - // lookup to proceed on a null target. - // TODO: Ensure this works with all options (in particular, - // multithreading options, like the special-case combination of USE_PTHREADS - // and OFFSCREEN_FRAMEBUFFER) - // clang-format off - EM_ASM( - let init_once = true; - if (init_once) { - const cachedFindCanvasEventTarget = findCanvasEventTarget; - - if (typeof cachedFindCanvasEventTarget !== 'function') { - if (typeof console !== 'undefined') { - console.error('Expected Emscripten global function ' - + '"findCanvasEventTarget" not found. WebGL context creation ' - + 'may fail.'); - } - return; - } - - findCanvasEventTarget = function(target) { - if (target == 0) { - if (Module && Module.canvas) { - return Module.canvas; - } else if (Module && Module.canvasCssSelector) { - return cachedFindCanvasEventTarget(Module.canvasCssSelector); - } - if (typeof console !== 'undefined') { - console.warn('Module properties canvas and canvasCssSelector not ' + - 'found during WebGL context creation.'); - } - } - // We still go through with the find attempt, although for most use - // cases it will not succeed, just in case the user does want to fall- - // back. - return cachedFindCanvasEventTarget(target); - }; // NOLINT: Necessary semicolon. - init_once = false; - } - ); - // clang-format on - + // Quick patch for -s DISABLE_DEPRECATED_FIND_EVENT_TARGET_BEHAVIOR so it also + // looks for our #canvas target in Module.canvas, where we expect it to be. + // -s OFFSCREENCANVAS_SUPPORT=1 will no longer work with this under the new + // event target behavior, but it was never supposed to be tapping into our + // canvas anyways. See b/278155946 for more background. + EM_ASM({ specialHTMLTargets["#canvas"] = Module.canvas; }); EMSCRIPTEN_WEBGL_CONTEXT_HANDLE context_handle = - emscripten_webgl_create_context(nullptr, &attrs); + emscripten_webgl_create_context("#canvas", &attrs); // Check for failure if (context_handle <= 0) { diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 69b9889c7..f1497f741 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -64,7 +64,7 @@ std::unique_ptr GlTextureBuffer::Create( int actual_ws = image_frame.WidthStep(); int alignment = 0; std::unique_ptr temp; - const uint8* data = image_frame.PixelData(); + const uint8_t* data = image_frame.PixelData(); // Let's see if the pixel data is tightly aligned to one of the alignments // supported by OpenGL, preferring 4 if possible since it's the default. diff --git a/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc b/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc index c7acd1340..4b0913b96 100644 --- a/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc +++ b/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc @@ -167,7 +167,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height, GpuBufferFormat format) { libyuv::FourCC fourcc = FourCCForGpuBufferFormat(format); int y_stride = std::ceil(1.0f * width / kDefaultDataAligment); - auto y_data = std::make_unique(y_stride * height); + auto y_data = std::make_unique(y_stride * height); switch (fourcc) { case libyuv::FOURCC_NV12: case libyuv::FOURCC_NV21: { @@ -175,7 +175,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height, int uv_width = 2 * std::ceil(0.5f * width); int uv_height = std::ceil(0.5f * height); int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment); - auto uv_data = std::make_unique(uv_stride * uv_height); + auto uv_data = std::make_unique(uv_stride * uv_height); yuv_image_ = std::make_shared( fourcc, std::move(y_data), y_stride, std::move(uv_data), uv_stride, nullptr, 0, width, height); @@ -187,8 +187,8 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height, int uv_width = std::ceil(0.5f * width); int uv_height = std::ceil(0.5f * height); int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment); - auto u_data = std::make_unique(uv_stride * uv_height); - auto v_data = std::make_unique(uv_stride * uv_height); + auto u_data = std::make_unique(uv_stride * uv_height); + auto v_data = std::make_unique(uv_stride * uv_height); yuv_image_ = std::make_shared( fourcc, std::move(y_data), y_stride, std::move(u_data), uv_stride, std::move(v_data), uv_stride, width, height); diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index 0a7e7a0e0..6aa68a284 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -16,6 +16,7 @@ import csv import filecmp import os import tempfile +import unittest from unittest import mock as unittest_mock import tensorflow as tf @@ -24,6 +25,7 @@ from mediapipe.model_maker.python.text import text_classifier from mediapipe.tasks.python.test import test_utils +@unittest.skip('b/275624089') class TextClassifierTest(tf.test.TestCase): _AVERAGE_WORD_EMBEDDING_JSON_FILE = ( diff --git a/mediapipe/model_maker/python/vision/object_detector/BUILD b/mediapipe/model_maker/python/vision/object_detector/BUILD index f3d4407d8..b97d215da 100644 --- a/mediapipe/model_maker/python/vision/object_detector/BUILD +++ b/mediapipe/model_maker/python/vision/object_detector/BUILD @@ -175,11 +175,7 @@ py_test( data = [":testdata"], tags = ["requires-net:external"], deps = [ - ":dataset", - ":hyperparameters", - ":model_spec", - ":object_detector", - ":object_detector_options", + ":object_detector_import", "//mediapipe/tasks/python/test:test_utils", ], ) diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py index df6b58a07..02f773e69 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py @@ -19,11 +19,7 @@ from unittest import mock as unittest_mock from absl.testing import parameterized import tensorflow as tf -from mediapipe.model_maker.python.vision.object_detector import dataset -from mediapipe.model_maker.python.vision.object_detector import hyperparameters -from mediapipe.model_maker.python.vision.object_detector import model_spec as ms -from mediapipe.model_maker.python.vision.object_detector import object_detector -from mediapipe.model_maker.python.vision.object_detector import object_detector_options +from mediapipe.model_maker.python.vision import object_detector from mediapipe.tasks.python.test import test_utils as task_test_utils @@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): super().setUp() dataset_folder = task_test_utils.get_test_data_path('coco_data') cache_dir = self.create_tempdir() - self.data = dataset.Dataset.from_coco_folder( + self.data = object_detector.Dataset.from_coco_folder( dataset_folder, cache_dir=cache_dir ) # Mock tempfile.gettempdir() to be unique for each test to avoid race @@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): self.addCleanup(mock_gettempdir.stop) def test_object_detector(self): - hparams = hyperparameters.HParams( + hparams = object_detector.HParams( epochs=1, batch_size=2, learning_rate=0.9, shuffle=False, export_dir=self.create_tempdir(), ) - options = object_detector_options.ObjectDetectorOptions( - supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams + options = object_detector.ObjectDetectorOptions( + supported_model=object_detector.SupportedModels.MOBILENET_V2, + hparams=hparams, ) # Test `create`` model = object_detector.ObjectDetector.create( @@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): self.assertGreater(os.path.getsize(output_metadata_file), 0) # Test `quantization_aware_training` - qat_hparams = hyperparameters.QATHParams( + qat_hparams = object_detector.QATHParams( learning_rate=0.9, batch_size=2, epochs=1, diff --git a/mediapipe/modules/objectron/calculators/frame_annotation_tracker.cc b/mediapipe/modules/objectron/calculators/frame_annotation_tracker.cc index eebf88579..1685a4f68 100644 --- a/mediapipe/modules/objectron/calculators/frame_annotation_tracker.cc +++ b/mediapipe/modules/objectron/calculators/frame_annotation_tracker.cc @@ -24,8 +24,8 @@ namespace mediapipe { void FrameAnnotationTracker::AddDetectionResult( const FrameAnnotation& frame_annotation) { - const int64 time_us = - static_cast(std::round(frame_annotation.timestamp())); + const int64_t time_us = + static_cast(std::round(frame_annotation.timestamp())); for (const auto& object_annotation : frame_annotation.annotations()) { detected_objects_[time_us + object_annotation.object_id()] = object_annotation; @@ -37,7 +37,7 @@ FrameAnnotation FrameAnnotationTracker::ConsolidateTrackingResult( absl::flat_hash_set* cancel_object_ids) { CHECK(cancel_object_ids != nullptr); FrameAnnotation frame_annotation; - std::vector keys_to_be_deleted; + std::vector keys_to_be_deleted; for (const auto& detected_obj : detected_objects_) { const int object_id = detected_obj.second.object_id(); if (cancel_object_ids->contains(object_id)) { diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 7b2b97783..5aa9c9729 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -78,6 +78,7 @@ cc_library( hdrs = ["mediapipe_builtin_op_resolver.h"], deps = [ "//mediapipe/tasks/cc/text/custom_ops/ragged:ragged_tensor_to_tensor_tflite", + "//mediapipe/tasks/cc/text/custom_ops/sentencepiece:sentencepiece_tokenizer_tflite", "//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup", "//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash", "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", diff --git a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc index ae64e33ef..80097fd09 100644 --- a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc @@ -16,6 +16,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h" #include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h" #include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" @@ -51,6 +52,8 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() { AddCustom("KmeansEmbeddingLookup", mediapipe::tflite_operations::Register_KmeansEmbeddingLookup()); // For the UniversalSentenceEncoder model. + AddCustom("TFSentencepieceTokenizeOp", + mediapipe::tflite_operations::Register_SENTENCEPIECE_TOKENIZER()); AddCustom("RaggedTensorToTensor", mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR()); } diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD index 072b21f53..19f843c4e 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD @@ -18,6 +18,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +filegroup( + name = "testdata", + srcs = glob([ + "testdata/**", + ]), +) + filegroup( name = "config_fbs", srcs = ["config.fbs"], @@ -80,3 +87,86 @@ cc_test( "//mediapipe/framework/port:gtest_main", ], ) + +cc_library( + name = "sentencepiece_constants", + hdrs = ["sentencepiece_constants.h"], +) + +cc_library( + name = "model_converter", + srcs = [ + "model_converter.cc", + ], + hdrs = [ + "model_converter.h", + ], + deps = [ + ":config", + ":double_array_trie_builder", + ":encoder_config", + ":sentencepiece_constants", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_sentencepiece//src:sentencepiece_model_cc_proto", + ], +) + +cc_library( + name = "optimized_encoder", + srcs = [ + "optimized_encoder.cc", + ], + hdrs = [ + "optimized_encoder.h", + ], + deps = [ + ":double_array_trie", + ":encoder_config", + ":utils", + ], +) + +cc_library( + name = "sentencepiece_tokenizer_tflite", + srcs = ["sentencepiece_tokenizer_tflite.cc"], + hdrs = ["sentencepiece_tokenizer_tflite.h"], + visibility = [ + "//visibility:public", + ], + deps = + [ + ":optimized_encoder", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + ], +) + +cc_test( + name = "optimized_encoder_test", + srcs = [ + "optimized_encoder_test.cc", + ], + data = [ + ":testdata", + ], + deps = [ + ":double_array_trie_builder", + ":encoder_config", + ":model_converter", + ":optimized_encoder", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_sentencepiece//src:sentencepiece_cc_proto", + "@com_google_sentencepiece//src:sentencepiece_processor", + "@org_tensorflow//tensorflow/core:lib", + ], +) diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.cc new file mode 100644 index 000000000..3a831f3d7 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.cc @@ -0,0 +1,131 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h" +#include "src/sentencepiece_model.pb.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +std::tuple, std::vector> +DecodePrecompiledCharsmap( + const ::sentencepiece::NormalizerSpec& normalizer_spec) { + // This function "undoes" encoding done by + // sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap. + const char* precompiled_map = normalizer_spec.precompiled_charsmap().data(); + const uint32_t trie_size = + *reinterpret_cast(precompiled_map); + const uint32_t* trie_ptr = + reinterpret_cast(precompiled_map + sizeof(uint32_t)); + const int8_t* normalized_ptr = reinterpret_cast( + precompiled_map + sizeof(uint32_t) + trie_size); + const int normalized_size = normalizer_spec.precompiled_charsmap().length() - + sizeof(uint32_t) - trie_size; + return std::make_tuple( + std::vector(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)), + std::vector(normalized_ptr, normalized_ptr + normalized_size)); +} + +absl::StatusOr ConvertSentencepieceModelToFlatBuffer( + const std::string& model_config_str, int encoding_offset) { + ::sentencepiece::ModelProto model_config; + if (!model_config.ParseFromString(model_config_str)) { + return absl::InvalidArgumentError( + "Invalid configuration, can't parse SentencePiece model config " + + model_config.InitializationErrorString()); + } + // Convert sentencepieces. + std::vector pieces; + pieces.reserve(model_config.pieces_size()); + std::vector scores; + scores.reserve(model_config.pieces_size()); + std::vector ids; + ids.reserve(model_config.pieces_size()); + float min_score = 0.0; + int index = 0; + for (const auto& piece : model_config.pieces()) { + switch (piece.type()) { + case ::sentencepiece::ModelProto::SentencePiece::NORMAL: + case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED: + pieces.push_back(piece.piece()); + ids.push_back(index); + if (piece.score() < min_score) { + min_score = piece.score(); + } + break; + case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN: + case ::sentencepiece::ModelProto::SentencePiece::CONTROL: + // Ignore unknown and control codes. + break; + default: + return absl::InvalidArgumentError("Invalid SentencePiece piece type " + + piece.piece()); + } + scores.push_back(piece.score()); + ++index; + } + flatbuffers::FlatBufferBuilder builder(1024); + const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids)); + const auto pieces_score_vector = builder.CreateVector(scores); + TrieBuilder pieces_trie_builder(builder); + pieces_trie_builder.add_nodes(pieces_trie_vector); + const auto pieces_trie_fbs = pieces_trie_builder.Finish(); + + // Converting normalization. + const auto normalization = + DecodePrecompiledCharsmap(model_config.normalizer_spec()); + const auto normalization_trie = std::get<0>(normalization); + const auto normalization_strings = std::get<1>(normalization); + const auto normalization_trie_vector = + builder.CreateVector(normalization_trie); + TrieBuilder normalization_trie_builder(builder); + normalization_trie_builder.add_nodes(normalization_trie_vector); + const auto normalization_trie_fbs = normalization_trie_builder.Finish(); + const auto normalization_strings_fbs = + builder.CreateVector(normalization_strings); + + EncoderConfigBuilder ecb(builder); + ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE); + ecb.add_start_code(model_config.trainer_spec().bos_id()); + ecb.add_end_code(model_config.trainer_spec().eos_id()); + ecb.add_unknown_code(model_config.trainer_spec().unk_id()); + ecb.add_unknown_penalty(min_score - kUnkPenalty); + ecb.add_encoding_offset(encoding_offset); + ecb.add_pieces(pieces_trie_fbs); + ecb.add_pieces_scores(pieces_score_vector); + ecb.add_remove_extra_whitespaces( + model_config.normalizer_spec().remove_extra_whitespaces()); + ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix()); + ecb.add_escape_whitespaces( + model_config.normalizer_spec().escape_whitespaces()); + ecb.add_normalized_prefixes(normalization_trie_fbs); + ecb.add_normalized_replacements(normalization_strings_fbs); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + return std::string(reinterpret_cast(builder.GetBufferPointer()), + builder.GetSize()); +} + +std::string ConvertSentencepieceModel(const std::string& model_string) { + const auto result = ConvertSentencepieceModelToFlatBuffer(model_string); + assert(result.status().ok()); + return result.value(); +} + +} // namespace mediapipe::tflite_operations::sentencepiece diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h new file mode 100644 index 000000000..828db16da --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h @@ -0,0 +1,33 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_ + +#include + +#include "absl/status/statusor.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +// Converts Sentencepiece configuration to flatbuffer format. +// encoding_offset is used by some encoders that combine different encodings. +absl::StatusOr ConvertSentencepieceModelToFlatBuffer( + const std::string& model_config_str, int encoding_offset = 0); +std::string ConvertSentencepieceModel(const std::string& model_string); + +} // namespace mediapipe::tflite_operations::sentencepiece + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.cc new file mode 100644 index 000000000..365b1a5ad --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.cc @@ -0,0 +1,236 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h" + +#include +#include + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h" + +namespace mediapipe::tflite_operations::sentencepiece { +namespace { + +const char kSpaceSymbol[] = "\xe2\x96\x81"; + +template +std::tuple> process_string( + const std::string& input, const std::vector& offsets, + const processing_callback& pc) { + std::string result_string; + result_string.reserve(input.size()); + std::vector result_offsets; + result_offsets.reserve(offsets.size()); + for (int i = 0, j = 0; i < input.size();) { + auto result = pc(input.data() + i, input.size() - i); + auto consumed = std::get<0>(result); + auto new_string = std::get<1>(result); + if (consumed == 0) { + // Skip the current byte and move forward. + result_string.push_back(input[i]); + result_offsets.push_back(offsets[j]); + i++; + j++; + continue; + } + result_string.append(new_string.data(), new_string.length()); + for (int i = 0; i < new_string.length(); ++i) { + result_offsets.push_back(offsets[j]); + } + j += consumed; + i += consumed; + } + return std::make_tuple(result_string, result_offsets); +} + +inline char is_whitespace(char c) { + return c == ' ' || c == '\t' || c == '\r' || c == '\n'; +} + +std::tuple remove_extra_whitespaces(const char* data, + int len) { + if (len == 0 || !is_whitespace(*data)) { + return std::make_tuple(0, utils::string_view(nullptr, 0)); + } + int num_consumed = 1; + for (; num_consumed < len && is_whitespace(data[num_consumed]); + ++num_consumed) { + } + return num_consumed > 1 + ? std::make_tuple(num_consumed, utils::string_view(" ", 1)) + : std::make_tuple(0, utils::string_view(nullptr, 0)); +} + +std::tuple find_replacement( + const char* data, int len, const DoubleArrayTrie& dat, + const flatbuffers::Vector& replacements) { + const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len)); + if (!max_match.empty()) { + // Because flatbuffer byte is signed char which is not the same as char, + // there is the reinterpret_cast here. + const char* replaced_string_ptr = + reinterpret_cast(replacements.data() + max_match.id); + return std::make_tuple(max_match.match_length, + utils::string_view(replaced_string_ptr)); + } + return std::make_tuple(0, utils::string_view(nullptr, 0)); +} +} // namespace + +std::tuple> NormalizeString( + const std::string& in_string, const EncoderConfig& config) { + std::vector output_offsets; + std::string result = in_string; + output_offsets.reserve(in_string.length()); + for (int i = 0; i < in_string.length(); ++i) { + output_offsets.push_back(i); + } + if (in_string.empty()) { + return std::make_tuple(result, output_offsets); + } + if (config.add_dummy_prefix()) { + result.insert(result.begin(), ' '); + output_offsets.insert(output_offsets.begin(), 0); + } + // Greedely replace normalized_prefixes with normalized_replacements + if (config.normalized_prefixes() != nullptr && + config.normalized_replacements() != nullptr) { + const DoubleArrayTrie normalized_prefixes_matcher( + config.normalized_prefixes()->nodes()); + const auto norm_replace = [&config, &normalized_prefixes_matcher]( + const char* data, int len) { + return find_replacement(data, len, normalized_prefixes_matcher, + *config.normalized_replacements()); + }; + std::tie(result, output_offsets) = + process_string(result, output_offsets, norm_replace); + } + if (config.remove_extra_whitespaces()) { + std::tie(result, output_offsets) = + process_string(result, output_offsets, remove_extra_whitespaces); + if (!result.empty() && is_whitespace(result.back())) { + result.pop_back(); + output_offsets.pop_back(); + } + } + if (config.escape_whitespaces()) { + const auto replace_whitespaces = [](const char* data, int len) { + if (len > 0 && is_whitespace(*data)) { + return std::make_tuple(1, utils::string_view(kSpaceSymbol)); + } + return std::make_tuple(0, utils::string_view(nullptr, 0)); + }; + std::tie(result, output_offsets) = + process_string(result, output_offsets, replace_whitespaces); + } + + return std::make_tuple(result, output_offsets); +} + +EncoderResult EncodeNormalizedString(const std::string& str, + const std::vector& offsets, + const EncoderConfig& config, bool add_bos, + bool add_eos, bool reverse) { + const DoubleArrayTrie piece_matcher(config.pieces()->nodes()); + const flatbuffers::Vector* piece_scores = config.pieces_scores(); + const int unknown_code = config.unknown_code(); + const float unknown_penalty = config.unknown_penalty(); + struct LatticeElement { + float score = 0; + int code = -1; + int prev_position = -1; + LatticeElement(float score_, int code_, int prev_position_) + : score(score_), code(code_), prev_position(prev_position_) {} + LatticeElement() {} + }; + const int length = str.length(); + std::vector lattice(length + 1); + for (int i = 0; i < length; ++i) { + if (i > 0 && lattice[i].prev_position < 0) { + // This state is unreachable. + continue; + } + if (unknown_code >= 0) { + // Put unknown code. + const float penalized_score = lattice[i].score + unknown_penalty; + const int pos = i + 1; + LatticeElement& current_element = lattice[pos]; + if (current_element.prev_position < 0 || + current_element.score < penalized_score) { + current_element = LatticeElement( + penalized_score, unknown_code, + // If the current state is already reached by unknown code, merge + // states. + lattice[i].code == unknown_code ? lattice[i].prev_position : i); + } + } + auto lattice_update = [&lattice, i, + piece_scores](const DoubleArrayTrie::Match& m) { + LatticeElement& target_element = lattice[i + m.match_length]; + const float score = lattice[i].score + (*piece_scores)[m.id]; + if (target_element.prev_position < 0 || target_element.score < score) { + target_element = LatticeElement(score, m.id, i); + } + }; + piece_matcher.IteratePrefixMatches( + utils::string_view(str.data() + i, length - i), lattice_update); + } + + EncoderResult result; + if (add_eos) { + result.codes.push_back(config.end_code()); + result.offsets.push_back(length); + } + if (lattice[length].prev_position >= 0) { + for (int pos = length; pos > 0;) { + auto code = lattice[pos].code; + if (code != config.unknown_code()) { + code += config.encoding_offset(); + } + result.codes.push_back(code); + pos = lattice[pos].prev_position; + result.offsets.push_back(offsets[pos]); + } + } + if (add_bos) { + result.codes.push_back(config.start_code()); + result.offsets.push_back(0); + } + if (!reverse) { + std::reverse(result.codes.begin(), result.codes.end()); + std::reverse(result.offsets.begin(), result.offsets.end()); + } + return result; +} + +EncoderResult EncodeString(const std::string& string, const void* config_buffer, + bool add_bos, bool add_eos, bool reverse) { + // Get the config from the buffer. + const EncoderConfig* config = GetEncoderConfig(config_buffer); + if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) { + EncoderResult result; + result.type = EncoderResultType::WRONG_CONFIG; + return result; + } + std::string normalized_string; + std::vector offsets; + std::tie(normalized_string, offsets) = NormalizeString(string, *config); + return EncodeNormalizedString(normalized_string, offsets, *config, add_bos, + add_eos, reverse); +} + +} // namespace mediapipe::tflite_operations::sentencepiece diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h new file mode 100644 index 000000000..849a47849 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ + +// Sentencepiece encoder optimized with memmapped model. + +#include +#include +#include + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 }; + +struct EncoderResult { + EncoderResultType type = EncoderResultType::SUCCESS; + std::vector codes; + std::vector offsets; +}; +std::tuple> NormalizeString( + const std::string& in_string, const EncoderConfig& config); + +// Encodes one string and returns ids and offsets. Takes the configuration as a +// type-erased buffer. +EncoderResult EncodeString(const std::string& string, const void* config_buffer, + bool add_bos, bool add_eos, bool reverse); + +} // namespace mediapipe::tflite_operations::sentencepiece + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder_test.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder_test.cc new file mode 100644 index 000000000..e65bd1850 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder_test.cc @@ -0,0 +1,171 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h" + +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h" +#include "src/sentencepiece.pb.h" +#include "src/sentencepiece_processor.h" +#include "tensorflow/core/platform/env.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +namespace internal { + +tensorflow::Status TFReadFileToString(const std::string& filepath, + std::string* data) { + return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath, + data); +} + +absl::Status StdReadFileToString(const std::string& filepath, + std::string* data) { + std::ifstream infile(filepath); + if (!infile.is_open()) { + return absl::NotFoundError( + absl::StrFormat("Error when opening %s", filepath)); + } + std::string contents((std::istreambuf_iterator(infile)), + (std::istreambuf_iterator())); + data->append(contents); + infile.close(); + return absl::OkStatus(); +} +} // namespace internal + +namespace { + +using ::mediapipe::file::JoinPath; + +static char kConfigFilePath[] = + "/mediapipe/tasks/cc/text/custom_ops/" + "sentencepiece/testdata/sentencepiece.model"; + +TEST(OptimizedEncoder, NormalizeStringWhitestpaces) { + flatbuffers::FlatBufferBuilder builder(1024); + EncoderConfigBuilder ecb(builder); + ecb.add_remove_extra_whitespaces(true); + ecb.add_add_dummy_prefix(true); + ecb.add_escape_whitespaces(true); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + { + const auto result = NormalizeString("x y", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); + EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y"); + EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3)); + } + { + const auto result = NormalizeString("\tx y\n", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); + EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y"); + EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4)); + } +} + +TEST(OptimizedEncoder, NormalizeStringReplacement) { + flatbuffers::FlatBufferBuilder builder(1024); + const std::vector norm_prefixes = {"A", "AA", "AAA", "AAAA"}; + const char norm_replacements[] = "A1\0A2\0A3\0A4"; + const auto trie_vector = + builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9})); + const auto norm_r = builder.CreateVector( + reinterpret_cast(norm_replacements), + sizeof(norm_replacements)); + TrieBuilder trie_builder(builder); + trie_builder.add_nodes(trie_vector); + const auto norm_p = trie_builder.Finish(); + EncoderConfigBuilder ecb(builder); + ecb.add_remove_extra_whitespaces(false); + ecb.add_normalized_prefixes(norm_p); + ecb.add_normalized_replacements(norm_r); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + { + const auto result = NormalizeString("ABAABAAABAAAA", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); + EXPECT_EQ(res_string, "A1BA2BA3BA4"); + EXPECT_THAT(offsets, + ::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9)); + } +} + +TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) { + flatbuffers::FlatBufferBuilder builder(1024); + const std::vector norm_prefixes = {"A", "AA", "AAA", "AAAA", + "X"}; + const char norm_replacements[] = "A1\0A2\0A3\0A4\0 "; + const auto trie_vector = + builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12})); + const auto norm_r = builder.CreateVector( + reinterpret_cast(norm_replacements), + sizeof(norm_replacements)); + TrieBuilder trie_builder(builder); + trie_builder.add_nodes(trie_vector); + const auto norm_p = trie_builder.Finish(); + EncoderConfigBuilder ecb(builder); + ecb.add_remove_extra_whitespaces(true); + ecb.add_normalized_prefixes(norm_p); + ecb.add_normalized_replacements(norm_r); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + { + const auto result = NormalizeString("XXABAABAAABAAAA", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); + EXPECT_EQ(res_string, " A1BA2BA3BA4"); + EXPECT_THAT(offsets, + ::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11)); + } +} + +TEST(OptimizedEncoder, ConfigConverter) { + std::string config; + auto status = + internal::TFReadFileToString(JoinPath("./", kConfigFilePath), &config); + ASSERT_TRUE(status.ok()); + + ::sentencepiece::SentencePieceProcessor processor; + ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok()); + const auto converted_model = ConvertSentencepieceModel(config); + const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95"); + const auto encoded = + EncodeString(test_string, converted_model.data(), false, false, false); + ASSERT_EQ(encoded.codes.size(), encoded.offsets.size()); + + ::sentencepiece::SentencePieceText reference_encoded; + ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok()); + EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size()); + for (int i = 0; i < encoded.codes.size(); ++i) { + EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id()); + EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin()); + } +} + +} // namespace +} // namespace mediapipe::tflite_operations::sentencepiece diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h new file mode 100644 index 000000000..faf481844 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ + +namespace mediapipe::tflite_operations::sentencepiece { + +// The constant is copied from +// https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc +constexpr float kUnkPenalty = 10.0; + +// These constants are copied from +// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc +// +// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK). +constexpr char kSpaceSymbol[] = "\xe2\x96\x81"; + +// Encodes into U+2047 (DOUBLE QUESTION MARK), +// since this character can be useful both for user and +// developer. We can easily figure out that is emitted. +constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 "; + +} // namespace mediapipe::tflite_operations::sentencepiece + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.cc new file mode 100644 index 000000000..468a3a54f --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.cc @@ -0,0 +1,129 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h" + +#include "flatbuffers/flexbuffers.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" + +namespace mediapipe::tflite_operations { +namespace sentencepiece::tokenizer { +namespace { + +using ::tflite::SetTensorToDynamic; + +constexpr int kSPModelIndex = 0; +constexpr int kInputIndex = 1; +constexpr int kAddBOSInput = 4; +constexpr int kAddEOSInput = 5; +constexpr int kReverseInput = 6; + +constexpr int kOutputValuesInd = 0; +constexpr int kOutputSplitsInd = 1; + +TfLiteIntArray* CreateSizeArray(const std::initializer_list& sizes) { + TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size()); + int index = 0; + for (const int size : sizes) { + array_size->data[index++] = size; + } + return array_size; +} +} // namespace + +// Initializes text encoder object from serialized parameters. +void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/, + size_t /*length*/) { + return nullptr; +} +void Free(TfLiteContext* /*context*/, void* /*buffer*/) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // TODO: Add checks for input and output tensors. + TfLiteTensor& output_values = + context->tensors[node->outputs->data[kOutputValuesInd]]; + SetTensorToDynamic(&output_values); + + TfLiteTensor& output_splits = + context->tensors[node->outputs->data[kOutputSplitsInd]]; + SetTensorToDynamic(&output_splits); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor& model_tensor = + context->tensors[node->inputs->data[kSPModelIndex]]; + const auto model_buffer_data = model_tensor.data.data; + const TfLiteTensor& input_text = + context->tensors[node->inputs->data[kInputIndex]]; + + const TfLiteTensor add_bos_tensor = + context->tensors[node->inputs->data[kAddBOSInput]]; + const bool add_bos = add_bos_tensor.data.b[0]; + const TfLiteTensor add_eos_tensor = + context->tensors[node->inputs->data[kAddEOSInput]]; + const bool add_eos = add_eos_tensor.data.b[0]; + const TfLiteTensor reverse_tensor = + context->tensors[node->inputs->data[kReverseInput]]; + const bool reverse = reverse_tensor.data.b[0]; + + std::vector encoded; + std::vector splits; + const int num_strings = tflite::GetStringCount(&input_text); + for (int i = 0; i < num_strings; ++i) { + const auto strref = tflite::GetString(&input_text, i); + const auto res = EncodeString(std::string(strref.str, strref.len), + model_buffer_data, add_bos, add_eos, reverse); + TF_LITE_ENSURE_MSG(context, res.type == EncoderResultType::SUCCESS, + "Sentencepiece conversion failed"); + std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded)); + splits.emplace_back(encoded.size()); + } + + TfLiteTensor& output_values = + context->tensors[node->outputs->data[kOutputValuesInd]]; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor( + context, &output_values, + CreateSizeArray({static_cast(encoded.size())}))); + int32_t* output_values_flat = output_values.data.i32; + std::copy(encoded.begin(), encoded.end(), output_values_flat); + TfLiteTensor& output_splits = + context->tensors[node->outputs->data[kOutputSplitsInd]]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor( + context, &output_splits, + CreateSizeArray({static_cast(splits.size() + 1)}))); + int32_t* output_splits_flat = output_splits.data.i32; + *output_splits_flat = 0; + std::copy(splits.begin(), splits.end(), output_splits_flat + 1); + return kTfLiteOk; +} +} // namespace sentencepiece::tokenizer + +TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER() { + static TfLiteRegistration r = { + sentencepiece::tokenizer::Initialize, sentencepiece::tokenizer::Free, + sentencepiece::tokenizer::Prepare, sentencepiece::tokenizer::Eval}; + return &r; +} + +} // namespace mediapipe::tflite_operations diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h new file mode 100644 index 000000000..8a9fa8aef --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h @@ -0,0 +1,27 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_ + +#include "tensorflow/lite/kernels/register.h" + +namespace mediapipe::tflite_operations { + +TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER(); + +} // namespace mediapipe::tflite_operations + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model new file mode 100644 index 000000000..041188ffd Binary files /dev/null and b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model differ diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc index 5e0be5578..474f0ca35 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc @@ -39,6 +39,8 @@ constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite"; // Embedding model with regex preprocessing. constexpr char kRegexOneEmbeddingModel[] = "regex_one_embedding_with_metadata.tflite"; +constexpr char kUniversalSentenceEncoderModel[] = + "universal_sentence_encoder_qa_with_metadata.tflite"; // Tolerance for embedding vector coordinate values. constexpr float kEpsilon = 1e-4; @@ -147,6 +149,35 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) { MP_ASSERT_OK(text_embedder->Close()); } +TEST(EmbedTest, SucceedsWithUniversalSentenceEncoderModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr text_embedder, + TextEmbedder::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN( + auto result0, + text_embedder->Embed("it's a charming and often affecting journey")); + ASSERT_EQ(result0.embeddings.size(), 1); + ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 100); + ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 1.422951f, kEpsilon); + + MP_ASSERT_OK_AND_ASSIGN( + auto result1, text_embedder->Embed("what a great and fantastic trip")); + ASSERT_EQ(result1.embeddings.size(), 1); + ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 100); + ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 1.404664f, kEpsilon); + + // Check cosine similarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], + result1.embeddings[0])); + ASSERT_NEAR(similarity, 0.851961, kSimilarityTolerancy); + + MP_ASSERT_OK(text_embedder->Close()); +} + TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) { auto options = std::make_unique(); options->base_options.model_asset_path = @@ -178,5 +209,31 @@ TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) { MP_ASSERT_OK(text_embedder->Close()); } +TEST_F(EmbedderTest, SucceedsWithUSEAndDifferentThemes) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr text_embedder, + TextEmbedder::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result0, + text_embedder->Embed("When you go to this restaurant, they hold the " + "pancake upside-down before they hand it " + "to you. It's a great gimmick.")); + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result1, + text_embedder->Embed( + "Let's make a plan to steal the declaration of independence.")); + + // Check cosine similarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], + result1.embeddings[0])); + EXPECT_NEAR(similarity, 0.780334, kSimilarityTolerancy); + + MP_ASSERT_OK(text_embedder->Close()); +} + } // namespace } // namespace mediapipe::tasks::text::text_embedder diff --git a/mediapipe/tasks/cc/vision/face_stylizer/BUILD b/mediapipe/tasks/cc/vision/face_stylizer/BUILD index bdbf340b8..27b2f482d 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/BUILD +++ b/mediapipe/tasks/cc/vision/face_stylizer/BUILD @@ -23,18 +23,12 @@ cc_library( srcs = ["face_stylizer_graph.cc"], deps = [ "//mediapipe/calculators/core:split_vector_calculator_cc_proto", - "//mediapipe/calculators/image:image_cropping_calculator", - "//mediapipe/calculators/image:image_cropping_calculator_cc_proto", - "//mediapipe/calculators/image:warp_affine_calculator", - "//mediapipe/calculators/image:warp_affine_calculator_cc_proto", + "//mediapipe/calculators/image:image_clone_calculator_cc_proto", "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/calculators/util:face_to_rect_calculator", - "//mediapipe/calculators/util:from_image_calculator", - "//mediapipe/calculators/util:inverse_matrix_calculator", "//mediapipe/calculators/util:landmarks_to_detection_calculator_cc_proto", - "//mediapipe/calculators/util:to_image_calculator", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", @@ -53,7 +47,6 @@ cc_library( "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/face_stylizer/calculators:strip_rotation_calculator", "//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator", "//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto", "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h index 14c23b7a8..36bb11bd7 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h +++ b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h @@ -84,9 +84,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi { // The input image can be of any size with format RGB or RGBA. // When no face is detected on the input image, the method returns a // std::nullopt. Otherwise, returns the stylized image of the most visible - // face. To ensure that the output image has reasonable quality, the stylized - // output image size is the smaller of the model output size and the size of - // the 'region_of_interest' specified in 'image_processing_options'. + // face. The stylized output image size is the same as the model output size. absl::StatusOr> Stylize( mediapipe::Image image, std::optional image_processing_options = @@ -111,9 +109,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi { // must be monotonically increasing. // When no face is detected on the input image, the method returns a // std::nullopt. Otherwise, returns the stylized image of the most visible - // face. To ensure that the output image has reasonable quality, the stylized - // output image size is the smaller of the model output size and the size of - // the 'region_of_interest' specified in 'image_processing_options'. + // face. The stylized output image size is the same as the model output size. absl::StatusOr> StylizeForVideo( mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options = @@ -143,10 +139,8 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi { // The "result_callback" provides: // - When no face is detected on the input image, the method returns a // std::nullopt. Otherwise, returns the stylized image of the most visible - // face. To ensure that the output image has reasonable quality, the - // stylized output image size is the smaller of the model output size and - // the size of the 'region_of_interest' specified in - // 'image_processing_options'. + // face. The stylized output image size is the same as the model output + // size. // - The input timestamp in milliseconds. absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms, std::optional diff --git a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc index bf717a71d..27b8dacc1 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc +++ b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc @@ -19,8 +19,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "mediapipe/calculators/core/split_vector_calculator.pb.h" -#include "mediapipe/calculators/image/image_cropping_calculator.pb.h" -#include "mediapipe/calculators/image/warp_affine_calculator.pb.h" +#include "mediapipe/calculators/image/image_clone_calculator.pb.h" #include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" #include "mediapipe/calculators/util/landmarks_to_detection_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" @@ -326,7 +325,6 @@ class FaceStylizerGraph : public core::ModelTaskGraph { image_in >> preprocessing.In(kImageTag); face_rect >> preprocessing.In(kNormRectTag); auto preprocessed_tensors = preprocessing.Out(kTensorsTag); - auto transform_matrix = preprocessing.Out(kMatrixTag); // Adds inference subgraph and connects its input stream to the output // tensors produced by the ImageToTensorCalculator. @@ -344,53 +342,12 @@ class FaceStylizerGraph : public core::ModelTaskGraph { model_output_tensors >> tensors_to_image.In(kTensorsTag); auto tensor_image = tensors_to_image.Out(kImageTag); - auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator"); - transform_matrix >> inverse_matrix.In(kMatrixTag); - auto inverse_transform_matrix = inverse_matrix.Out(kMatrixTag); + auto& image_converter = graph.AddNode("ImageCloneCalculator"); + image_converter.GetOptions() + .set_output_on_gpu(false); + tensor_image >> image_converter.In(""); - auto& warp_affine = graph.AddNode("WarpAffineCalculator"); - auto& warp_affine_options = - warp_affine.GetOptions(); - warp_affine_options.set_border_mode( - WarpAffineCalculatorOptions::BORDER_ZERO); - warp_affine_options.set_gpu_origin(mediapipe::GpuOrigin_Mode_TOP_LEFT); - tensor_image >> warp_affine.In(kImageTag); - inverse_transform_matrix >> warp_affine.In(kMatrixTag); - image_size >> warp_affine.In(kOutputSizeTag); - auto image_to_crop = warp_affine.Out(kImageTag); - - // The following calculators are for cropping and resizing the output image - // based on the roi and the model output size. As the WarpAffineCalculator - // rotates the image based on the transform matrix, the rotation info in the - // rect proto is stripped to prevent the ImageCroppingCalculator from - // performing extra rotation. - auto& strip_rotation = - graph.AddNode("mediapipe.tasks.StripRotationCalculator"); - face_rect >> strip_rotation.In(kNormRectTag); - auto norm_rect_no_rotation = strip_rotation.Out(kNormRectTag); - auto& from_image = graph.AddNode("FromImageCalculator"); - image_to_crop >> from_image.In(kImageTag); - auto& image_cropping = graph.AddNode("ImageCroppingCalculator"); - auto& image_cropping_opts = - image_cropping.GetOptions(); - image_cropping_opts.set_output_max_width( - image_to_tensor_options.output_tensor_width()); - image_cropping_opts.set_output_max_height( - image_to_tensor_options.output_tensor_height()); - norm_rect_no_rotation >> image_cropping.In(kNormRectTag); - auto& to_image = graph.AddNode("ToImageCalculator"); - // ImageCroppingCalculator currently doesn't support mediapipe::Image, the - // graph selects its cpu or gpu path based on the image preprocessing - // backend. - if (use_gpu) { - from_image.Out(kImageGpuTag) >> image_cropping.In(kImageGpuTag); - image_cropping.Out(kImageGpuTag) >> to_image.In(kImageGpuTag); - } else { - from_image.Out(kImageCpuTag) >> image_cropping.In(kImageTag); - image_cropping.Out(kImageTag) >> to_image.In(kImageCpuTag); - } - - return {{/*stylized_image=*/to_image.Out(kImageTag).Cast(), + return {{/*stylized_image=*/image_converter.Out("").Cast(), /*original_image=*/preprocessing.Out(kImageTag).Cast()}}; } }; diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD index 07cf793e9..19f546257 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD @@ -100,6 +100,7 @@ cc_library( "//mediapipe/util:graph_builder_utils", "@com_google_absl//absl/status:statusor", ], + alwayslink = 1, ) cc_library( diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h index cd464c6a1..bbc9aa8a5 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h @@ -90,7 +90,7 @@ NS_SWIFT_NAME(ClassificationResult) * amount of data to process might exceed the maximum size that the model can process: to solve * this, the input data is split into multiple chunks starting at different timestamps. */ -@property(nonatomic, readonly) NSInteger timestampMs; +@property(nonatomic, readonly) NSInteger timestampInMilliseconds; /** * Initializes a new `MPPClassificationResult` with the given array of classifications and time @@ -98,14 +98,15 @@ NS_SWIFT_NAME(ClassificationResult) * * @param classifications An Array of `MPPClassifications` objects containing the predicted * categories for each head of the model. - * @param timestampMs The timestamp (in milliseconds) of the start of the chunk of data + * @param timestampInMilliseconds The timestamp (in milliseconds) of the start of the chunk of data * corresponding to these results. * * @return An instance of `MPPClassificationResult` initialized with the given array of - * classifications and timestampMs. + * classifications and timestamp (in milliseconds). */ - (instancetype)initWithClassifications:(NSArray *)classifications - timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + NS_DESIGNATED_INITIALIZER; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m index 6d42d22ca..8d9440492 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m @@ -38,11 +38,11 @@ @implementation MPPClassificationResult - (instancetype)initWithClassifications:(NSArray *)classifications - timestampMs:(NSInteger)timestampMs { + timestampInMilliseconds:(NSInteger)timestampInMilliseconds { self = [super init]; if (self) { _classifications = classifications; - _timestampMs = timestampMs; + _timestampInMilliseconds = timestampInMilliseconds; } return self; diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h index 8fd9b9dff..4cfd8890d 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h @@ -33,7 +33,7 @@ NS_SWIFT_NAME(EmbeddingResult) * cases, the amount of data to process might exceed the maximum size that the model can process. To * solve this, the input data is split into multiple chunks starting at different timestamps. */ -@property(nonatomic, readonly) NSInteger timestampMs; +@property(nonatomic, readonly) NSInteger timestampInMilliseconds; /** * Initializes a new `MPPEmbedding` with the given array of embeddings and timestamp (in @@ -41,14 +41,14 @@ NS_SWIFT_NAME(EmbeddingResult) * * @param embeddings An Array of `MPPEmbedding` objects containing the embedding results for each * head of the model. - * @param timestampMs The optional timestamp (in milliseconds) of the start of the chunk of data - * corresponding to these results. Pass `0` if timestamp is absent. + * @param timestampInMilliseconds The optional timestamp (in milliseconds) of the start of the chunk + * of data corresponding to these results. Pass `0` if timestamp is absent. * * @return An instance of `MPPEmbeddingResult` initialized with the given array of embeddings and - * timestampMs. + * timestamp (in milliseconds). */ - (instancetype)initWithEmbeddings:(NSArray *)embeddings - timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds NS_DESIGNATED_INITIALIZER; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.m index 56dd30fdd..1f4828583 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.m @@ -17,11 +17,11 @@ @implementation MPPEmbeddingResult - (instancetype)initWithEmbeddings:(NSArray *)embeddings - timestampMs:(NSInteger)timestampMs { + timestampInMilliseconds:(NSInteger)timestampInMilliseconds { self = [super init]; if (self) { _embeddings = embeddings; - _timestampMs = timestampMs; + _timestampInMilliseconds = timestampInMilliseconds; } return self; diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm index b02b032bb..47f1cf45c 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm @@ -55,13 +55,13 @@ using ClassificationResultProto = [classifications addObject:[MPPClassifications classificationsWithProto:classificationsProto]]; } - NSInteger timestampMs = 0; + NSInteger timestampInMilliseconds = 0; if (classificationResultProto.has_timestamp_ms()) { - timestampMs = (NSInteger)classificationResultProto.timestamp_ms(); + timestampInMilliseconds = (NSInteger)classificationResultProto.timestamp_ms(); } return [[MPPClassificationResult alloc] initWithClassifications:classifications - timestampMs:timestampMs]; + timestampInMilliseconds:timestampInMilliseconds]; ; } diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm index f9863e9ca..cf5569c07 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm @@ -31,12 +31,13 @@ using EmbeddingResultProto = ::mediapipe::tasks::components::containers::proto:: [embeddings addObject:[MPPEmbedding embeddingWithProto:embeddingProto]]; } - NSInteger timestampMs = 0; + NSInteger timestampInMilliseconds = 0; if (embeddingResultProto.has_timestamp_ms()) { - timestampMs = (NSInteger)embeddingResultProto.timestamp_ms(); + timestampInMilliseconds = (NSInteger)embeddingResultProto.timestamp_ms(); } - return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings timestampMs:timestampMs]; + return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings + timestampInMilliseconds:timestampInMilliseconds]; } @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h index 4ee7b2fc6..664a94ba6 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -26,11 +26,12 @@ NS_SWIFT_NAME(TaskResult) /** * Timestamp that is associated with the task result object. */ -@property(nonatomic, assign, readonly) NSInteger timestampMs; +@property(nonatomic, assign, readonly) NSInteger timestampInMilliseconds; - (instancetype)init NS_UNAVAILABLE; -- (instancetype)initWithTimestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; +- (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds + NS_DESIGNATED_INITIALIZER; @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m index 6c08014ff..8a7fa6b5b 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m @@ -16,16 +16,16 @@ @implementation MPPTaskResult -- (instancetype)initWithTimestampMs:(NSInteger)timestampMs { +- (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds { self = [super init]; if (self) { - _timestampMs = timestampMs; + _timestampInMilliseconds = timestampInMilliseconds; } return self; } - (id)copyWithZone:(NSZone *)zone { - return [[MPPTaskResult alloc] initWithTimestampMs:self.timestampMs]; + return [[MPPTaskResult alloc] initWithTimestampInMilliseconds:self.timestampInMilliseconds]; } @end diff --git a/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m b/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m index 613239944..d3a027b6c 100644 --- a/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m +++ b/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m @@ -487,7 +487,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; NSError *liveStreamApiCallError; XCTAssertFalse([imageClassifier classifyAsyncImage:image - timestampMs:0 + timestampInMilliseconds:0 error:&liveStreamApiCallError]); NSError *expectedLiveStreamApiCallError = @@ -501,7 +501,9 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError); NSError *videoApiCallError; - XCTAssertFalse([imageClassifier classifyVideoFrame:image timestampMs:0 error:&videoApiCallError]); + XCTAssertFalse([imageClassifier classifyVideoFrame:image + timestampInMilliseconds:0 + error:&videoApiCallError]); NSError *expectedVideoApiCallError = [NSError errorWithDomain:kExpectedErrorDomain @@ -524,7 +526,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; NSError *liveStreamApiCallError; XCTAssertFalse([imageClassifier classifyAsyncImage:image - timestampMs:0 + timestampInMilliseconds:0 error:&liveStreamApiCallError]); NSError *expectedLiveStreamApiCallError = @@ -575,7 +577,9 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; AssertEqualErrors(imageApiCallError, expectedImageApiCallError); NSError *videoApiCallError; - XCTAssertFalse([imageClassifier classifyVideoFrame:image timestampMs:0 error:&videoApiCallError]); + XCTAssertFalse([imageClassifier classifyVideoFrame:image + timestampInMilliseconds:0 + error:&videoApiCallError]); NSError *expectedVideoApiCallError = [NSError errorWithDomain:kExpectedErrorDomain @@ -601,7 +605,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; for (int i = 0; i < 3; i++) { MPPImageClassifierResult *imageClassifierResult = [imageClassifier classifyVideoFrame:image - timestampMs:i + timestampInMilliseconds:i error:nil]; [self assertImageClassifierResult:imageClassifierResult hasExpectedCategoriesCount:maxResults @@ -630,10 +634,10 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; MPPImage *image = [self imageWithFileInfo:kBurgerImage]; - XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampMs:1 error:nil]); + XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:1 error:nil]); NSError *error; - XCTAssertFalse([imageClassifier classifyAsyncImage:image timestampMs:0 error:&error]); + XCTAssertFalse([imageClassifier classifyAsyncImage:image timestampInMilliseconds:0 error:&error]); NSError *expectedError = [NSError errorWithDomain:kExpectedErrorDomain @@ -668,7 +672,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; MPPImage *image = [self imageWithFileInfo:kBurgerImage]; for (int i = 0; i < 3; i++) { - XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampMs:i error:nil]); + XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:i error:nil]); } } diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h index 6744a8e16..9ce7fcec2 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h @@ -31,13 +31,13 @@ NS_SWIFT_NAME(TextClassifierResult) * * @param classificationResult The `MPPClassificationResult` instance containing one set of results * per classifier head. - * @param timestampMs The timestamp for this result. + * @param timestampInMilliseconds The timestamp (in milliseconds) for this result. * * @return An instance of `MPPTextClassifierResult` initialized with the given * `MPPClassificationResult` and timestamp (in milliseconds). */ - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - timestampMs:(NSInteger)timestampMs; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds; @end diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m index 4d5c1104a..09a2097cc 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m @@ -17,8 +17,8 @@ @implementation MPPTextClassifierResult - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - timestampMs:(NSInteger)timestampMs { - self = [super initWithTimestampMs:timestampMs]; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds { + self = [super initWithTimestampInMilliseconds:timestampInMilliseconds]; if (self) { _classificationResult = classificationResult; } diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm index f5d6aa1d3..5a924016e 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm @@ -35,7 +35,7 @@ using ::mediapipe::Packet; return [[MPPTextClassifierResult alloc] initWithClassificationResult:classificationResult - timestampMs:(NSInteger)(packet.Timestamp().Value() / + timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond)]; } diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h index e4697dcef..ab8edd16b 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h @@ -31,13 +31,13 @@ NS_SWIFT_NAME(TextEmbedderResult) * * @param embeddingResult The `MPPEmbeddingResult` instance containing one set of results * per classifier head. - * @param timestampMs The timestamp for this result. + * @param timestampInMilliseconds The timestamp (in millisecondss) for this result. * * @return An instance of `MPPTextEmbedderResult` initialized with the given * `MPPEmbeddingResult` and timestamp (in milliseconds). */ - (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult - timestampMs:(NSInteger)timestampMs; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.m b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.m index 5483e3c3f..d764f63d6 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.m +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.m @@ -17,8 +17,8 @@ @implementation MPPTextEmbedderResult - (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult - timestampMs:(NSInteger)timestampMs { - self = [super initWithTimestampMs:timestampMs]; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds { + self = [super initWithTimestampInMilliseconds:timestampInMilliseconds]; if (self) { _embeddingResult = embeddingResult; } diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm index b769292ce..3534ea66d 100644 --- a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm +++ b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm @@ -34,7 +34,7 @@ using ::mediapipe::Packet; return [[MPPTextEmbedderResult alloc] initWithEmbeddingResult:embeddingResult - timestampMs:(NSInteger)(packet.Timestamp().Value() / + timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond)]; } diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h index eaf059ad2..ed07c6d90 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h +++ b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h @@ -41,7 +41,7 @@ * timestamp. * * @param image The image to send to the MediaPipe graph. - * @param timestampMs The timestamp (in milliseconds) to assign to the packet. + * @param timestampInMilliseconds The timestamp (in milliseconds) to assign to the packet. * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no * error will be saved. * @@ -49,7 +49,7 @@ * occurred during the conversion. */ + (mediapipe::Packet)createPacketWithMPPImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError **)error; /** @@ -66,11 +66,11 @@ * specified timestamp. * * @param image The `NormalizedRect` to send to the MediaPipe graph. - * @param timestampMs The timestamp (in milliseconds) to assign to the packet. + * @param timestampInMilliseconds The timestamp (in milliseconds) to assign to the packet. * * @return The MediaPipe packet containing the normalized rect. */ + (mediapipe::Packet)createPacketWithNormalizedRect:(mediapipe::NormalizedRect &)normalizedRect - timestampMs:(NSInteger)timestampMs; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds; @end diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm index bf136a759..af419c6d0 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm +++ b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm @@ -42,7 +42,7 @@ using ::mediapipe::Timestamp; } + (Packet)createPacketWithMPPImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError **)error { std::unique_ptr imageFrame = [image imageFrameWithError:error]; @@ -51,7 +51,7 @@ using ::mediapipe::Timestamp; } return MakePacket(std::move(imageFrame)) - .At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond))); + .At(Timestamp(int64(timestampInMilliseconds * kMicroSecondsPerMilliSecond))); } + (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect { @@ -59,9 +59,9 @@ using ::mediapipe::Timestamp; } + (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect - timestampMs:(NSInteger)timestampMs { + timestampInMilliseconds:(NSInteger)timestampInMilliseconds { return MakePacket(std::move(normalizedRect)) - .At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond))); + .At(Timestamp(int64(timestampInMilliseconds * kMicroSecondsPerMilliSecond))); } @end diff --git a/mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizerResult.m b/mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizerResult.m index 1fa1a9d37..3ffb15392 100644 --- a/mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizerResult.m +++ b/mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizerResult.m @@ -21,7 +21,7 @@ handedness:(NSArray *> *)handedness gestures:(NSArray *> *)gestures timestampInMilliseconds:(NSInteger)timestampInMilliseconds { - self = [super initWithTimestampMs:timestampInMilliseconds]; + self = [super initWithTimestampInMilliseconds:timestampInMilliseconds]; if (self) { _landmarks = landmarks; _worldLandmarks = worldLandmarks; diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h index 581c8d95b..345687877 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h @@ -122,17 +122,17 @@ NS_SWIFT_NAME(ImageClassifier) * `MPPRunningModeVideo`. * * @param image The `MPPImage` on which image classification is to be performed. - * @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be - * monotonically increasing. + * @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input + * timestamps must be monotonically increasing. * @param error An optional error parameter populated when there is an error in performing image * classification on the input video frame. * * @return An `MPPImageClassifierResult` object that contains a list of image classifications. */ - (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError **)error - NS_SWIFT_NAME(classify(videoFrame:timestampMs:)); + NS_SWIFT_NAME(classify(videoFrame:timestampInMilliseconds:)); /** * Performs image classification on the provided video frame of type `MPPImage` cropped to the @@ -145,8 +145,8 @@ NS_SWIFT_NAME(ImageClassifier) * * @param image A live stream image data of type `MPPImage` on which image classification is to be * performed. - * @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be - * monotonically increasing. + * @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input + * timestamps must be monotonically increasing. * @param roi A `CGRect` specifying the region of interest within the video frame of type * `MPPImage`, on which image classification should be performed. * @param error An optional error parameter populated when there is an error in performing image @@ -155,10 +155,10 @@ NS_SWIFT_NAME(ImageClassifier) * @return An `MPPImageClassifierResult` object that contains a list of image classifications. */ - (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds regionOfInterest:(CGRect)roi error:(NSError **)error - NS_SWIFT_NAME(classify(videoFrame:timestampMs:regionOfInterest:)); + NS_SWIFT_NAME(classify(videoFrame:timestampInMilliseconds:regionOfInterest:)); /** * Sends live stream image data of type `MPPImage` to perform image classification using the whole @@ -172,16 +172,17 @@ NS_SWIFT_NAME(ImageClassifier) * * @param image A live stream image data of type `MPPImage` on which image classification is to be * performed. - * @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent - * to the image classifier. The input timestamps must be monotonically increasing. + * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input + * image is sent to the image classifier. The input timestamps must be monotonically increasing. * @param error An optional error parameter populated when there is an error in performing image * classification on the input live stream image data. * * @return `YES` if the image was sent to the task successfully, otherwise `NO`. */ - (BOOL)classifyAsyncImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - error:(NSError **)error NS_SWIFT_NAME(classifyAsync(image:timestampMs:)); + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError **)error + NS_SWIFT_NAME(classifyAsync(image:timestampInMilliseconds:)); /** * Sends live stream image data of type `MPPImage` to perform image classification, cropped to the @@ -195,8 +196,8 @@ NS_SWIFT_NAME(ImageClassifier) * * @param image A live stream image data of type `MPPImage` on which image classification is to be * performed. - * @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent - * to the image classifier. The input timestamps must be monotonically increasing. + * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input + * image is sent to the image classifier. The input timestamps must be monotonically increasing. * @param roi A `CGRect` specifying the region of interest within the given live stream image data * of type `MPPImage`, on which image classification should be performed. * @param error An optional error parameter populated when there is an error in performing image @@ -205,10 +206,10 @@ NS_SWIFT_NAME(ImageClassifier) * @return `YES` if the image was sent to the task successfully, otherwise `NO`. */ - (BOOL)classifyAsyncImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - regionOfInterest:(CGRect)roi - error:(NSError **)error - NS_SWIFT_NAME(classifyAsync(image:timestampMs:regionOfInterest:)); + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + regionOfInterest:(CGRect)roi + error:(NSError **)error + NS_SWIFT_NAME(classifyAsync(image:timestampInMilliseconds:regionOfInterest:)); - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm index 8051fbf3d..18c1bb56a 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm @@ -149,7 +149,7 @@ static NSString *const kTaskGraphName = } - (std::optional)inputPacketMapWithMPPImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds regionOfInterest:(CGRect)roi error:(NSError **)error { std::optional rect = @@ -162,14 +162,15 @@ static NSString *const kTaskGraphName = } Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds error:error]; if (imagePacket.IsEmpty()) { return std::nullopt; } - Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() - timestampMs:timestampMs]; + Packet normalizedRectPacket = + [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() + timestampInMilliseconds:timestampInMilliseconds]; PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); return inputPacketMap; @@ -180,11 +181,11 @@ static NSString *const kTaskGraphName = } - (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds regionOfInterest:(CGRect)roi error:(NSError **)error { std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:roi error:error]; if (!inputPacketMap.has_value()) { @@ -204,20 +205,20 @@ static NSString *const kTaskGraphName = } - (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError **)error { return [self classifyVideoFrame:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:CGRectZero error:error]; } - (BOOL)classifyAsyncImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - regionOfInterest:(CGRect)roi - error:(NSError **)error { + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + regionOfInterest:(CGRect)roi + error:(NSError **)error { std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:roi error:error]; if (!inputPacketMap.has_value()) { @@ -228,10 +229,10 @@ static NSString *const kTaskGraphName = } - (BOOL)classifyAsyncImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - error:(NSError **)error { + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError **)error { return [self classifyAsyncImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:CGRectZero error:error]; } diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.h b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.h index 92fdb13cb..478bd452a 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.h +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.h @@ -31,13 +31,13 @@ NS_SWIFT_NAME(ImageClassifierResult) * * @param classificationResult The `MPPClassificationResult` instance containing one set of results * per classifier head. - * @param timestampMs The timestamp for this result. + * @param timestampInMilliseconds The timestamp (in milliseconds) for this result. * * @return An instance of `MPPImageClassifierResult` initialized with the given * `MPPClassificationResult` and timestamp (in milliseconds). */ - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - timestampMs:(NSInteger)timestampMs; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds; @end diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.m b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.m index 6dcd064eb..cb17bb10e 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.m +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.m @@ -17,8 +17,8 @@ @implementation MPPImageClassifierResult - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - timestampMs:(NSInteger)timestampMs { - self = [super initWithTimestampMs:timestampMs]; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds { + self = [super initWithTimestampInMilliseconds:timestampInMilliseconds]; if (self) { _classificationResult = classificationResult; } diff --git a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm index 09e21b278..f5199765d 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm +++ b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm @@ -34,7 +34,7 @@ using ::mediapipe::Packet; return [[MPPImageClassifierResult alloc] initWithClassificationResult:classificationResult - timestampMs:(NSInteger)(packet.Timestamp().Value() / + timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond)]; } diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h index 590867bf8..da9899d40 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h @@ -36,13 +36,13 @@ NS_SWIFT_NAME(ObjectDetectionResult) * @param detections An array of `MPPDetection` objects each of which has a bounding box that is * expressed in the unrotated input frame of reference coordinates system, i.e. in `[0,image_width) * x [0,image_height)`, which are the dimensions of the underlying image data. - * @param timestampMs The timestamp for this result. + * @param timestampInMilliseconds The timestamp (in milliseconds) for this result. * * @return An instance of `MPPObjectDetectionResult` initialized with the given array of detections * and timestamp (in milliseconds). */ - (instancetype)initWithDetections:(NSArray *)detections - timestampMs:(NSInteger)timestampMs; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds; @end diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.m b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.m index ac24c19fa..47902bba4 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.m +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.m @@ -17,8 +17,8 @@ @implementation MPPObjectDetectionResult - (instancetype)initWithDetections:(NSArray *)detections - timestampMs:(NSInteger)timestampMs { - self = [super initWithTimestampMs:timestampMs]; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds { + self = [super initWithTimestampInMilliseconds:timestampInMilliseconds]; if (self) { _detections = detections; } diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h index 58344d0c7..f92c90c50 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h @@ -138,8 +138,8 @@ NS_SWIFT_NAME(ObjectDetector) * `MPPRunningModeVideo`. * * @param image The `MPPImage` on which object detection is to be performed. - * @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be - * monotonically increasing. + * @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input + * timestamps must be monotonically increasing. * @param error An optional error parameter populated when there is an error in performing object * detection on the input image. * @@ -149,9 +149,9 @@ NS_SWIFT_NAME(ObjectDetector) * image data. */ - (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError **)error - NS_SWIFT_NAME(detect(videoFrame:timestampMs:)); + NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:)); /** * Performs object detection on the provided video frame of type `MPPImage` cropped to the @@ -164,8 +164,8 @@ NS_SWIFT_NAME(ObjectDetector) * * @param image A live stream image data of type `MPPImage` on which object detection is to be * performed. - * @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be - * monotonically increasing. + * @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input + * timestamps must be monotonically increasing. * @param roi A `CGRect` specifying the region of interest within the given `MPPImage`, on which * object detection should be performed. * @@ -178,10 +178,10 @@ NS_SWIFT_NAME(ObjectDetector) * image data. */ - (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds regionOfInterest:(CGRect)roi error:(NSError **)error - NS_SWIFT_NAME(detect(videoFrame:timestampMs:regionOfInterest:)); + NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:regionOfInterest:)); /** * Sends live stream image data of type `MPPImage` to perform object detection using the whole @@ -195,16 +195,17 @@ NS_SWIFT_NAME(ObjectDetector) * * @param image A live stream image data of type `MPPImage` on which object detection is to be * performed. - * @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent - * to the object detector. The input timestamps must be monotonically increasing. + * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input + * image is sent to the object detector. The input timestamps must be monotonically increasing. * @param error An optional error parameter populated when there is an error in performing object * detection on the input live stream image data. * * @return `YES` if the image was sent to the task successfully, otherwise `NO`. */ - (BOOL)detectAsyncInImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - error:(NSError **)error NS_SWIFT_NAME(detectAsync(image:timestampMs:)); + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError **)error + NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:)); /** * Sends live stream image data of type `MPPImage` to perform object detection, cropped to the @@ -218,8 +219,8 @@ NS_SWIFT_NAME(ObjectDetector) * * @param image A live stream image data of type `MPPImage` on which object detection is to be * performed. - * @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent - * to the object detector. The input timestamps must be monotonically increasing. + * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input + * image is sent to the object detector. The input timestamps must be monotonically increasing. * @param roi A `CGRect` specifying the region of interest within the given live stream image data * of type `MPPImage`, on which iobject detection should be performed. * @param error An optional error parameter populated when there is an error in performing object @@ -228,10 +229,10 @@ NS_SWIFT_NAME(ObjectDetector) * @return `YES` if the image was sent to the task successfully, otherwise `NO`. */ - (BOOL)detectAsyncInImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - regionOfInterest:(CGRect)roi - error:(NSError **)error - NS_SWIFT_NAME(detectAsync(image:timestampMs:regionOfInterest:)); + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + regionOfInterest:(CGRect)roi + error:(NSError **)error + NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:regionOfInterest:)); - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm index 53dcad4a8..e1aa11e96 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm @@ -157,7 +157,7 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG } - (std::optional)inputPacketMapWithMPPImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds regionOfInterest:(CGRect)roi error:(NSError **)error { std::optional rect = @@ -170,14 +170,15 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG } Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds error:error]; if (imagePacket.IsEmpty()) { return std::nullopt; } - Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() - timestampMs:timestampMs]; + Packet normalizedRectPacket = + [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() + timestampInMilliseconds:timestampInMilliseconds]; PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); return inputPacketMap; @@ -188,11 +189,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG } - (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds regionOfInterest:(CGRect)roi error:(NSError **)error { std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:roi error:error]; if (!inputPacketMap.has_value()) { @@ -212,20 +213,20 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG } - (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError **)error { return [self detectInVideoFrame:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:CGRectZero error:error]; } - (BOOL)detectAsyncInImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - regionOfInterest:(CGRect)roi - error:(NSError **)error { + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + regionOfInterest:(CGRect)roi + error:(NSError **)error { std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:roi error:error]; if (!inputPacketMap.has_value()) { @@ -236,10 +237,10 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG } - (BOOL)detectAsyncInImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - error:(NSError **)error { + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError **)error { return [self detectAsyncInImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:CGRectZero error:error]; } diff --git a/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.mm b/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.mm index 3507b7d72..225a6993d 100644 --- a/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.mm +++ b/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.mm @@ -38,8 +38,9 @@ using ::mediapipe::Packet; } return [[MPPObjectDetectionResult alloc] - initWithDetections:detections - timestampMs:(NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond)]; + initWithDetections:detections + timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond)]; } @end diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java index b95e9021f..d6f565c78 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java @@ -198,9 +198,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *
  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The image can be of any size. To ensure that the output image has reasonable quality, the - * size of the stylized output is based the model output size and can be smaller than the input - * image. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is created @@ -220,9 +220,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *

  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The input image can be of any size. To ensure that the output image has reasonable quality, - * the stylized output image size is the smaller of the model output size and the size of the - * {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the @@ -256,9 +256,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *

  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The image can be of any size. To ensure that the output image has reasonable quality, the - * size of the stylized output is based the model output size and can be smaller than the input - * image. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a @@ -281,9 +281,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *

  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The input image can be of any size. To ensure that the output image has reasonable quality, - * the stylized output image size is the smaller of the model output size and the size of the - * {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the @@ -320,9 +320,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *

  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The image can be of any size. To ensure that the output image has reasonable quality, the - * size of the stylized output is based the model output size and can be smaller than the input - * image. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the input timestamp (in milliseconds). @@ -346,9 +346,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *

  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The input image can be of any size. To ensure that the output image has reasonable quality, - * the stylized output image size is the smaller of the model output size and the size of the - * {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * * @param image a MediaPipe {@link MPImage} object for processing. * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the @@ -387,9 +387,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *

  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The image can be of any size. To ensure that the output image has reasonable quality, the - * size of the stylized output is based the model output size and can be smaller than the input - * image. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the input timestamp (in milliseconds). @@ -414,9 +414,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *

  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The input image can be of any size. To ensure that the output image has reasonable quality, - * the stylized output image size is the smaller of the model output size and the size of the - * {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the input timestamp (in milliseconds). @@ -445,9 +445,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { * *

    {@link FaceStylizer} supports the following color space types: * - *

    The image can be of any size. To ensure that the output image has reasonable quality, the - * size of the stylized output is based the model output * size and can be smaller than the input - * image. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * *

      *
    • {@link android.graphics.Bitmap.Config#ARGB_8888} @@ -475,9 +475,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *
    • {@link android.graphics.Bitmap.Config#ARGB_8888} *
    * - *

    The input image can be of any size. To ensure that the output image has reasonable quality, - * the stylized output image size is the smaller of the model output size and the size of the - * {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java index b38bd1c86..6f47797b4 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -94,15 +94,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { "IMAGE:" + IMAGE_IN_STREAM_NAME, "ROI:" + ROI_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); - private static final List OUTPUT_STREAMS = - Collections.unmodifiableList( - Arrays.asList( - "GROUPED_SEGMENTATION:segmented_mask_out", - "IMAGE:image_out", - "SEGMENTATION:0:segmentation")); - private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0; - private static final int IMAGE_OUT_STREAM_INDEX = 1; - private static final int SEGMENTATION_OUT_STREAM_INDEX = 2; + private static final int IMAGE_OUT_STREAM_INDEX = 0; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"; private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = @@ -123,6 +115,21 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { */ public static InteractiveSegmenter createFromOptions( Context context, InteractiveSegmenterOptions segmenterOptions) { + if (!segmenterOptions.outputConfidenceMasks() && !segmenterOptions.outputCategoryMask()) { + throw new IllegalArgumentException( + "At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set."); + } + List outputStreams = new ArrayList<>(); + outputStreams.add("IMAGE:image_out"); + if (segmenterOptions.outputConfidenceMasks()) { + outputStreams.add("CONFIDENCE_MASKS:confidence_masks"); + } + final int confidenceMasksOutStreamIndex = outputStreams.size() - 1; + if (segmenterOptions.outputCategoryMask()) { + outputStreams.add("CATEGORY_MASK:category_mask"); + } + final int categoryMaskOutStreamIndex = outputStreams.size() - 1; + // TODO: Consolidate OutputHandler and TaskRunner. OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( @@ -130,52 +137,72 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { @Override public ImageSegmenterResult convertToTaskResult(List packets) throws MediaPipeException { - if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { + if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) { return ImageSegmenterResult.create( Optional.empty(), Optional.empty(), - packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); + packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp()); } - List segmentedMasks = new ArrayList<>(); - int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); - int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); - int imageFormat = - segmenterOptions.outputType() - == InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK - ? MPImage.IMAGE_FORMAT_VEC32F1 - : MPImage.IMAGE_FORMAT_ALPHA; - int imageListSize = - PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)); - ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; - // If resultListener is not provided, the resulted MPImage is deep copied from mediapipe - // graph. If provided, the result MPImage is wrapping the mediapipe packet memory. - if (!segmenterOptions.resultListener().isPresent()) { - for (int i = 0; i < imageListSize; i++) { - buffersArray[i] = - ByteBuffer.allocateDirect( - width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1)); + // If resultListener is not provided, the resulted MPImage is deep copied from + // mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet + // memory. + boolean copyImage = !segmenterOptions.resultListener().isPresent(); + Optional> confidenceMasks = Optional.empty(); + if (segmenterOptions.outputConfidenceMasks()) { + confidenceMasks = Optional.of(new ArrayList<>()); + int width = + PacketGetter.getImageWidthFromImageList( + packets.get(confidenceMasksOutStreamIndex)); + int height = + PacketGetter.getImageHeightFromImageList( + packets.get(confidenceMasksOutStreamIndex)); + int imageListSize = + PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex)); + ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; + // confidence masks are float type image. + final int numBytes = 4; + if (copyImage) { + for (int i = 0; i < imageListSize; i++) { + buffersArray[i] = ByteBuffer.allocateDirect(width * height * numBytes); + } + } + if (!PacketGetter.getImageList( + packets.get(confidenceMasksOutStreamIndex), buffersArray, copyImage)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "There is an error getting confidence masks."); + } + for (ByteBuffer buffer : buffersArray) { + ByteBufferImageBuilder builder = + new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1); + confidenceMasks.get().add(builder.build()); } } - if (!PacketGetter.getImageList( - packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), - buffersArray, - !segmenterOptions.resultListener().isPresent())) { - throw new MediaPipeException( - MediaPipeException.StatusCode.INTERNAL.ordinal(), - "There is an error getting segmented masks. It usually results from incorrect" - + " options of unsupported OutputType of given model."); - } - for (ByteBuffer buffer : buffersArray) { + Optional categoryMask = Optional.empty(); + if (segmenterOptions.outputCategoryMask()) { + int width = PacketGetter.getImageWidth(packets.get(categoryMaskOutStreamIndex)); + int height = PacketGetter.getImageHeight(packets.get(categoryMaskOutStreamIndex)); + ByteBuffer buffer; + if (copyImage) { + buffer = ByteBuffer.allocateDirect(width * height); + if (!PacketGetter.getImageData(packets.get(categoryMaskOutStreamIndex), buffer)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "There is an error getting category mask."); + } + } else { + buffer = PacketGetter.getImageDataDirectly(packets.get(categoryMaskOutStreamIndex)); + } ByteBufferImageBuilder builder = - new ByteBufferImageBuilder(buffer, width, height, imageFormat); - segmentedMasks.add(builder.build()); + new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA); + categoryMask = Optional.of(builder.build()); } return ImageSegmenterResult.create( - Optional.of(segmentedMasks), - Optional.empty(), + confidenceMasks, + categoryMask, BaseVisionTaskApi.generateResultTimestampMs( - RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); + RunningMode.IMAGE, packets.get(IMAGE_OUT_STREAM_INDEX))); } @Override @@ -195,7 +222,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { .setTaskRunningModeName(RunningMode.IMAGE.name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) - .setOutputStreams(OUTPUT_STREAMS) + .setOutputStreams(outputStreams) .setTaskOptions(segmenterOptions) .setEnableFlowLimiting(false) .build(), @@ -394,8 +421,11 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { /** Sets the base options for the image segmenter task. */ public abstract Builder setBaseOptions(BaseOptions value); - /** The output type from image segmenter. */ - public abstract Builder setOutputType(OutputType value); + /** Sets whether to output confidence masks. Default to true. */ + public abstract Builder setOutputConfidenceMasks(boolean value); + + /** Sets whether to output category mask. Default to false. */ + public abstract Builder setOutputCategoryMask(boolean value); /** * Sets an optional {@link ResultListener} to receive the segmentation results when the graph @@ -417,25 +447,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { abstract BaseOptions baseOptions(); - abstract OutputType outputType(); + abstract boolean outputConfidenceMasks(); + + abstract boolean outputCategoryMask(); abstract Optional> resultListener(); abstract Optional errorListener(); - /** The output type of segmentation results. */ - public enum OutputType { - // Gives a single output mask where each pixel represents the class which - // the pixel in the original image was predicted to belong to. - CATEGORY_MASK, - // Gives a list of output masks where, for each mask, each pixel represents - // the prediction confidence, usually in the [0, 1] range. - CONFIDENCE_MASK - } - public static Builder builder() { return new AutoValue_InteractiveSegmenter_InteractiveSegmenterOptions.Builder() - .setOutputType(OutputType.CATEGORY_MASK); + .setOutputConfidenceMasks(true) + .setOutputCategoryMask(false); } /** @@ -454,14 +477,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder = SegmenterOptionsProto.SegmenterOptions.newBuilder(); - if (outputType() == OutputType.CONFIDENCE_MASK) { - segmenterOptionsBuilder.setOutputType( - SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK); - } else if (outputType() == OutputType.CATEGORY_MASK) { - segmenterOptionsBuilder.setOutputType( - SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK); - } - taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizerTest.java index 5b880f419..4f6cc2d68 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizerTest.java @@ -234,8 +234,8 @@ public class FaceStylizerTest { FaceStylizerResult actualResult = faceStylizer.stylize(inputImage); MPImage stylizedImage = actualResult.stylizedImage().get(); assertThat(stylizedImage).isNotNull(); - assertThat(stylizedImage.getWidth()).isEqualTo(83); - assertThat(stylizedImage.getHeight()).isEqualTo(83); + assertThat(stylizedImage.getWidth()).isEqualTo(modelImageSize); + assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize); } @Test diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java index f32ab7976..3a6854949 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java @@ -53,18 +53,15 @@ public class InteractiveSegmenterTest { InteractiveSegmenterOptions options = InteractiveSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(InteractiveSegmenterOptions.OutputType.CATEGORY_MASK) + .setOutputConfidenceMasks(false) + .setOutputCategoryMask(true) .build(); InteractiveSegmenter imageSegmenter = InteractiveSegmenter.createFromOptions( ApplicationProvider.getApplicationContext(), options); MPImage image = getImageFromAsset(inputImageName); ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi); - // TODO update to correct category mask output. - // After InteractiveSegmenter updated according to (b/276519300), update this to use - // categoryMask field instead of confidenceMasks. - List segmentations = actualResult.confidenceMasks().get(); - assertThat(segmentations.size()).isEqualTo(1); + assertThat(actualResult.categoryMask().isPresent()).isTrue(); } @Test @@ -75,15 +72,17 @@ public class InteractiveSegmenterTest { InteractiveSegmenterOptions options = InteractiveSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK) + .setOutputConfidenceMasks(true) + .setOutputCategoryMask(false) .build(); InteractiveSegmenter imageSegmenter = InteractiveSegmenter.createFromOptions( ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName), roi); - List segmentations = actualResult.confidenceMasks().get(); - assertThat(segmentations.size()).isEqualTo(2); + assertThat(actualResult.confidenceMasks().isPresent()).isTrue(); + List confidenceMasks = actualResult.confidenceMasks().get(); + assertThat(confidenceMasks.size()).isEqualTo(2); } } diff --git a/mediapipe/tasks/python/core/pybind/task_runner.cc b/mediapipe/tasks/python/core/pybind/task_runner.cc index aa48a1a9a..250c0fa62 100644 --- a/mediapipe/tasks/python/core/pybind/task_runner.cc +++ b/mediapipe/tasks/python/core/pybind/task_runner.cc @@ -204,6 +204,11 @@ This can be useful for resetting a stateful task graph to process new data. Raises: RuntimeError: The underlying medipaipe graph fails to reset and restart. )doc"); + + task_runner.def( + "get_graph_config", + [](TaskRunner* self) { return self->GetGraphConfig(); }, + R"doc(Returns the canonicalized CalculatorGraphConfig of the underlying graph.)doc"); } } // namespace python diff --git a/mediapipe/tasks/python/test/text/text_embedder_test.py b/mediapipe/tasks/python/test/text/text_embedder_test.py index 78e98a1b4..62d162f6e 100644 --- a/mediapipe/tasks/python/test/text/text_embedder_test.py +++ b/mediapipe/tasks/python/test/text/text_embedder_test.py @@ -32,6 +32,7 @@ _TextEmbedderOptions = text_embedder.TextEmbedderOptions _BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite' _REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite' +_USE_MODEL_FILE = 'universal_sentence_encoder_qa_with_metadata.tflite' _TEST_DATA_DIR = 'mediapipe/tasks/testdata/text' # Tolerance for embedding vector coordinate values. _EPSILON = 1e-4 @@ -138,6 +139,24 @@ class TextEmbedderTest(parameterized.TestCase): 16, (0.549632, 0.552879), ), + ( + False, + False, + _USE_MODEL_FILE, + ModelFileType.FILE_NAME, + 0.851961, + 100, + (1.422951, 1.404664), + ), + ( + True, + False, + _USE_MODEL_FILE, + ModelFileType.FILE_CONTENT, + 0.851961, + 100, + (0.127049, 0.125416), + ), ) def test_embed(self, l2_normalize, quantize, model_name, model_file_type, expected_similarity, expected_size, expected_first_values): @@ -213,6 +232,24 @@ class TextEmbedderTest(parameterized.TestCase): 16, (0.549632, 0.552879), ), + ( + False, + False, + _USE_MODEL_FILE, + ModelFileType.FILE_NAME, + 0.851961, + 100, + (1.422951, 1.404664), + ), + ( + True, + False, + _USE_MODEL_FILE, + ModelFileType.FILE_CONTENT, + 0.851961, + 100, + (0.127049, 0.125416), + ), ) def test_embed_in_context(self, l2_normalize, quantize, model_name, model_file_type, expected_similarity, expected_size, @@ -251,6 +288,7 @@ class TextEmbedderTest(parameterized.TestCase): @parameterized.parameters( # TODO: The similarity should likely be lower (_BERT_MODEL_FILE, 0.980880), + (_USE_MODEL_FILE, 0.780334), ) def test_embed_with_different_themes(self, model_file, expected_similarity): # Creates embedder. diff --git a/mediapipe/tasks/python/test/vision/image_segmenter_test.py b/mediapipe/tasks/python/test/vision/image_segmenter_test.py index 7f0b47eb7..c11cea865 100644 --- a/mediapipe/tasks/python/test/vision/image_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/image_segmenter_test.py @@ -15,7 +15,6 @@ import enum import os -from typing import List from unittest import mock from absl.testing import absltest @@ -30,11 +29,10 @@ from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_segmenter from mediapipe.tasks.python.vision.core import vision_task_running_mode +ImageSegmenterResult = image_segmenter.ImageSegmenterResult _BaseOptions = base_options_module.BaseOptions _Image = image_module.Image _ImageFormat = image_frame.ImageFormat -_OutputType = image_segmenter.ImageSegmenterOptions.OutputType -_Activation = image_segmenter.ImageSegmenterOptions.Activation _ImageSegmenter = image_segmenter.ImageSegmenter _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode @@ -42,11 +40,33 @@ _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode _MODEL_FILE = 'deeplabv3.tflite' _IMAGE_FILE = 'segmentation_input_rotation0.jpg' _SEGMENTATION_FILE = 'segmentation_golden_rotation0.png' +_CAT_IMAGE = 'cat.jpg' +_CAT_MASK = 'cat_mask.jpg' _MASK_MAGNIFICATION_FACTOR = 10 _MASK_SIMILARITY_THRESHOLD = 0.98 _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' +def _calculate_soft_iou(m1, m2): + intersection_sum = np.sum(m1 * m2) + union_sum = np.sum(m1 * m1) + np.sum(m2 * m2) - intersection_sum + + if union_sum > 0: + return intersection_sum / union_sum + else: + return 0 + + +def _similar_to_float_mask(actual_mask, expected_mask, similarity_threshold): + actual_mask = actual_mask.numpy_view() + expected_mask = expected_mask.numpy_view() / 255.0 + + return ( + actual_mask.shape == expected_mask.shape + and _calculate_soft_iou(actual_mask, expected_mask) > similarity_threshold + ) + + def _similar_to_uint8_mask(actual_mask, expected_mask): actual_mask_pixels = actual_mask.numpy_view().flatten() expected_mask_pixels = expected_mask.numpy_view().flatten() @@ -56,8 +76,9 @@ def _similar_to_uint8_mask(actual_mask, expected_mask): for index in range(num_pixels): consistent_pixels += ( - actual_mask_pixels[index] * - _MASK_MAGNIFICATION_FACTOR == expected_mask_pixels[index]) + actual_mask_pixels[index] * _MASK_MAGNIFICATION_FACTOR + == expected_mask_pixels[index] + ) return consistent_pixels / num_pixels >= _MASK_SIMILARITY_THRESHOLD @@ -73,16 +94,27 @@ class ImageSegmenterTest(parameterized.TestCase): super().setUp() # Load the test input image. self.test_image = _Image.create_from_file( - test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))) + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)) + ) # Loads ground truth segmentation file. gt_segmentation_data = cv2.imread( test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)), - cv2.IMREAD_GRAYSCALE) + os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE) + ), + cv2.IMREAD_GRAYSCALE, + ) self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data) self.model_path = test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, _MODEL_FILE)) + os.path.join(_TEST_DATA_DIR, _MODEL_FILE) + ) + + def _load_segmentation_mask(self, file_path: str): + # Loads ground truth segmentation file. + gt_segmentation_data = cv2.imread( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_path)), + cv2.IMREAD_GRAYSCALE, + ) + return _Image(_ImageFormat.GRAY8, gt_segmentation_data) def test_create_from_file_succeeds_with_valid_model_path(self): # Creates with default option and valid model file successfully. @@ -98,9 +130,11 @@ class ImageSegmenterTest(parameterized.TestCase): def test_create_from_options_fails_with_invalid_model_path(self): with self.assertRaisesRegex( - RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite' + ): base_options = _BaseOptions( - model_asset_path='/path/to/invalid/model.tflite') + model_asset_path='/path/to/invalid/model.tflite' + ) options = _ImageSegmenterOptions(base_options=base_options) _ImageSegmenter.create_from_options(options) @@ -112,8 +146,9 @@ class ImageSegmenterTest(parameterized.TestCase): segmenter = _ImageSegmenter.create_from_options(options) self.assertIsInstance(segmenter, _ImageSegmenter) - @parameterized.parameters((ModelFileType.FILE_NAME,), - (ModelFileType.FILE_CONTENT,)) + @parameterized.parameters( + (ModelFileType.FILE_NAME,), (ModelFileType.FILE_CONTENT,) + ) def test_segment_succeeds_with_category_mask(self, model_file_type): # Creates segmenter. if model_file_type is ModelFileType.FILE_NAME: @@ -127,22 +162,27 @@ class ImageSegmenterTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _ImageSegmenterOptions( - base_options=base_options, output_type=_OutputType.CATEGORY_MASK) + base_options=base_options, + output_category_mask=True, + output_confidence_masks=False, + ) segmenter = _ImageSegmenter.create_from_options(options) # Performs image segmentation on the input. - category_masks = segmenter.segment(self.test_image) - self.assertLen(category_masks, 1) - category_mask = category_masks[0] + segmentation_result = segmenter.segment(self.test_image) + category_mask = segmentation_result.category_mask result_pixels = category_mask.numpy_view().flatten() # Check if data type of `category_mask` is correct. self.assertEqual(result_pixels.dtype, np.uint8) self.assertTrue( - _similar_to_uint8_mask(category_masks[0], self.test_seg_image), - f'Number of pixels in the candidate mask differing from that of the ' - f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') + _similar_to_uint8_mask(category_mask, self.test_seg_image), + ( + 'Number of pixels in the candidate mask differing from that of the' + f' ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.' + ), + ) # Closes the segmenter explicitly when the segmenter is not used in # a context. @@ -152,74 +192,46 @@ class ImageSegmenterTest(parameterized.TestCase): # Creates segmenter. base_options = _BaseOptions(model_asset_path=self.model_path) - # Run segmentation on the model in CATEGORY_MASK mode. - options = _ImageSegmenterOptions( - base_options=base_options, output_type=_OutputType.CATEGORY_MASK) - segmenter = _ImageSegmenter.create_from_options(options) - category_masks = segmenter.segment(self.test_image) - category_mask = category_masks[0].numpy_view() + # Load the cat image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE)) + ) # Run segmentation on the model in CONFIDENCE_MASK mode. options = _ImageSegmenterOptions( base_options=base_options, - output_type=_OutputType.CONFIDENCE_MASK, - activation=_Activation.SOFTMAX) - segmenter = _ImageSegmenter.create_from_options(options) - confidence_masks = segmenter.segment(self.test_image) + output_category_mask=False, + output_confidence_masks=True, + ) - # Check if confidence mask shape is correct. - self.assertLen( - confidence_masks, 21, - 'Number of confidence masks must match with number of categories.') - - # Gather the confidence masks in a single array `confidence_mask_array`. - confidence_mask_array = np.array( - [confidence_mask.numpy_view() for confidence_mask in confidence_masks]) - - # Check if data type of `confidence_masks` are correct. - self.assertEqual(confidence_mask_array.dtype, np.float32) - - # Compute the category mask from the created confidence mask. - calculated_category_mask = np.argmax(confidence_mask_array, axis=0) - self.assertListEqual( - calculated_category_mask.tolist(), category_mask.tolist(), - 'Confidence mask does not match with the category mask.') - - # Closes the segmenter explicitly when the segmenter is not used in - # a context. - segmenter.close() - - @parameterized.parameters((ModelFileType.FILE_NAME), - (ModelFileType.FILE_CONTENT)) - def test_segment_in_context(self, model_file_type): - if model_file_type is ModelFileType.FILE_NAME: - base_options = _BaseOptions(model_asset_path=self.model_path) - elif model_file_type is ModelFileType.FILE_CONTENT: - with open(self.model_path, 'rb') as f: - model_contents = f.read() - base_options = _BaseOptions(model_asset_buffer=model_contents) - else: - # Should never happen - raise ValueError('model_file_type is invalid.') - - options = _ImageSegmenterOptions( - base_options=base_options, output_type=_OutputType.CATEGORY_MASK) with _ImageSegmenter.create_from_options(options) as segmenter: - # Performs image segmentation on the input. - category_masks = segmenter.segment(self.test_image) - self.assertLen(category_masks, 1) + segmentation_result = segmenter.segment(test_image) + confidence_masks = segmentation_result.confidence_masks + + # Check if confidence mask shape is correct. + self.assertLen( + confidence_masks, + 21, + 'Number of confidence masks must match with number of categories.', + ) + + # Loads ground truth segmentation file. + expected_mask = self._load_segmentation_mask(_CAT_MASK) self.assertTrue( - _similar_to_uint8_mask(category_masks[0], self.test_seg_image), - f'Number of pixels in the candidate mask differing from that of the ' - f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') + _similar_to_float_mask( + confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD + ) + ) def test_missing_result_callback(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.LIVE_STREAM) - with self.assertRaisesRegex(ValueError, - r'result callback must be provided'): + running_mode=_RUNNING_MODE.LIVE_STREAM, + ) + with self.assertRaisesRegex( + ValueError, r'result callback must be provided' + ): with _ImageSegmenter.create_from_options(options) as unused_segmenter: pass @@ -228,130 +240,236 @@ class ImageSegmenterTest(parameterized.TestCase): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=running_mode, - result_callback=mock.MagicMock()) - with self.assertRaisesRegex(ValueError, - r'result callback should not be provided'): + result_callback=mock.MagicMock(), + ) + with self.assertRaisesRegex( + ValueError, r'result callback should not be provided' + ): with _ImageSegmenter.create_from_options(options) as unused_segmenter: pass def test_calling_segment_for_video_in_image_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.IMAGE) + running_mode=_RUNNING_MODE.IMAGE, + ) with _ImageSegmenter.create_from_options(options) as segmenter: - with self.assertRaisesRegex(ValueError, - r'not initialized with the video mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the video mode' + ): segmenter.segment_for_video(self.test_image, 0) def test_calling_segment_async_in_image_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.IMAGE) + running_mode=_RUNNING_MODE.IMAGE, + ) with _ImageSegmenter.create_from_options(options) as segmenter: - with self.assertRaisesRegex(ValueError, - r'not initialized with the live stream mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the live stream mode' + ): segmenter.segment_async(self.test_image, 0) def test_calling_segment_in_video_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO) + running_mode=_RUNNING_MODE.VIDEO, + ) with _ImageSegmenter.create_from_options(options) as segmenter: - with self.assertRaisesRegex(ValueError, - r'not initialized with the image mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the image mode' + ): segmenter.segment(self.test_image) def test_calling_segment_async_in_video_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO) + running_mode=_RUNNING_MODE.VIDEO, + ) with _ImageSegmenter.create_from_options(options) as segmenter: - with self.assertRaisesRegex(ValueError, - r'not initialized with the live stream mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the live stream mode' + ): segmenter.segment_async(self.test_image, 0) def test_segment_for_video_with_out_of_order_timestamp(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO) + running_mode=_RUNNING_MODE.VIDEO, + ) with _ImageSegmenter.create_from_options(options) as segmenter: unused_result = segmenter.segment_for_video(self.test_image, 1) with self.assertRaisesRegex( - ValueError, r'Input timestamp must be monotonically increasing'): + ValueError, r'Input timestamp must be monotonically increasing' + ): segmenter.segment_for_video(self.test_image, 0) - def test_segment_for_video(self): + def test_segment_for_video_in_category_mask_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - output_type=_OutputType.CATEGORY_MASK, - running_mode=_RUNNING_MODE.VIDEO) + output_category_mask=True, + output_confidence_masks=False, + running_mode=_RUNNING_MODE.VIDEO, + ) with _ImageSegmenter.create_from_options(options) as segmenter: for timestamp in range(0, 300, 30): - category_masks = segmenter.segment_for_video(self.test_image, timestamp) - self.assertLen(category_masks, 1) + segmentation_result = segmenter.segment_for_video( + self.test_image, timestamp + ) + category_mask = segmentation_result.category_mask self.assertTrue( - _similar_to_uint8_mask(category_masks[0], self.test_seg_image), - f'Number of pixels in the candidate mask differing from that of the ' - f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') + _similar_to_uint8_mask(category_mask, self.test_seg_image), + ( + 'Number of pixels in the candidate mask differing from that of' + f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.' + ), + ) + + def test_segment_for_video_in_confidence_mask_mode(self): + # Load the cat image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE)) + ) + + options = _ImageSegmenterOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + output_category_mask=False, + output_confidence_masks=True, + ) + with _ImageSegmenter.create_from_options(options) as segmenter: + for timestamp in range(0, 300, 30): + segmentation_result = segmenter.segment_for_video(test_image, timestamp) + confidence_masks = segmentation_result.confidence_masks + + # Check if confidence mask shape is correct. + self.assertLen( + confidence_masks, + 21, + 'Number of confidence masks must match with number of categories.', + ) + + # Loads ground truth segmentation file. + expected_mask = self._load_segmentation_mask(_CAT_MASK) + self.assertTrue( + _similar_to_float_mask( + confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD + ) + ) def test_calling_segment_in_live_stream_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=mock.MagicMock()) + result_callback=mock.MagicMock(), + ) with _ImageSegmenter.create_from_options(options) as segmenter: - with self.assertRaisesRegex(ValueError, - r'not initialized with the image mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the image mode' + ): segmenter.segment(self.test_image) def test_calling_segment_for_video_in_live_stream_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=mock.MagicMock()) + result_callback=mock.MagicMock(), + ) with _ImageSegmenter.create_from_options(options) as segmenter: - with self.assertRaisesRegex(ValueError, - r'not initialized with the video mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the video mode' + ): segmenter.segment_for_video(self.test_image, 0) def test_segment_async_calls_with_illegal_timestamp(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=mock.MagicMock()) + result_callback=mock.MagicMock(), + ) with _ImageSegmenter.create_from_options(options) as segmenter: segmenter.segment_async(self.test_image, 100) with self.assertRaisesRegex( - ValueError, r'Input timestamp must be monotonically increasing'): + ValueError, r'Input timestamp must be monotonically increasing' + ): segmenter.segment_async(self.test_image, 0) - def test_segment_async_calls(self): + def test_segment_async_calls_in_category_mask_mode(self): observed_timestamp_ms = -1 - def check_result(result: List[image_module.Image], output_image: _Image, - timestamp_ms: int): + def check_result( + result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int + ): # Get the output category mask. - category_mask = result[0] + category_mask = result.category_mask self.assertEqual(output_image.width, self.test_image.width) self.assertEqual(output_image.height, self.test_image.height) self.assertEqual(output_image.width, self.test_seg_image.width) self.assertEqual(output_image.height, self.test_seg_image.height) self.assertTrue( _similar_to_uint8_mask(category_mask, self.test_seg_image), - f'Number of pixels in the candidate mask differing from that of the ' - f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') + ( + 'Number of pixels in the candidate mask differing from that of' + f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.' + ), + ) self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - output_type=_OutputType.CATEGORY_MASK, + output_category_mask=True, + output_confidence_masks=False, running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=check_result) + result_callback=check_result, + ) with _ImageSegmenter.create_from_options(options) as segmenter: for timestamp in range(0, 300, 30): segmenter.segment_async(self.test_image, timestamp) + def test_segment_async_calls_in_confidence_mask_mode(self): + # Load the cat image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE)) + ) + + # Loads ground truth segmentation file. + expected_mask = self._load_segmentation_mask(_CAT_MASK) + observed_timestamp_ms = -1 + + def check_result( + result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int + ): + # Get the output category mask. + confidence_masks = result.confidence_masks + + # Check if confidence mask shape is correct. + self.assertLen( + confidence_masks, + 21, + 'Number of confidence masks must match with number of categories.', + ) + self.assertEqual(output_image.width, test_image.width) + self.assertEqual(output_image.height, test_image.height) + self.assertTrue( + _similar_to_float_mask( + confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD + ) + ) + self.assertLess(observed_timestamp_ms, timestamp_ms) + self.observed_timestamp_ms = timestamp_ms + + options = _ImageSegmenterOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + output_category_mask=False, + output_confidence_masks=True, + result_callback=check_result, + ) + with _ImageSegmenter.create_from_options(options) as segmenter: + for timestamp in range(0, 300, 30): + segmenter.segment_async(test_image, timestamp) + if __name__ == '__main__': absltest.main() diff --git a/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py b/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py index e8c52ae3e..2e0039b15 100644 --- a/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py @@ -30,12 +30,12 @@ from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import interactive_segmenter from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module +InteractiveSegmenterResult = interactive_segmenter.InteractiveSegmenterResult _BaseOptions = base_options_module.BaseOptions _Image = image_module.Image _ImageFormat = image_frame.ImageFormat _NormalizedKeypoint = keypoint_module.NormalizedKeypoint _Rect = rect.Rect -_OutputType = interactive_segmenter.InteractiveSegmenterOptions.OutputType _InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter _InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions _RegionOfInterest = interactive_segmenter.RegionOfInterest @@ -200,15 +200,16 @@ class InteractiveSegmenterTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _InteractiveSegmenterOptions( - base_options=base_options, output_type=_OutputType.CATEGORY_MASK + base_options=base_options, + output_category_mask=True, + output_confidence_masks=False, ) segmenter = _InteractiveSegmenter.create_from_options(options) # Performs image segmentation on the input. roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) - category_masks = segmenter.segment(self.test_image, roi) - self.assertLen(category_masks, 1) - category_mask = category_masks[0] + segmentation_result = segmenter.segment(self.test_image, roi) + category_mask = segmentation_result.category_mask result_pixels = category_mask.numpy_view().flatten() # Check if data type of `category_mask` is correct. @@ -219,7 +220,7 @@ class InteractiveSegmenterTest(parameterized.TestCase): self.assertTrue( _similar_to_uint8_mask( - category_masks[0], test_seg_image, similarity_threshold + category_mask, test_seg_image, similarity_threshold ), ( 'Number of pixels in the candidate mask differing from that of the' @@ -254,12 +255,15 @@ class InteractiveSegmenterTest(parameterized.TestCase): # Run segmentation on the model in CONFIDENCE_MASK mode. options = _InteractiveSegmenterOptions( - base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK + base_options=base_options, + output_category_mask=False, + output_confidence_masks=True, ) with _InteractiveSegmenter.create_from_options(options) as segmenter: # Perform segmentation - confidence_masks = segmenter.segment(self.test_image, roi) + segmentation_result = segmenter.segment(self.test_image, roi) + confidence_masks = segmentation_result.confidence_masks # Check if confidence mask shape is correct. self.assertLen( @@ -287,15 +291,18 @@ class InteractiveSegmenterTest(parameterized.TestCase): # Run segmentation on the model in CONFIDENCE_MASK mode. options = _InteractiveSegmenterOptions( - base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK + base_options=base_options, + output_category_mask=False, + output_confidence_masks=True, ) with _InteractiveSegmenter.create_from_options(options) as segmenter: # Perform segmentation image_processing_options = _ImageProcessingOptions(rotation_degrees=-90) - confidence_masks = segmenter.segment( + segmentation_result = segmenter.segment( self.test_image, roi, image_processing_options ) + confidence_masks = segmentation_result.confidence_masks # Check if confidence mask shape is correct. self.assertLen( @@ -314,7 +321,9 @@ class InteractiveSegmenterTest(parameterized.TestCase): # Run segmentation on the model in CONFIDENCE_MASK mode. options = _InteractiveSegmenterOptions( - base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK + base_options=base_options, + output_category_mask=False, + output_confidence_masks=True, ) with self.assertRaisesRegex( diff --git a/mediapipe/tasks/python/vision/__init__.py b/mediapipe/tasks/python/vision/__init__.py index 49fe03059..53cbf026e 100644 --- a/mediapipe/tasks/python/vision/__init__.py +++ b/mediapipe/tasks/python/vision/__init__.py @@ -32,6 +32,7 @@ FaceDetectorResult = face_detector.FaceDetectorResult FaceLandmarker = face_landmarker.FaceLandmarker FaceLandmarkerOptions = face_landmarker.FaceLandmarkerOptions FaceLandmarkerResult = face_landmarker.FaceLandmarkerResult +FaceLandmarksConnections = face_landmarker.FaceLandmarksConnections FaceStylizer = face_stylizer.FaceStylizer FaceStylizerOptions = face_stylizer.FaceStylizerOptions GestureRecognizer = gesture_recognizer.GestureRecognizer diff --git a/mediapipe/tasks/python/vision/core/base_vision_task_api.py b/mediapipe/tasks/python/vision/core/base_vision_task_api.py index eb976153e..d9195d3ce 100644 --- a/mediapipe/tasks/python/vision/core/base_vision_task_api.py +++ b/mediapipe/tasks/python/vision/core/base_vision_task_api.py @@ -208,6 +208,11 @@ class BaseVisionTaskApi(object): """ self._runner.close() + def get_graph_config(self) -> calculator_pb2.CalculatorGraphConfig: + """Returns the canonicalized CalculatorGraphConfig of the underlying graph. + """ + return self._runner.get_graph_config() + def __enter__(self): """Return `self` upon entering the runtime context.""" return self diff --git a/mediapipe/tasks/python/vision/face_landmarker.py b/mediapipe/tasks/python/vision/face_landmarker.py index 3e43e8a7f..c5b24499f 100644 --- a/mediapipe/tasks/python/vision/face_landmarker.py +++ b/mediapipe/tasks/python/vision/face_landmarker.py @@ -120,6 +120,2741 @@ class Blendshapes(enum.IntEnum): NOSE_SNEER_RIGHT = 51 +class FaceLandmarksConnections: + """The connections between face landmarks.""" + + @dataclasses.dataclass + class Connection: + """The connection class for face landmarks.""" + + start: int + end: int + + FACE_LANDMARKS_LIPS: List[Connection] = [ + Connection(61, 146), + Connection(146, 91), + Connection(91, 181), + Connection(181, 84), + Connection(84, 17), + Connection(17, 314), + Connection(314, 405), + Connection(405, 321), + Connection(321, 375), + Connection(375, 291), + Connection(61, 185), + Connection(185, 40), + Connection(40, 39), + Connection(39, 37), + Connection(37, 0), + Connection(0, 267), + Connection(267, 269), + Connection(269, 270), + Connection(270, 409), + Connection(409, 291), + Connection(78, 95), + Connection(95, 88), + Connection(88, 178), + Connection(178, 87), + Connection(87, 14), + Connection(14, 317), + Connection(317, 402), + Connection(402, 318), + Connection(318, 324), + Connection(324, 308), + Connection(78, 191), + Connection(191, 80), + Connection(80, 81), + Connection(81, 82), + Connection(82, 13), + Connection(13, 312), + Connection(312, 311), + Connection(311, 310), + Connection(310, 415), + Connection(415, 308), + ] + + FACE_LANDMARKS_LEFT_EYE: List[Connection] = [ + Connection(263, 249), + Connection(249, 390), + Connection(390, 373), + Connection(373, 374), + Connection(374, 380), + Connection(380, 381), + Connection(381, 382), + Connection(382, 362), + Connection(263, 466), + Connection(466, 388), + Connection(388, 387), + Connection(387, 386), + Connection(386, 385), + Connection(385, 384), + Connection(384, 398), + Connection(398, 362), + ] + + FACE_LANDMARKS_LEFT_EYEBROW: List[Connection] = [ + Connection(276, 283), + Connection(283, 282), + Connection(282, 295), + Connection(295, 285), + Connection(300, 293), + Connection(293, 334), + Connection(334, 296), + Connection(296, 336), + ] + + FACE_LANDMARKS_LEFT_IRIS: List[Connection] = [ + Connection(474, 475), + Connection(475, 476), + Connection(476, 477), + Connection(477, 474), + ] + + FACE_LANDMARKS_RIGHT_EYE: List[Connection] = [ + Connection(33, 7), + Connection(7, 163), + Connection(163, 144), + Connection(144, 145), + Connection(145, 153), + Connection(153, 154), + Connection(154, 155), + Connection(155, 133), + Connection(33, 246), + Connection(246, 161), + Connection(161, 160), + Connection(160, 159), + Connection(159, 158), + Connection(158, 157), + Connection(157, 173), + Connection(173, 133), + ] + + FACE_LANDMARKS_RIGHT_EYEBROW: List[Connection] = [ + Connection(46, 53), + Connection(53, 52), + Connection(52, 65), + Connection(65, 55), + Connection(70, 63), + Connection(63, 105), + Connection(105, 66), + Connection(66, 107), + ] + + FACE_LANDMARKS_RIGHT_IRIS: List[Connection] = [ + Connection(469, 470), + Connection(470, 471), + Connection(471, 472), + Connection(472, 469), + ] + + FACE_LANDMARKS_FACE_OVAL: List[Connection] = [ + Connection(10, 338), + Connection(338, 297), + Connection(297, 332), + Connection(332, 284), + Connection(284, 251), + Connection(251, 389), + Connection(389, 356), + Connection(356, 454), + Connection(454, 323), + Connection(323, 361), + Connection(361, 288), + Connection(288, 397), + Connection(397, 365), + Connection(365, 379), + Connection(379, 378), + Connection(378, 400), + Connection(400, 377), + Connection(377, 152), + Connection(152, 148), + Connection(148, 176), + Connection(176, 149), + Connection(149, 150), + Connection(150, 136), + Connection(136, 172), + Connection(172, 58), + Connection(58, 132), + Connection(132, 93), + Connection(93, 234), + Connection(234, 127), + Connection(127, 162), + Connection(162, 21), + Connection(21, 54), + Connection(54, 103), + Connection(103, 67), + Connection(67, 109), + Connection(109, 10), + ] + + FACE_LANDMARKS_CONTOURS: List[Connection] = ( + FACE_LANDMARKS_LIPS + + FACE_LANDMARKS_LEFT_EYE + + FACE_LANDMARKS_LEFT_EYEBROW + + FACE_LANDMARKS_RIGHT_EYE + + FACE_LANDMARKS_RIGHT_EYEBROW + + FACE_LANDMARKS_FACE_OVAL + ) + + FACE_LANDMARKS_TESSELATION: List[Connection] = [ + Connection(127, 34), + Connection(34, 139), + Connection(139, 127), + Connection(11, 0), + Connection(0, 37), + Connection(37, 11), + Connection(232, 231), + Connection(231, 120), + Connection(120, 232), + Connection(72, 37), + Connection(37, 39), + Connection(39, 72), + Connection(128, 121), + Connection(121, 47), + Connection(47, 128), + Connection(232, 121), + Connection(121, 128), + Connection(128, 232), + Connection(104, 69), + Connection(69, 67), + Connection(67, 104), + Connection(175, 171), + Connection(171, 148), + Connection(148, 175), + Connection(118, 50), + Connection(50, 101), + Connection(101, 118), + Connection(73, 39), + Connection(39, 40), + Connection(40, 73), + Connection(9, 151), + Connection(151, 108), + Connection(108, 9), + Connection(48, 115), + Connection(115, 131), + Connection(131, 48), + Connection(194, 204), + Connection(204, 211), + Connection(211, 194), + Connection(74, 40), + Connection(40, 185), + Connection(185, 74), + Connection(80, 42), + Connection(42, 183), + Connection(183, 80), + Connection(40, 92), + Connection(92, 186), + Connection(186, 40), + Connection(230, 229), + Connection(229, 118), + Connection(118, 230), + Connection(202, 212), + Connection(212, 214), + Connection(214, 202), + Connection(83, 18), + Connection(18, 17), + Connection(17, 83), + Connection(76, 61), + Connection(61, 146), + Connection(146, 76), + Connection(160, 29), + Connection(29, 30), + Connection(30, 160), + Connection(56, 157), + Connection(157, 173), + Connection(173, 56), + Connection(106, 204), + Connection(204, 194), + Connection(194, 106), + Connection(135, 214), + Connection(214, 192), + Connection(192, 135), + Connection(203, 165), + Connection(165, 98), + Connection(98, 203), + Connection(21, 71), + Connection(71, 68), + Connection(68, 21), + Connection(51, 45), + Connection(45, 4), + Connection(4, 51), + Connection(144, 24), + Connection(24, 23), + Connection(23, 144), + Connection(77, 146), + Connection(146, 91), + Connection(91, 77), + Connection(205, 50), + Connection(50, 187), + Connection(187, 205), + Connection(201, 200), + Connection(200, 18), + Connection(18, 201), + Connection(91, 106), + Connection(106, 182), + Connection(182, 91), + Connection(90, 91), + Connection(91, 181), + Connection(181, 90), + Connection(85, 84), + Connection(84, 17), + Connection(17, 85), + Connection(206, 203), + Connection(203, 36), + Connection(36, 206), + Connection(148, 171), + Connection(171, 140), + Connection(140, 148), + Connection(92, 40), + Connection(40, 39), + Connection(39, 92), + Connection(193, 189), + Connection(189, 244), + Connection(244, 193), + Connection(159, 158), + Connection(158, 28), + Connection(28, 159), + Connection(247, 246), + Connection(246, 161), + Connection(161, 247), + Connection(236, 3), + Connection(3, 196), + Connection(196, 236), + Connection(54, 68), + Connection(68, 104), + Connection(104, 54), + Connection(193, 168), + Connection(168, 8), + Connection(8, 193), + Connection(117, 228), + Connection(228, 31), + Connection(31, 117), + Connection(189, 193), + Connection(193, 55), + Connection(55, 189), + Connection(98, 97), + Connection(97, 99), + Connection(99, 98), + Connection(126, 47), + Connection(47, 100), + Connection(100, 126), + Connection(166, 79), + Connection(79, 218), + Connection(218, 166), + Connection(155, 154), + Connection(154, 26), + Connection(26, 155), + Connection(209, 49), + Connection(49, 131), + Connection(131, 209), + Connection(135, 136), + Connection(136, 150), + Connection(150, 135), + Connection(47, 126), + Connection(126, 217), + Connection(217, 47), + Connection(223, 52), + Connection(52, 53), + Connection(53, 223), + Connection(45, 51), + Connection(51, 134), + Connection(134, 45), + Connection(211, 170), + Connection(170, 140), + Connection(140, 211), + Connection(67, 69), + Connection(69, 108), + Connection(108, 67), + Connection(43, 106), + Connection(106, 91), + Connection(91, 43), + Connection(230, 119), + Connection(119, 120), + Connection(120, 230), + Connection(226, 130), + Connection(130, 247), + Connection(247, 226), + Connection(63, 53), + Connection(53, 52), + Connection(52, 63), + Connection(238, 20), + Connection(20, 242), + Connection(242, 238), + Connection(46, 70), + Connection(70, 156), + Connection(156, 46), + Connection(78, 62), + Connection(62, 96), + Connection(96, 78), + Connection(46, 53), + Connection(53, 63), + Connection(63, 46), + Connection(143, 34), + Connection(34, 227), + Connection(227, 143), + Connection(123, 117), + Connection(117, 111), + Connection(111, 123), + Connection(44, 125), + Connection(125, 19), + Connection(19, 44), + Connection(236, 134), + Connection(134, 51), + Connection(51, 236), + Connection(216, 206), + Connection(206, 205), + Connection(205, 216), + Connection(154, 153), + Connection(153, 22), + Connection(22, 154), + Connection(39, 37), + Connection(37, 167), + Connection(167, 39), + Connection(200, 201), + Connection(201, 208), + Connection(208, 200), + Connection(36, 142), + Connection(142, 100), + Connection(100, 36), + Connection(57, 212), + Connection(212, 202), + Connection(202, 57), + Connection(20, 60), + Connection(60, 99), + Connection(99, 20), + Connection(28, 158), + Connection(158, 157), + Connection(157, 28), + Connection(35, 226), + Connection(226, 113), + Connection(113, 35), + Connection(160, 159), + Connection(159, 27), + Connection(27, 160), + Connection(204, 202), + Connection(202, 210), + Connection(210, 204), + Connection(113, 225), + Connection(225, 46), + Connection(46, 113), + Connection(43, 202), + Connection(202, 204), + Connection(204, 43), + Connection(62, 76), + Connection(76, 77), + Connection(77, 62), + Connection(137, 123), + Connection(123, 116), + Connection(116, 137), + Connection(41, 38), + Connection(38, 72), + Connection(72, 41), + Connection(203, 129), + Connection(129, 142), + Connection(142, 203), + Connection(64, 98), + Connection(98, 240), + Connection(240, 64), + Connection(49, 102), + Connection(102, 64), + Connection(64, 49), + Connection(41, 73), + Connection(73, 74), + Connection(74, 41), + Connection(212, 216), + Connection(216, 207), + Connection(207, 212), + Connection(42, 74), + Connection(74, 184), + Connection(184, 42), + Connection(169, 170), + Connection(170, 211), + Connection(211, 169), + Connection(170, 149), + Connection(149, 176), + Connection(176, 170), + Connection(105, 66), + Connection(66, 69), + Connection(69, 105), + Connection(122, 6), + Connection(6, 168), + Connection(168, 122), + Connection(123, 147), + Connection(147, 187), + Connection(187, 123), + Connection(96, 77), + Connection(77, 90), + Connection(90, 96), + Connection(65, 55), + Connection(55, 107), + Connection(107, 65), + Connection(89, 90), + Connection(90, 180), + Connection(180, 89), + Connection(101, 100), + Connection(100, 120), + Connection(120, 101), + Connection(63, 105), + Connection(105, 104), + Connection(104, 63), + Connection(93, 137), + Connection(137, 227), + Connection(227, 93), + Connection(15, 86), + Connection(86, 85), + Connection(85, 15), + Connection(129, 102), + Connection(102, 49), + Connection(49, 129), + Connection(14, 87), + Connection(87, 86), + Connection(86, 14), + Connection(55, 8), + Connection(8, 9), + Connection(9, 55), + Connection(100, 47), + Connection(47, 121), + Connection(121, 100), + Connection(145, 23), + Connection(23, 22), + Connection(22, 145), + Connection(88, 89), + Connection(89, 179), + Connection(179, 88), + Connection(6, 122), + Connection(122, 196), + Connection(196, 6), + Connection(88, 95), + Connection(95, 96), + Connection(96, 88), + Connection(138, 172), + Connection(172, 136), + Connection(136, 138), + Connection(215, 58), + Connection(58, 172), + Connection(172, 215), + Connection(115, 48), + Connection(48, 219), + Connection(219, 115), + Connection(42, 80), + Connection(80, 81), + Connection(81, 42), + Connection(195, 3), + Connection(3, 51), + Connection(51, 195), + Connection(43, 146), + Connection(146, 61), + Connection(61, 43), + Connection(171, 175), + Connection(175, 199), + Connection(199, 171), + Connection(81, 82), + Connection(82, 38), + Connection(38, 81), + Connection(53, 46), + Connection(46, 225), + Connection(225, 53), + Connection(144, 163), + Connection(163, 110), + Connection(110, 144), + Connection(52, 65), + Connection(65, 66), + Connection(66, 52), + Connection(229, 228), + Connection(228, 117), + Connection(117, 229), + Connection(34, 127), + Connection(127, 234), + Connection(234, 34), + Connection(107, 108), + Connection(108, 69), + Connection(69, 107), + Connection(109, 108), + Connection(108, 151), + Connection(151, 109), + Connection(48, 64), + Connection(64, 235), + Connection(235, 48), + Connection(62, 78), + Connection(78, 191), + Connection(191, 62), + Connection(129, 209), + Connection(209, 126), + Connection(126, 129), + Connection(111, 35), + Connection(35, 143), + Connection(143, 111), + Connection(117, 123), + Connection(123, 50), + Connection(50, 117), + Connection(222, 65), + Connection(65, 52), + Connection(52, 222), + Connection(19, 125), + Connection(125, 141), + Connection(141, 19), + Connection(221, 55), + Connection(55, 65), + Connection(65, 221), + Connection(3, 195), + Connection(195, 197), + Connection(197, 3), + Connection(25, 7), + Connection(7, 33), + Connection(33, 25), + Connection(220, 237), + Connection(237, 44), + Connection(44, 220), + Connection(70, 71), + Connection(71, 139), + Connection(139, 70), + Connection(122, 193), + Connection(193, 245), + Connection(245, 122), + Connection(247, 130), + Connection(130, 33), + Connection(33, 247), + Connection(71, 21), + Connection(21, 162), + Connection(162, 71), + Connection(170, 169), + Connection(169, 150), + Connection(150, 170), + Connection(188, 174), + Connection(174, 196), + Connection(196, 188), + Connection(216, 186), + Connection(186, 92), + Connection(92, 216), + Connection(2, 97), + Connection(97, 167), + Connection(167, 2), + Connection(141, 125), + Connection(125, 241), + Connection(241, 141), + Connection(164, 167), + Connection(167, 37), + Connection(37, 164), + Connection(72, 38), + Connection(38, 12), + Connection(12, 72), + Connection(38, 82), + Connection(82, 13), + Connection(13, 38), + Connection(63, 68), + Connection(68, 71), + Connection(71, 63), + Connection(226, 35), + Connection(35, 111), + Connection(111, 226), + Connection(101, 50), + Connection(50, 205), + Connection(205, 101), + Connection(206, 92), + Connection(92, 165), + Connection(165, 206), + Connection(209, 198), + Connection(198, 217), + Connection(217, 209), + Connection(165, 167), + Connection(167, 97), + Connection(97, 165), + Connection(220, 115), + Connection(115, 218), + Connection(218, 220), + Connection(133, 112), + Connection(112, 243), + Connection(243, 133), + Connection(239, 238), + Connection(238, 241), + Connection(241, 239), + Connection(214, 135), + Connection(135, 169), + Connection(169, 214), + Connection(190, 173), + Connection(173, 133), + Connection(133, 190), + Connection(171, 208), + Connection(208, 32), + Connection(32, 171), + Connection(125, 44), + Connection(44, 237), + Connection(237, 125), + Connection(86, 87), + Connection(87, 178), + Connection(178, 86), + Connection(85, 86), + Connection(86, 179), + Connection(179, 85), + Connection(84, 85), + Connection(85, 180), + Connection(180, 84), + Connection(83, 84), + Connection(84, 181), + Connection(181, 83), + Connection(201, 83), + Connection(83, 182), + Connection(182, 201), + Connection(137, 93), + Connection(93, 132), + Connection(132, 137), + Connection(76, 62), + Connection(62, 183), + Connection(183, 76), + Connection(61, 76), + Connection(76, 184), + Connection(184, 61), + Connection(57, 61), + Connection(61, 185), + Connection(185, 57), + Connection(212, 57), + Connection(57, 186), + Connection(186, 212), + Connection(214, 207), + Connection(207, 187), + Connection(187, 214), + Connection(34, 143), + Connection(143, 156), + Connection(156, 34), + Connection(79, 239), + Connection(239, 237), + Connection(237, 79), + Connection(123, 137), + Connection(137, 177), + Connection(177, 123), + Connection(44, 1), + Connection(1, 4), + Connection(4, 44), + Connection(201, 194), + Connection(194, 32), + Connection(32, 201), + Connection(64, 102), + Connection(102, 129), + Connection(129, 64), + Connection(213, 215), + Connection(215, 138), + Connection(138, 213), + Connection(59, 166), + Connection(166, 219), + Connection(219, 59), + Connection(242, 99), + Connection(99, 97), + Connection(97, 242), + Connection(2, 94), + Connection(94, 141), + Connection(141, 2), + Connection(75, 59), + Connection(59, 235), + Connection(235, 75), + Connection(24, 110), + Connection(110, 228), + Connection(228, 24), + Connection(25, 130), + Connection(130, 226), + Connection(226, 25), + Connection(23, 24), + Connection(24, 229), + Connection(229, 23), + Connection(22, 23), + Connection(23, 230), + Connection(230, 22), + Connection(26, 22), + Connection(22, 231), + Connection(231, 26), + Connection(112, 26), + Connection(26, 232), + Connection(232, 112), + Connection(189, 190), + Connection(190, 243), + Connection(243, 189), + Connection(221, 56), + Connection(56, 190), + Connection(190, 221), + Connection(28, 56), + Connection(56, 221), + Connection(221, 28), + Connection(27, 28), + Connection(28, 222), + Connection(222, 27), + Connection(29, 27), + Connection(27, 223), + Connection(223, 29), + Connection(30, 29), + Connection(29, 224), + Connection(224, 30), + Connection(247, 30), + Connection(30, 225), + Connection(225, 247), + Connection(238, 79), + Connection(79, 20), + Connection(20, 238), + Connection(166, 59), + Connection(59, 75), + Connection(75, 166), + Connection(60, 75), + Connection(75, 240), + Connection(240, 60), + Connection(147, 177), + Connection(177, 215), + Connection(215, 147), + Connection(20, 79), + Connection(79, 166), + Connection(166, 20), + Connection(187, 147), + Connection(147, 213), + Connection(213, 187), + Connection(112, 233), + Connection(233, 244), + Connection(244, 112), + Connection(233, 128), + Connection(128, 245), + Connection(245, 233), + Connection(128, 114), + Connection(114, 188), + Connection(188, 128), + Connection(114, 217), + Connection(217, 174), + Connection(174, 114), + Connection(131, 115), + Connection(115, 220), + Connection(220, 131), + Connection(217, 198), + Connection(198, 236), + Connection(236, 217), + Connection(198, 131), + Connection(131, 134), + Connection(134, 198), + Connection(177, 132), + Connection(132, 58), + Connection(58, 177), + Connection(143, 35), + Connection(35, 124), + Connection(124, 143), + Connection(110, 163), + Connection(163, 7), + Connection(7, 110), + Connection(228, 110), + Connection(110, 25), + Connection(25, 228), + Connection(356, 389), + Connection(389, 368), + Connection(368, 356), + Connection(11, 302), + Connection(302, 267), + Connection(267, 11), + Connection(452, 350), + Connection(350, 349), + Connection(349, 452), + Connection(302, 303), + Connection(303, 269), + Connection(269, 302), + Connection(357, 343), + Connection(343, 277), + Connection(277, 357), + Connection(452, 453), + Connection(453, 357), + Connection(357, 452), + Connection(333, 332), + Connection(332, 297), + Connection(297, 333), + Connection(175, 152), + Connection(152, 377), + Connection(377, 175), + Connection(347, 348), + Connection(348, 330), + Connection(330, 347), + Connection(303, 304), + Connection(304, 270), + Connection(270, 303), + Connection(9, 336), + Connection(336, 337), + Connection(337, 9), + Connection(278, 279), + Connection(279, 360), + Connection(360, 278), + Connection(418, 262), + Connection(262, 431), + Connection(431, 418), + Connection(304, 408), + Connection(408, 409), + Connection(409, 304), + Connection(310, 415), + Connection(415, 407), + Connection(407, 310), + Connection(270, 409), + Connection(409, 410), + Connection(410, 270), + Connection(450, 348), + Connection(348, 347), + Connection(347, 450), + Connection(422, 430), + Connection(430, 434), + Connection(434, 422), + Connection(313, 314), + Connection(314, 17), + Connection(17, 313), + Connection(306, 307), + Connection(307, 375), + Connection(375, 306), + Connection(387, 388), + Connection(388, 260), + Connection(260, 387), + Connection(286, 414), + Connection(414, 398), + Connection(398, 286), + Connection(335, 406), + Connection(406, 418), + Connection(418, 335), + Connection(364, 367), + Connection(367, 416), + Connection(416, 364), + Connection(423, 358), + Connection(358, 327), + Connection(327, 423), + Connection(251, 284), + Connection(284, 298), + Connection(298, 251), + Connection(281, 5), + Connection(5, 4), + Connection(4, 281), + Connection(373, 374), + Connection(374, 253), + Connection(253, 373), + Connection(307, 320), + Connection(320, 321), + Connection(321, 307), + Connection(425, 427), + Connection(427, 411), + Connection(411, 425), + Connection(421, 313), + Connection(313, 18), + Connection(18, 421), + Connection(321, 405), + Connection(405, 406), + Connection(406, 321), + Connection(320, 404), + Connection(404, 405), + Connection(405, 320), + Connection(315, 16), + Connection(16, 17), + Connection(17, 315), + Connection(426, 425), + Connection(425, 266), + Connection(266, 426), + Connection(377, 400), + Connection(400, 369), + Connection(369, 377), + Connection(322, 391), + Connection(391, 269), + Connection(269, 322), + Connection(417, 465), + Connection(465, 464), + Connection(464, 417), + Connection(386, 257), + Connection(257, 258), + Connection(258, 386), + Connection(466, 260), + Connection(260, 388), + Connection(388, 466), + Connection(456, 399), + Connection(399, 419), + Connection(419, 456), + Connection(284, 332), + Connection(332, 333), + Connection(333, 284), + Connection(417, 285), + Connection(285, 8), + Connection(8, 417), + Connection(346, 340), + Connection(340, 261), + Connection(261, 346), + Connection(413, 441), + Connection(441, 285), + Connection(285, 413), + Connection(327, 460), + Connection(460, 328), + Connection(328, 327), + Connection(355, 371), + Connection(371, 329), + Connection(329, 355), + Connection(392, 439), + Connection(439, 438), + Connection(438, 392), + Connection(382, 341), + Connection(341, 256), + Connection(256, 382), + Connection(429, 420), + Connection(420, 360), + Connection(360, 429), + Connection(364, 394), + Connection(394, 379), + Connection(379, 364), + Connection(277, 343), + Connection(343, 437), + Connection(437, 277), + Connection(443, 444), + Connection(444, 283), + Connection(283, 443), + Connection(275, 440), + Connection(440, 363), + Connection(363, 275), + Connection(431, 262), + Connection(262, 369), + Connection(369, 431), + Connection(297, 338), + Connection(338, 337), + Connection(337, 297), + Connection(273, 375), + Connection(375, 321), + Connection(321, 273), + Connection(450, 451), + Connection(451, 349), + Connection(349, 450), + Connection(446, 342), + Connection(342, 467), + Connection(467, 446), + Connection(293, 334), + Connection(334, 282), + Connection(282, 293), + Connection(458, 461), + Connection(461, 462), + Connection(462, 458), + Connection(276, 353), + Connection(353, 383), + Connection(383, 276), + Connection(308, 324), + Connection(324, 325), + Connection(325, 308), + Connection(276, 300), + Connection(300, 293), + Connection(293, 276), + Connection(372, 345), + Connection(345, 447), + Connection(447, 372), + Connection(352, 345), + Connection(345, 340), + Connection(340, 352), + Connection(274, 1), + Connection(1, 19), + Connection(19, 274), + Connection(456, 248), + Connection(248, 281), + Connection(281, 456), + Connection(436, 427), + Connection(427, 425), + Connection(425, 436), + Connection(381, 256), + Connection(256, 252), + Connection(252, 381), + Connection(269, 391), + Connection(391, 393), + Connection(393, 269), + Connection(200, 199), + Connection(199, 428), + Connection(428, 200), + Connection(266, 330), + Connection(330, 329), + Connection(329, 266), + Connection(287, 273), + Connection(273, 422), + Connection(422, 287), + Connection(250, 462), + Connection(462, 328), + Connection(328, 250), + Connection(258, 286), + Connection(286, 384), + Connection(384, 258), + Connection(265, 353), + Connection(353, 342), + Connection(342, 265), + Connection(387, 259), + Connection(259, 257), + Connection(257, 387), + Connection(424, 431), + Connection(431, 430), + Connection(430, 424), + Connection(342, 353), + Connection(353, 276), + Connection(276, 342), + Connection(273, 335), + Connection(335, 424), + Connection(424, 273), + Connection(292, 325), + Connection(325, 307), + Connection(307, 292), + Connection(366, 447), + Connection(447, 345), + Connection(345, 366), + Connection(271, 303), + Connection(303, 302), + Connection(302, 271), + Connection(423, 266), + Connection(266, 371), + Connection(371, 423), + Connection(294, 455), + Connection(455, 460), + Connection(460, 294), + Connection(279, 278), + Connection(278, 294), + Connection(294, 279), + Connection(271, 272), + Connection(272, 304), + Connection(304, 271), + Connection(432, 434), + Connection(434, 427), + Connection(427, 432), + Connection(272, 407), + Connection(407, 408), + Connection(408, 272), + Connection(394, 430), + Connection(430, 431), + Connection(431, 394), + Connection(395, 369), + Connection(369, 400), + Connection(400, 395), + Connection(334, 333), + Connection(333, 299), + Connection(299, 334), + Connection(351, 417), + Connection(417, 168), + Connection(168, 351), + Connection(352, 280), + Connection(280, 411), + Connection(411, 352), + Connection(325, 319), + Connection(319, 320), + Connection(320, 325), + Connection(295, 296), + Connection(296, 336), + Connection(336, 295), + Connection(319, 403), + Connection(403, 404), + Connection(404, 319), + Connection(330, 348), + Connection(348, 349), + Connection(349, 330), + Connection(293, 298), + Connection(298, 333), + Connection(333, 293), + Connection(323, 454), + Connection(454, 447), + Connection(447, 323), + Connection(15, 16), + Connection(16, 315), + Connection(315, 15), + Connection(358, 429), + Connection(429, 279), + Connection(279, 358), + Connection(14, 15), + Connection(15, 316), + Connection(316, 14), + Connection(285, 336), + Connection(336, 9), + Connection(9, 285), + Connection(329, 349), + Connection(349, 350), + Connection(350, 329), + Connection(374, 380), + Connection(380, 252), + Connection(252, 374), + Connection(318, 402), + Connection(402, 403), + Connection(403, 318), + Connection(6, 197), + Connection(197, 419), + Connection(419, 6), + Connection(318, 319), + Connection(319, 325), + Connection(325, 318), + Connection(367, 364), + Connection(364, 365), + Connection(365, 367), + Connection(435, 367), + Connection(367, 397), + Connection(397, 435), + Connection(344, 438), + Connection(438, 439), + Connection(439, 344), + Connection(272, 271), + Connection(271, 311), + Connection(311, 272), + Connection(195, 5), + Connection(5, 281), + Connection(281, 195), + Connection(273, 287), + Connection(287, 291), + Connection(291, 273), + Connection(396, 428), + Connection(428, 199), + Connection(199, 396), + Connection(311, 271), + Connection(271, 268), + Connection(268, 311), + Connection(283, 444), + Connection(444, 445), + Connection(445, 283), + Connection(373, 254), + Connection(254, 339), + Connection(339, 373), + Connection(282, 334), + Connection(334, 296), + Connection(296, 282), + Connection(449, 347), + Connection(347, 346), + Connection(346, 449), + Connection(264, 447), + Connection(447, 454), + Connection(454, 264), + Connection(336, 296), + Connection(296, 299), + Connection(299, 336), + Connection(338, 10), + Connection(10, 151), + Connection(151, 338), + Connection(278, 439), + Connection(439, 455), + Connection(455, 278), + Connection(292, 407), + Connection(407, 415), + Connection(415, 292), + Connection(358, 371), + Connection(371, 355), + Connection(355, 358), + Connection(340, 345), + Connection(345, 372), + Connection(372, 340), + Connection(346, 347), + Connection(347, 280), + Connection(280, 346), + Connection(442, 443), + Connection(443, 282), + Connection(282, 442), + Connection(19, 94), + Connection(94, 370), + Connection(370, 19), + Connection(441, 442), + Connection(442, 295), + Connection(295, 441), + Connection(248, 419), + Connection(419, 197), + Connection(197, 248), + Connection(263, 255), + Connection(255, 359), + Connection(359, 263), + Connection(440, 275), + Connection(275, 274), + Connection(274, 440), + Connection(300, 383), + Connection(383, 368), + Connection(368, 300), + Connection(351, 412), + Connection(412, 465), + Connection(465, 351), + Connection(263, 467), + Connection(467, 466), + Connection(466, 263), + Connection(301, 368), + Connection(368, 389), + Connection(389, 301), + Connection(395, 378), + Connection(378, 379), + Connection(379, 395), + Connection(412, 351), + Connection(351, 419), + Connection(419, 412), + Connection(436, 426), + Connection(426, 322), + Connection(322, 436), + Connection(2, 164), + Connection(164, 393), + Connection(393, 2), + Connection(370, 462), + Connection(462, 461), + Connection(461, 370), + Connection(164, 0), + Connection(0, 267), + Connection(267, 164), + Connection(302, 11), + Connection(11, 12), + Connection(12, 302), + Connection(268, 12), + Connection(12, 13), + Connection(13, 268), + Connection(293, 300), + Connection(300, 301), + Connection(301, 293), + Connection(446, 261), + Connection(261, 340), + Connection(340, 446), + Connection(330, 266), + Connection(266, 425), + Connection(425, 330), + Connection(426, 423), + Connection(423, 391), + Connection(391, 426), + Connection(429, 355), + Connection(355, 437), + Connection(437, 429), + Connection(391, 327), + Connection(327, 326), + Connection(326, 391), + Connection(440, 457), + Connection(457, 438), + Connection(438, 440), + Connection(341, 382), + Connection(382, 362), + Connection(362, 341), + Connection(459, 457), + Connection(457, 461), + Connection(461, 459), + Connection(434, 430), + Connection(430, 394), + Connection(394, 434), + Connection(414, 463), + Connection(463, 362), + Connection(362, 414), + Connection(396, 369), + Connection(369, 262), + Connection(262, 396), + Connection(354, 461), + Connection(461, 457), + Connection(457, 354), + Connection(316, 403), + Connection(403, 402), + Connection(402, 316), + Connection(315, 404), + Connection(404, 403), + Connection(403, 315), + Connection(314, 405), + Connection(405, 404), + Connection(404, 314), + Connection(313, 406), + Connection(406, 405), + Connection(405, 313), + Connection(421, 418), + Connection(418, 406), + Connection(406, 421), + Connection(366, 401), + Connection(401, 361), + Connection(361, 366), + Connection(306, 408), + Connection(408, 407), + Connection(407, 306), + Connection(291, 409), + Connection(409, 408), + Connection(408, 291), + Connection(287, 410), + Connection(410, 409), + Connection(409, 287), + Connection(432, 436), + Connection(436, 410), + Connection(410, 432), + Connection(434, 416), + Connection(416, 411), + Connection(411, 434), + Connection(264, 368), + Connection(368, 383), + Connection(383, 264), + Connection(309, 438), + Connection(438, 457), + Connection(457, 309), + Connection(352, 376), + Connection(376, 401), + Connection(401, 352), + Connection(274, 275), + Connection(275, 4), + Connection(4, 274), + Connection(421, 428), + Connection(428, 262), + Connection(262, 421), + Connection(294, 327), + Connection(327, 358), + Connection(358, 294), + Connection(433, 416), + Connection(416, 367), + Connection(367, 433), + Connection(289, 455), + Connection(455, 439), + Connection(439, 289), + Connection(462, 370), + Connection(370, 326), + Connection(326, 462), + Connection(2, 326), + Connection(326, 370), + Connection(370, 2), + Connection(305, 460), + Connection(460, 455), + Connection(455, 305), + Connection(254, 449), + Connection(449, 448), + Connection(448, 254), + Connection(255, 261), + Connection(261, 446), + Connection(446, 255), + Connection(253, 450), + Connection(450, 449), + Connection(449, 253), + Connection(252, 451), + Connection(451, 450), + Connection(450, 252), + Connection(256, 452), + Connection(452, 451), + Connection(451, 256), + Connection(341, 453), + Connection(453, 452), + Connection(452, 341), + Connection(413, 464), + Connection(464, 463), + Connection(463, 413), + Connection(441, 413), + Connection(413, 414), + Connection(414, 441), + Connection(258, 442), + Connection(442, 441), + Connection(441, 258), + Connection(257, 443), + Connection(443, 442), + Connection(442, 257), + Connection(259, 444), + Connection(444, 443), + Connection(443, 259), + Connection(260, 445), + Connection(445, 444), + Connection(444, 260), + Connection(467, 342), + Connection(342, 445), + Connection(445, 467), + Connection(459, 458), + Connection(458, 250), + Connection(250, 459), + Connection(289, 392), + Connection(392, 290), + Connection(290, 289), + Connection(290, 328), + Connection(328, 460), + Connection(460, 290), + Connection(376, 433), + Connection(433, 435), + Connection(435, 376), + Connection(250, 290), + Connection(290, 392), + Connection(392, 250), + Connection(411, 416), + Connection(416, 433), + Connection(433, 411), + Connection(341, 463), + Connection(463, 464), + Connection(464, 341), + Connection(453, 464), + Connection(464, 465), + Connection(465, 453), + Connection(357, 465), + Connection(465, 412), + Connection(412, 357), + Connection(343, 412), + Connection(412, 399), + Connection(399, 343), + Connection(360, 363), + Connection(363, 440), + Connection(440, 360), + Connection(437, 399), + Connection(399, 456), + Connection(456, 437), + Connection(420, 456), + Connection(456, 363), + Connection(363, 420), + Connection(401, 435), + Connection(435, 288), + Connection(288, 401), + Connection(372, 383), + Connection(383, 353), + Connection(353, 372), + Connection(339, 255), + Connection(255, 249), + Connection(249, 339), + Connection(448, 261), + Connection(261, 255), + Connection(255, 448), + Connection(133, 243), + Connection(243, 190), + Connection(190, 133), + Connection(133, 155), + Connection(155, 112), + Connection(112, 133), + Connection(33, 246), + Connection(246, 247), + Connection(247, 33), + Connection(33, 130), + Connection(130, 25), + Connection(25, 33), + Connection(398, 384), + Connection(384, 286), + Connection(286, 398), + Connection(362, 398), + Connection(398, 414), + Connection(414, 362), + Connection(362, 463), + Connection(463, 341), + Connection(341, 362), + Connection(263, 359), + Connection(359, 467), + Connection(467, 263), + Connection(263, 249), + Connection(249, 255), + Connection(255, 263), + Connection(466, 467), + Connection(467, 260), + Connection(260, 466), + Connection(75, 60), + Connection(60, 166), + Connection(166, 75), + Connection(238, 239), + Connection(239, 79), + Connection(79, 238), + Connection(162, 127), + Connection(127, 139), + Connection(139, 162), + Connection(72, 11), + Connection(11, 37), + Connection(37, 72), + Connection(121, 232), + Connection(232, 120), + Connection(120, 121), + Connection(73, 72), + Connection(72, 39), + Connection(39, 73), + Connection(114, 128), + Connection(128, 47), + Connection(47, 114), + Connection(233, 232), + Connection(232, 128), + Connection(128, 233), + Connection(103, 104), + Connection(104, 67), + Connection(67, 103), + Connection(152, 175), + Connection(175, 148), + Connection(148, 152), + Connection(119, 118), + Connection(118, 101), + Connection(101, 119), + Connection(74, 73), + Connection(73, 40), + Connection(40, 74), + Connection(107, 9), + Connection(9, 108), + Connection(108, 107), + Connection(49, 48), + Connection(48, 131), + Connection(131, 49), + Connection(32, 194), + Connection(194, 211), + Connection(211, 32), + Connection(184, 74), + Connection(74, 185), + Connection(185, 184), + Connection(191, 80), + Connection(80, 183), + Connection(183, 191), + Connection(185, 40), + Connection(40, 186), + Connection(186, 185), + Connection(119, 230), + Connection(230, 118), + Connection(118, 119), + Connection(210, 202), + Connection(202, 214), + Connection(214, 210), + Connection(84, 83), + Connection(83, 17), + Connection(17, 84), + Connection(77, 76), + Connection(76, 146), + Connection(146, 77), + Connection(161, 160), + Connection(160, 30), + Connection(30, 161), + Connection(190, 56), + Connection(56, 173), + Connection(173, 190), + Connection(182, 106), + Connection(106, 194), + Connection(194, 182), + Connection(138, 135), + Connection(135, 192), + Connection(192, 138), + Connection(129, 203), + Connection(203, 98), + Connection(98, 129), + Connection(54, 21), + Connection(21, 68), + Connection(68, 54), + Connection(5, 51), + Connection(51, 4), + Connection(4, 5), + Connection(145, 144), + Connection(144, 23), + Connection(23, 145), + Connection(90, 77), + Connection(77, 91), + Connection(91, 90), + Connection(207, 205), + Connection(205, 187), + Connection(187, 207), + Connection(83, 201), + Connection(201, 18), + Connection(18, 83), + Connection(181, 91), + Connection(91, 182), + Connection(182, 181), + Connection(180, 90), + Connection(90, 181), + Connection(181, 180), + Connection(16, 85), + Connection(85, 17), + Connection(17, 16), + Connection(205, 206), + Connection(206, 36), + Connection(36, 205), + Connection(176, 148), + Connection(148, 140), + Connection(140, 176), + Connection(165, 92), + Connection(92, 39), + Connection(39, 165), + Connection(245, 193), + Connection(193, 244), + Connection(244, 245), + Connection(27, 159), + Connection(159, 28), + Connection(28, 27), + Connection(30, 247), + Connection(247, 161), + Connection(161, 30), + Connection(174, 236), + Connection(236, 196), + Connection(196, 174), + Connection(103, 54), + Connection(54, 104), + Connection(104, 103), + Connection(55, 193), + Connection(193, 8), + Connection(8, 55), + Connection(111, 117), + Connection(117, 31), + Connection(31, 111), + Connection(221, 189), + Connection(189, 55), + Connection(55, 221), + Connection(240, 98), + Connection(98, 99), + Connection(99, 240), + Connection(142, 126), + Connection(126, 100), + Connection(100, 142), + Connection(219, 166), + Connection(166, 218), + Connection(218, 219), + Connection(112, 155), + Connection(155, 26), + Connection(26, 112), + Connection(198, 209), + Connection(209, 131), + Connection(131, 198), + Connection(169, 135), + Connection(135, 150), + Connection(150, 169), + Connection(114, 47), + Connection(47, 217), + Connection(217, 114), + Connection(224, 223), + Connection(223, 53), + Connection(53, 224), + Connection(220, 45), + Connection(45, 134), + Connection(134, 220), + Connection(32, 211), + Connection(211, 140), + Connection(140, 32), + Connection(109, 67), + Connection(67, 108), + Connection(108, 109), + Connection(146, 43), + Connection(43, 91), + Connection(91, 146), + Connection(231, 230), + Connection(230, 120), + Connection(120, 231), + Connection(113, 226), + Connection(226, 247), + Connection(247, 113), + Connection(105, 63), + Connection(63, 52), + Connection(52, 105), + Connection(241, 238), + Connection(238, 242), + Connection(242, 241), + Connection(124, 46), + Connection(46, 156), + Connection(156, 124), + Connection(95, 78), + Connection(78, 96), + Connection(96, 95), + Connection(70, 46), + Connection(46, 63), + Connection(63, 70), + Connection(116, 143), + Connection(143, 227), + Connection(227, 116), + Connection(116, 123), + Connection(123, 111), + Connection(111, 116), + Connection(1, 44), + Connection(44, 19), + Connection(19, 1), + Connection(3, 236), + Connection(236, 51), + Connection(51, 3), + Connection(207, 216), + Connection(216, 205), + Connection(205, 207), + Connection(26, 154), + Connection(154, 22), + Connection(22, 26), + Connection(165, 39), + Connection(39, 167), + Connection(167, 165), + Connection(199, 200), + Connection(200, 208), + Connection(208, 199), + Connection(101, 36), + Connection(36, 100), + Connection(100, 101), + Connection(43, 57), + Connection(57, 202), + Connection(202, 43), + Connection(242, 20), + Connection(20, 99), + Connection(99, 242), + Connection(56, 28), + Connection(28, 157), + Connection(157, 56), + Connection(124, 35), + Connection(35, 113), + Connection(113, 124), + Connection(29, 160), + Connection(160, 27), + Connection(27, 29), + Connection(211, 204), + Connection(204, 210), + Connection(210, 211), + Connection(124, 113), + Connection(113, 46), + Connection(46, 124), + Connection(106, 43), + Connection(43, 204), + Connection(204, 106), + Connection(96, 62), + Connection(62, 77), + Connection(77, 96), + Connection(227, 137), + Connection(137, 116), + Connection(116, 227), + Connection(73, 41), + Connection(41, 72), + Connection(72, 73), + Connection(36, 203), + Connection(203, 142), + Connection(142, 36), + Connection(235, 64), + Connection(64, 240), + Connection(240, 235), + Connection(48, 49), + Connection(49, 64), + Connection(64, 48), + Connection(42, 41), + Connection(41, 74), + Connection(74, 42), + Connection(214, 212), + Connection(212, 207), + Connection(207, 214), + Connection(183, 42), + Connection(42, 184), + Connection(184, 183), + Connection(210, 169), + Connection(169, 211), + Connection(211, 210), + Connection(140, 170), + Connection(170, 176), + Connection(176, 140), + Connection(104, 105), + Connection(105, 69), + Connection(69, 104), + Connection(193, 122), + Connection(122, 168), + Connection(168, 193), + Connection(50, 123), + Connection(123, 187), + Connection(187, 50), + Connection(89, 96), + Connection(96, 90), + Connection(90, 89), + Connection(66, 65), + Connection(65, 107), + Connection(107, 66), + Connection(179, 89), + Connection(89, 180), + Connection(180, 179), + Connection(119, 101), + Connection(101, 120), + Connection(120, 119), + Connection(68, 63), + Connection(63, 104), + Connection(104, 68), + Connection(234, 93), + Connection(93, 227), + Connection(227, 234), + Connection(16, 15), + Connection(15, 85), + Connection(85, 16), + Connection(209, 129), + Connection(129, 49), + Connection(49, 209), + Connection(15, 14), + Connection(14, 86), + Connection(86, 15), + Connection(107, 55), + Connection(55, 9), + Connection(9, 107), + Connection(120, 100), + Connection(100, 121), + Connection(121, 120), + Connection(153, 145), + Connection(145, 22), + Connection(22, 153), + Connection(178, 88), + Connection(88, 179), + Connection(179, 178), + Connection(197, 6), + Connection(6, 196), + Connection(196, 197), + Connection(89, 88), + Connection(88, 96), + Connection(96, 89), + Connection(135, 138), + Connection(138, 136), + Connection(136, 135), + Connection(138, 215), + Connection(215, 172), + Connection(172, 138), + Connection(218, 115), + Connection(115, 219), + Connection(219, 218), + Connection(41, 42), + Connection(42, 81), + Connection(81, 41), + Connection(5, 195), + Connection(195, 51), + Connection(51, 5), + Connection(57, 43), + Connection(43, 61), + Connection(61, 57), + Connection(208, 171), + Connection(171, 199), + Connection(199, 208), + Connection(41, 81), + Connection(81, 38), + Connection(38, 41), + Connection(224, 53), + Connection(53, 225), + Connection(225, 224), + Connection(24, 144), + Connection(144, 110), + Connection(110, 24), + Connection(105, 52), + Connection(52, 66), + Connection(66, 105), + Connection(118, 229), + Connection(229, 117), + Connection(117, 118), + Connection(227, 34), + Connection(34, 234), + Connection(234, 227), + Connection(66, 107), + Connection(107, 69), + Connection(69, 66), + Connection(10, 109), + Connection(109, 151), + Connection(151, 10), + Connection(219, 48), + Connection(48, 235), + Connection(235, 219), + Connection(183, 62), + Connection(62, 191), + Connection(191, 183), + Connection(142, 129), + Connection(129, 126), + Connection(126, 142), + Connection(116, 111), + Connection(111, 143), + Connection(143, 116), + Connection(118, 117), + Connection(117, 50), + Connection(50, 118), + Connection(223, 222), + Connection(222, 52), + Connection(52, 223), + Connection(94, 19), + Connection(19, 141), + Connection(141, 94), + Connection(222, 221), + Connection(221, 65), + Connection(65, 222), + Connection(196, 3), + Connection(3, 197), + Connection(197, 196), + Connection(45, 220), + Connection(220, 44), + Connection(44, 45), + Connection(156, 70), + Connection(70, 139), + Connection(139, 156), + Connection(188, 122), + Connection(122, 245), + Connection(245, 188), + Connection(139, 71), + Connection(71, 162), + Connection(162, 139), + Connection(149, 170), + Connection(170, 150), + Connection(150, 149), + Connection(122, 188), + Connection(188, 196), + Connection(196, 122), + Connection(206, 216), + Connection(216, 92), + Connection(92, 206), + Connection(164, 2), + Connection(2, 167), + Connection(167, 164), + Connection(242, 141), + Connection(141, 241), + Connection(241, 242), + Connection(0, 164), + Connection(164, 37), + Connection(37, 0), + Connection(11, 72), + Connection(72, 12), + Connection(12, 11), + Connection(12, 38), + Connection(38, 13), + Connection(13, 12), + Connection(70, 63), + Connection(63, 71), + Connection(71, 70), + Connection(31, 226), + Connection(226, 111), + Connection(111, 31), + Connection(36, 101), + Connection(101, 205), + Connection(205, 36), + Connection(203, 206), + Connection(206, 165), + Connection(165, 203), + Connection(126, 209), + Connection(209, 217), + Connection(217, 126), + Connection(98, 165), + Connection(165, 97), + Connection(97, 98), + Connection(237, 220), + Connection(220, 218), + Connection(218, 237), + Connection(237, 239), + Connection(239, 241), + Connection(241, 237), + Connection(210, 214), + Connection(214, 169), + Connection(169, 210), + Connection(140, 171), + Connection(171, 32), + Connection(32, 140), + Connection(241, 125), + Connection(125, 237), + Connection(237, 241), + Connection(179, 86), + Connection(86, 178), + Connection(178, 179), + Connection(180, 85), + Connection(85, 179), + Connection(179, 180), + Connection(181, 84), + Connection(84, 180), + Connection(180, 181), + Connection(182, 83), + Connection(83, 181), + Connection(181, 182), + Connection(194, 201), + Connection(201, 182), + Connection(182, 194), + Connection(177, 137), + Connection(137, 132), + Connection(132, 177), + Connection(184, 76), + Connection(76, 183), + Connection(183, 184), + Connection(185, 61), + Connection(61, 184), + Connection(184, 185), + Connection(186, 57), + Connection(57, 185), + Connection(185, 186), + Connection(216, 212), + Connection(212, 186), + Connection(186, 216), + Connection(192, 214), + Connection(214, 187), + Connection(187, 192), + Connection(139, 34), + Connection(34, 156), + Connection(156, 139), + Connection(218, 79), + Connection(79, 237), + Connection(237, 218), + Connection(147, 123), + Connection(123, 177), + Connection(177, 147), + Connection(45, 44), + Connection(44, 4), + Connection(4, 45), + Connection(208, 201), + Connection(201, 32), + Connection(32, 208), + Connection(98, 64), + Connection(64, 129), + Connection(129, 98), + Connection(192, 213), + Connection(213, 138), + Connection(138, 192), + Connection(235, 59), + Connection(59, 219), + Connection(219, 235), + Connection(141, 242), + Connection(242, 97), + Connection(97, 141), + Connection(97, 2), + Connection(2, 141), + Connection(141, 97), + Connection(240, 75), + Connection(75, 235), + Connection(235, 240), + Connection(229, 24), + Connection(24, 228), + Connection(228, 229), + Connection(31, 25), + Connection(25, 226), + Connection(226, 31), + Connection(230, 23), + Connection(23, 229), + Connection(229, 230), + Connection(231, 22), + Connection(22, 230), + Connection(230, 231), + Connection(232, 26), + Connection(26, 231), + Connection(231, 232), + Connection(233, 112), + Connection(112, 232), + Connection(232, 233), + Connection(244, 189), + Connection(189, 243), + Connection(243, 244), + Connection(189, 221), + Connection(221, 190), + Connection(190, 189), + Connection(222, 28), + Connection(28, 221), + Connection(221, 222), + Connection(223, 27), + Connection(27, 222), + Connection(222, 223), + Connection(224, 29), + Connection(29, 223), + Connection(223, 224), + Connection(225, 30), + Connection(30, 224), + Connection(224, 225), + Connection(113, 247), + Connection(247, 225), + Connection(225, 113), + Connection(99, 60), + Connection(60, 240), + Connection(240, 99), + Connection(213, 147), + Connection(147, 215), + Connection(215, 213), + Connection(60, 20), + Connection(20, 166), + Connection(166, 60), + Connection(192, 187), + Connection(187, 213), + Connection(213, 192), + Connection(243, 112), + Connection(112, 244), + Connection(244, 243), + Connection(244, 233), + Connection(233, 245), + Connection(245, 244), + Connection(245, 128), + Connection(128, 188), + Connection(188, 245), + Connection(188, 114), + Connection(114, 174), + Connection(174, 188), + Connection(134, 131), + Connection(131, 220), + Connection(220, 134), + Connection(174, 217), + Connection(217, 236), + Connection(236, 174), + Connection(236, 198), + Connection(198, 134), + Connection(134, 236), + Connection(215, 177), + Connection(177, 58), + Connection(58, 215), + Connection(156, 143), + Connection(143, 124), + Connection(124, 156), + Connection(25, 110), + Connection(110, 7), + Connection(7, 25), + Connection(31, 228), + Connection(228, 25), + Connection(25, 31), + Connection(264, 356), + Connection(356, 368), + Connection(368, 264), + Connection(0, 11), + Connection(11, 267), + Connection(267, 0), + Connection(451, 452), + Connection(452, 349), + Connection(349, 451), + Connection(267, 302), + Connection(302, 269), + Connection(269, 267), + Connection(350, 357), + Connection(357, 277), + Connection(277, 350), + Connection(350, 452), + Connection(452, 357), + Connection(357, 350), + Connection(299, 333), + Connection(333, 297), + Connection(297, 299), + Connection(396, 175), + Connection(175, 377), + Connection(377, 396), + Connection(280, 347), + Connection(347, 330), + Connection(330, 280), + Connection(269, 303), + Connection(303, 270), + Connection(270, 269), + Connection(151, 9), + Connection(9, 337), + Connection(337, 151), + Connection(344, 278), + Connection(278, 360), + Connection(360, 344), + Connection(424, 418), + Connection(418, 431), + Connection(431, 424), + Connection(270, 304), + Connection(304, 409), + Connection(409, 270), + Connection(272, 310), + Connection(310, 407), + Connection(407, 272), + Connection(322, 270), + Connection(270, 410), + Connection(410, 322), + Connection(449, 450), + Connection(450, 347), + Connection(347, 449), + Connection(432, 422), + Connection(422, 434), + Connection(434, 432), + Connection(18, 313), + Connection(313, 17), + Connection(17, 18), + Connection(291, 306), + Connection(306, 375), + Connection(375, 291), + Connection(259, 387), + Connection(387, 260), + Connection(260, 259), + Connection(424, 335), + Connection(335, 418), + Connection(418, 424), + Connection(434, 364), + Connection(364, 416), + Connection(416, 434), + Connection(391, 423), + Connection(423, 327), + Connection(327, 391), + Connection(301, 251), + Connection(251, 298), + Connection(298, 301), + Connection(275, 281), + Connection(281, 4), + Connection(4, 275), + Connection(254, 373), + Connection(373, 253), + Connection(253, 254), + Connection(375, 307), + Connection(307, 321), + Connection(321, 375), + Connection(280, 425), + Connection(425, 411), + Connection(411, 280), + Connection(200, 421), + Connection(421, 18), + Connection(18, 200), + Connection(335, 321), + Connection(321, 406), + Connection(406, 335), + Connection(321, 320), + Connection(320, 405), + Connection(405, 321), + Connection(314, 315), + Connection(315, 17), + Connection(17, 314), + Connection(423, 426), + Connection(426, 266), + Connection(266, 423), + Connection(396, 377), + Connection(377, 369), + Connection(369, 396), + Connection(270, 322), + Connection(322, 269), + Connection(269, 270), + Connection(413, 417), + Connection(417, 464), + Connection(464, 413), + Connection(385, 386), + Connection(386, 258), + Connection(258, 385), + Connection(248, 456), + Connection(456, 419), + Connection(419, 248), + Connection(298, 284), + Connection(284, 333), + Connection(333, 298), + Connection(168, 417), + Connection(417, 8), + Connection(8, 168), + Connection(448, 346), + Connection(346, 261), + Connection(261, 448), + Connection(417, 413), + Connection(413, 285), + Connection(285, 417), + Connection(326, 327), + Connection(327, 328), + Connection(328, 326), + Connection(277, 355), + Connection(355, 329), + Connection(329, 277), + Connection(309, 392), + Connection(392, 438), + Connection(438, 309), + Connection(381, 382), + Connection(382, 256), + Connection(256, 381), + Connection(279, 429), + Connection(429, 360), + Connection(360, 279), + Connection(365, 364), + Connection(364, 379), + Connection(379, 365), + Connection(355, 277), + Connection(277, 437), + Connection(437, 355), + Connection(282, 443), + Connection(443, 283), + Connection(283, 282), + Connection(281, 275), + Connection(275, 363), + Connection(363, 281), + Connection(395, 431), + Connection(431, 369), + Connection(369, 395), + Connection(299, 297), + Connection(297, 337), + Connection(337, 299), + Connection(335, 273), + Connection(273, 321), + Connection(321, 335), + Connection(348, 450), + Connection(450, 349), + Connection(349, 348), + Connection(359, 446), + Connection(446, 467), + Connection(467, 359), + Connection(283, 293), + Connection(293, 282), + Connection(282, 283), + Connection(250, 458), + Connection(458, 462), + Connection(462, 250), + Connection(300, 276), + Connection(276, 383), + Connection(383, 300), + Connection(292, 308), + Connection(308, 325), + Connection(325, 292), + Connection(283, 276), + Connection(276, 293), + Connection(293, 283), + Connection(264, 372), + Connection(372, 447), + Connection(447, 264), + Connection(346, 352), + Connection(352, 340), + Connection(340, 346), + Connection(354, 274), + Connection(274, 19), + Connection(19, 354), + Connection(363, 456), + Connection(456, 281), + Connection(281, 363), + Connection(426, 436), + Connection(436, 425), + Connection(425, 426), + Connection(380, 381), + Connection(381, 252), + Connection(252, 380), + Connection(267, 269), + Connection(269, 393), + Connection(393, 267), + Connection(421, 200), + Connection(200, 428), + Connection(428, 421), + Connection(371, 266), + Connection(266, 329), + Connection(329, 371), + Connection(432, 287), + Connection(287, 422), + Connection(422, 432), + Connection(290, 250), + Connection(250, 328), + Connection(328, 290), + Connection(385, 258), + Connection(258, 384), + Connection(384, 385), + Connection(446, 265), + Connection(265, 342), + Connection(342, 446), + Connection(386, 387), + Connection(387, 257), + Connection(257, 386), + Connection(422, 424), + Connection(424, 430), + Connection(430, 422), + Connection(445, 342), + Connection(342, 276), + Connection(276, 445), + Connection(422, 273), + Connection(273, 424), + Connection(424, 422), + Connection(306, 292), + Connection(292, 307), + Connection(307, 306), + Connection(352, 366), + Connection(366, 345), + Connection(345, 352), + Connection(268, 271), + Connection(271, 302), + Connection(302, 268), + Connection(358, 423), + Connection(423, 371), + Connection(371, 358), + Connection(327, 294), + Connection(294, 460), + Connection(460, 327), + Connection(331, 279), + Connection(279, 294), + Connection(294, 331), + Connection(303, 271), + Connection(271, 304), + Connection(304, 303), + Connection(436, 432), + Connection(432, 427), + Connection(427, 436), + Connection(304, 272), + Connection(272, 408), + Connection(408, 304), + Connection(395, 394), + Connection(394, 431), + Connection(431, 395), + Connection(378, 395), + Connection(395, 400), + Connection(400, 378), + Connection(296, 334), + Connection(334, 299), + Connection(299, 296), + Connection(6, 351), + Connection(351, 168), + Connection(168, 6), + Connection(376, 352), + Connection(352, 411), + Connection(411, 376), + Connection(307, 325), + Connection(325, 320), + Connection(320, 307), + Connection(285, 295), + Connection(295, 336), + Connection(336, 285), + Connection(320, 319), + Connection(319, 404), + Connection(404, 320), + Connection(329, 330), + Connection(330, 349), + Connection(349, 329), + Connection(334, 293), + Connection(293, 333), + Connection(333, 334), + Connection(366, 323), + Connection(323, 447), + Connection(447, 366), + Connection(316, 15), + Connection(15, 315), + Connection(315, 316), + Connection(331, 358), + Connection(358, 279), + Connection(279, 331), + Connection(317, 14), + Connection(14, 316), + Connection(316, 317), + Connection(8, 285), + Connection(285, 9), + Connection(9, 8), + Connection(277, 329), + Connection(329, 350), + Connection(350, 277), + Connection(253, 374), + Connection(374, 252), + Connection(252, 253), + Connection(319, 318), + Connection(318, 403), + Connection(403, 319), + Connection(351, 6), + Connection(6, 419), + Connection(419, 351), + Connection(324, 318), + Connection(318, 325), + Connection(325, 324), + Connection(397, 367), + Connection(367, 365), + Connection(365, 397), + Connection(288, 435), + Connection(435, 397), + Connection(397, 288), + Connection(278, 344), + Connection(344, 439), + Connection(439, 278), + Connection(310, 272), + Connection(272, 311), + Connection(311, 310), + Connection(248, 195), + Connection(195, 281), + Connection(281, 248), + Connection(375, 273), + Connection(273, 291), + Connection(291, 375), + Connection(175, 396), + Connection(396, 199), + Connection(199, 175), + Connection(312, 311), + Connection(311, 268), + Connection(268, 312), + Connection(276, 283), + Connection(283, 445), + Connection(445, 276), + Connection(390, 373), + Connection(373, 339), + Connection(339, 390), + Connection(295, 282), + Connection(282, 296), + Connection(296, 295), + Connection(448, 449), + Connection(449, 346), + Connection(346, 448), + Connection(356, 264), + Connection(264, 454), + Connection(454, 356), + Connection(337, 336), + Connection(336, 299), + Connection(299, 337), + Connection(337, 338), + Connection(338, 151), + Connection(151, 337), + Connection(294, 278), + Connection(278, 455), + Connection(455, 294), + Connection(308, 292), + Connection(292, 415), + Connection(415, 308), + Connection(429, 358), + Connection(358, 355), + Connection(355, 429), + Connection(265, 340), + Connection(340, 372), + Connection(372, 265), + Connection(352, 346), + Connection(346, 280), + Connection(280, 352), + Connection(295, 442), + Connection(442, 282), + Connection(282, 295), + Connection(354, 19), + Connection(19, 370), + Connection(370, 354), + Connection(285, 441), + Connection(441, 295), + Connection(295, 285), + Connection(195, 248), + Connection(248, 197), + Connection(197, 195), + Connection(457, 440), + Connection(440, 274), + Connection(274, 457), + Connection(301, 300), + Connection(300, 368), + Connection(368, 301), + Connection(417, 351), + Connection(351, 465), + Connection(465, 417), + Connection(251, 301), + Connection(301, 389), + Connection(389, 251), + Connection(394, 395), + Connection(395, 379), + Connection(379, 394), + Connection(399, 412), + Connection(412, 419), + Connection(419, 399), + Connection(410, 436), + Connection(436, 322), + Connection(322, 410), + Connection(326, 2), + Connection(2, 393), + Connection(393, 326), + Connection(354, 370), + Connection(370, 461), + Connection(461, 354), + Connection(393, 164), + Connection(164, 267), + Connection(267, 393), + Connection(268, 302), + Connection(302, 12), + Connection(12, 268), + Connection(312, 268), + Connection(268, 13), + Connection(13, 312), + Connection(298, 293), + Connection(293, 301), + Connection(301, 298), + Connection(265, 446), + Connection(446, 340), + Connection(340, 265), + Connection(280, 330), + Connection(330, 425), + Connection(425, 280), + Connection(322, 426), + Connection(426, 391), + Connection(391, 322), + Connection(420, 429), + Connection(429, 437), + Connection(437, 420), + Connection(393, 391), + Connection(391, 326), + Connection(326, 393), + Connection(344, 440), + Connection(440, 438), + Connection(438, 344), + Connection(458, 459), + Connection(459, 461), + Connection(461, 458), + Connection(364, 434), + Connection(434, 394), + Connection(394, 364), + Connection(428, 396), + Connection(396, 262), + Connection(262, 428), + Connection(274, 354), + Connection(354, 457), + Connection(457, 274), + Connection(317, 316), + Connection(316, 402), + Connection(402, 317), + Connection(316, 315), + Connection(315, 403), + Connection(403, 316), + Connection(315, 314), + Connection(314, 404), + Connection(404, 315), + Connection(314, 313), + Connection(313, 405), + Connection(405, 314), + Connection(313, 421), + Connection(421, 406), + Connection(406, 313), + Connection(323, 366), + Connection(366, 361), + Connection(361, 323), + Connection(292, 306), + Connection(306, 407), + Connection(407, 292), + Connection(306, 291), + Connection(291, 408), + Connection(408, 306), + Connection(291, 287), + Connection(287, 409), + Connection(409, 291), + Connection(287, 432), + Connection(432, 410), + Connection(410, 287), + Connection(427, 434), + Connection(434, 411), + Connection(411, 427), + Connection(372, 264), + Connection(264, 383), + Connection(383, 372), + Connection(459, 309), + Connection(309, 457), + Connection(457, 459), + Connection(366, 352), + Connection(352, 401), + Connection(401, 366), + Connection(1, 274), + Connection(274, 4), + Connection(4, 1), + Connection(418, 421), + Connection(421, 262), + Connection(262, 418), + Connection(331, 294), + Connection(294, 358), + Connection(358, 331), + Connection(435, 433), + Connection(433, 367), + Connection(367, 435), + Connection(392, 289), + Connection(289, 439), + Connection(439, 392), + Connection(328, 462), + Connection(462, 326), + Connection(326, 328), + Connection(94, 2), + Connection(2, 370), + Connection(370, 94), + Connection(289, 305), + Connection(305, 455), + Connection(455, 289), + Connection(339, 254), + Connection(254, 448), + Connection(448, 339), + Connection(359, 255), + Connection(255, 446), + Connection(446, 359), + Connection(254, 253), + Connection(253, 449), + Connection(449, 254), + Connection(253, 252), + Connection(252, 450), + Connection(450, 253), + Connection(252, 256), + Connection(256, 451), + Connection(451, 252), + Connection(256, 341), + Connection(341, 452), + Connection(452, 256), + Connection(414, 413), + Connection(413, 463), + Connection(463, 414), + Connection(286, 441), + Connection(441, 414), + Connection(414, 286), + Connection(286, 258), + Connection(258, 441), + Connection(441, 286), + Connection(258, 257), + Connection(257, 442), + Connection(442, 258), + Connection(257, 259), + Connection(259, 443), + Connection(443, 257), + Connection(259, 260), + Connection(260, 444), + Connection(444, 259), + Connection(260, 467), + Connection(467, 445), + Connection(445, 260), + Connection(309, 459), + Connection(459, 250), + Connection(250, 309), + Connection(305, 289), + Connection(289, 290), + Connection(290, 305), + Connection(305, 290), + Connection(290, 460), + Connection(460, 305), + Connection(401, 376), + Connection(376, 435), + Connection(435, 401), + Connection(309, 250), + Connection(250, 392), + Connection(392, 309), + Connection(376, 411), + Connection(411, 433), + Connection(433, 376), + Connection(453, 341), + Connection(341, 464), + Connection(464, 453), + Connection(357, 453), + Connection(453, 465), + Connection(465, 357), + Connection(343, 357), + Connection(357, 412), + Connection(412, 343), + Connection(437, 343), + Connection(343, 399), + Connection(399, 437), + Connection(344, 360), + Connection(360, 440), + Connection(440, 344), + Connection(420, 437), + Connection(437, 456), + Connection(456, 420), + Connection(360, 420), + Connection(420, 363), + Connection(363, 360), + Connection(361, 401), + Connection(401, 288), + Connection(288, 361), + Connection(265, 372), + Connection(372, 353), + Connection(353, 265), + Connection(390, 339), + Connection(339, 249), + Connection(249, 390), + Connection(339, 448), + Connection(448, 255), + Connection(255, 339), + ] + + @dataclasses.dataclass class FaceLandmarkerResult: """The face landmarks detection result from FaceLandmarker, where each vector element represents a single face detected in the image. diff --git a/mediapipe/tasks/python/vision/face_stylizer.py b/mediapipe/tasks/python/vision/face_stylizer.py index c6470b19f..0b10a2b40 100644 --- a/mediapipe/tasks/python/vision/face_stylizer.py +++ b/mediapipe/tasks/python/vision/face_stylizer.py @@ -176,16 +176,13 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi): Only use this method when the FaceStylizer is created with the image running mode. - To ensure that the output image has reasonable quality, the stylized output - image size is the smaller of the model output size and the size of the - `region_of_interest` specified in `image_processing_options`. - Args: image: MediaPipe Image. image_processing_options: Options for image processing. Returns: - The stylized image of the most visible face. None if no face is detected + The stylized image of the most visible face. The stylized output image + size is the same as the model output size. None if no face is detected on the input image. Raises: @@ -217,17 +214,14 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi): milliseconds) along with the video frame. The input timestamps should be monotonically increasing for adjacent calls of this method. - To ensure that the output image has reasonable quality, the stylized output - image size is the smaller of the model output size and the size of the - `region_of_interest` specified in `image_processing_options`. - Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input video frame in milliseconds. image_processing_options: Options for image processing. Returns: - The stylized image of the most visible face. None if no face is detected + The stylized image of the most visible face. The stylized output image + size is the same as the model output size. None if no face is detected on the input image. Raises: @@ -266,12 +260,9 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi): images if needed. In other words, it's not guaranteed to have output per input image. - To ensure that the stylized image has reasonable quality, the stylized - output image size is the smaller of the model output size and the size of - the `region_of_interest` specified in `image_processing_options`. - The `result_callback` provides: - - The stylized image of the most visible face. None if no face is detected + - The stylized image of the most visible face. The stylized output image + size is the same as the model output size. None if no face is detected on the input image. - The input image that the face stylizer runs on. - The input timestamp in milliseconds. diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index e50ffbf79..d2ebdda1c 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -14,7 +14,6 @@ """MediaPipe image segmenter task.""" import dataclasses -import enum from typing import Callable, List, Mapping, Optional from mediapipe.python import packet_creator @@ -31,7 +30,6 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import vision_task_running_mode -ImageSegmenterResult = List[image_module.Image] _NormalizedRect = rect.NormalizedRect _BaseOptions = base_options_module.BaseOptions _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions @@ -42,8 +40,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo -_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' -_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' +_CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks' +_CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS' +_CATEGORY_MASK_STREAM_NAME = 'category_mask' +_CATEGORY_MASK_TAG = 'CATEGORY_MASK' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' @@ -53,6 +53,21 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 +@dataclasses.dataclass +class ImageSegmenterResult: + """Output result of ImageSegmenter. + + confidence_masks: multiple masks of float image where, for each mask, each + pixel represents the prediction confidence, usually in the [0, 1] range. + + category_mask: a category mask of uint8 image where each pixel represents the + class which the pixel in the original image was predicted to belong to. + """ + + confidence_masks: Optional[List[image_module.Image]] = None + category_mask: Optional[image_module.Image] = None + + @dataclasses.dataclass class ImageSegmenterOptions: """Options for the image segmenter task. @@ -64,28 +79,17 @@ class ImageSegmenterOptions: objects on single image inputs. 2) The video mode for segmenting objects on the decoded frames of a video. 3) The live stream mode for segmenting objects on a live stream of input data, such as from camera. - output_type: The output mask type allows specifying the type of - post-processing to perform on the raw model results. - activation: Activation function to apply to input tensor. + output_confidence_masks: Whether to output confidence masks. + output_category_mask: Whether to output category mask. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. """ - class OutputType(enum.Enum): - UNSPECIFIED = 0 - CATEGORY_MASK = 1 - CONFIDENCE_MASK = 2 - - class Activation(enum.Enum): - NONE = 0 - SIGMOID = 1 - SOFTMAX = 2 - base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - output_type: Optional[OutputType] = OutputType.CATEGORY_MASK - activation: Optional[Activation] = Activation.NONE + output_confidence_masks: bool = True + output_category_mask: bool = False result_callback: Optional[ Callable[[ImageSegmenterResult, image_module.Image, int], None] ] = None @@ -97,9 +101,7 @@ class ImageSegmenterOptions: base_options_proto.use_stream_mode = ( False if self.running_mode == _RunningMode.IMAGE else True ) - segmenter_options_proto = _SegmenterOptionsProto( - output_type=self.output_type.value, activation=self.activation.value - ) + segmenter_options_proto = _SegmenterOptionsProto() return _ImageSegmenterGraphOptionsProto( base_options=base_options_proto, segmenter_options=segmenter_options_proto, @@ -177,27 +179,48 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): def packets_callback(output_packets: Mapping[str, packet.Packet]): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): return - segmentation_result = packet_getter.get_image_list( - output_packets[_SEGMENTATION_OUT_STREAM_NAME] - ) + + segmentation_result = ImageSegmenterResult() + + if options.output_confidence_masks: + segmentation_result.confidence_masks = packet_getter.get_image_list( + output_packets[_CONFIDENCE_MASKS_STREAM_NAME] + ) + + if options.output_category_mask: + segmentation_result.category_mask = packet_getter.get_image( + output_packets[_CATEGORY_MASK_STREAM_NAME] + ) + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) - timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp + timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp options.result_callback( segmentation_result, image, timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, ) + output_streams = [ + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), + ] + + if options.output_confidence_masks: + output_streams.append( + ':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME]) + ) + + if options.output_category_mask: + output_streams.append( + ':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME]) + ) + task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, input_streams=[ ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ], - output_streams=[ - ':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]), - ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), - ], + output_streams=output_streams, task_options=options, ) return cls( @@ -240,9 +263,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): normalized_rect.to_pb2() ), }) - segmentation_result = packet_getter.get_image_list( - output_packets[_SEGMENTATION_OUT_STREAM_NAME] - ) + segmentation_result = ImageSegmenterResult() + + if _CONFIDENCE_MASKS_STREAM_NAME in output_packets: + segmentation_result.confidence_masks = packet_getter.get_image_list( + output_packets[_CONFIDENCE_MASKS_STREAM_NAME] + ) + + if _CATEGORY_MASK_STREAM_NAME in output_packets: + segmentation_result.category_mask = packet_getter.get_image( + output_packets[_CATEGORY_MASK_STREAM_NAME] + ) + return segmentation_result def segment_for_video( @@ -285,9 +317,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): normalized_rect.to_pb2() ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) - segmentation_result = packet_getter.get_image_list( - output_packets[_SEGMENTATION_OUT_STREAM_NAME] - ) + segmentation_result = ImageSegmenterResult() + + if _CONFIDENCE_MASKS_STREAM_NAME in output_packets: + segmentation_result.confidence_masks = packet_getter.get_image_list( + output_packets[_CONFIDENCE_MASKS_STREAM_NAME] + ) + + if _CATEGORY_MASK_STREAM_NAME in output_packets: + segmentation_result.category_mask = packet_getter.get_image( + output_packets[_CATEGORY_MASK_STREAM_NAME] + ) + return segmentation_result def segment_async( diff --git a/mediapipe/tasks/python/vision/interactive_segmenter.py b/mediapipe/tasks/python/vision/interactive_segmenter.py index 12a30b6ef..ad93c798c 100644 --- a/mediapipe/tasks/python/vision/interactive_segmenter.py +++ b/mediapipe/tasks/python/vision/interactive_segmenter.py @@ -41,8 +41,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo -_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' -_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' +_CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks' +_CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS' +_CATEGORY_MASK_STREAM_NAME = 'category_mask' +_CATEGORY_MASK_TAG = 'CATEGORY_MASK' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _ROI_STREAM_NAME = 'roi_in' @@ -55,32 +57,41 @@ _TASK_GRAPH_NAME = ( ) +@dataclasses.dataclass +class InteractiveSegmenterResult: + """Output result of InteractiveSegmenter. + + confidence_masks: multiple masks of float image where, for each mask, each + pixel represents the prediction confidence, usually in the [0, 1] range. + + category_mask: a category mask of uint8 image where each pixel represents the + class which the pixel in the original image was predicted to belong to. + """ + + confidence_masks: Optional[List[image_module.Image]] = None + category_mask: Optional[image_module.Image] = None + + @dataclasses.dataclass class InteractiveSegmenterOptions: """Options for the interactive segmenter task. Attributes: base_options: Base options for the interactive segmenter task. - output_type: The output mask type allows specifying the type of - post-processing to perform on the raw model results. + output_confidence_masks: Whether to output confidence masks. + output_category_mask: Whether to output category mask. """ - class OutputType(enum.Enum): - UNSPECIFIED = 0 - CATEGORY_MASK = 1 - CONFIDENCE_MASK = 2 - base_options: _BaseOptions - output_type: Optional[OutputType] = OutputType.CATEGORY_MASK + output_confidence_masks: bool = True + output_category_mask: bool = False @doc_controls.do_not_generate_docs def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: """Generates an InteractiveSegmenterOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False - segmenter_options_proto = _SegmenterOptionsProto( - output_type=self.output_type.value - ) + segmenter_options_proto = _SegmenterOptionsProto() return _ImageSegmenterGraphOptionsProto( base_options=base_options_proto, segmenter_options=segmenter_options_proto, @@ -192,6 +203,20 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): RuntimeError: If other types of error occurred. """ + output_streams = [ + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), + ] + + if options.output_confidence_masks: + output_streams.append( + ':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME]) + ) + + if options.output_category_mask: + output_streams.append( + ':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME]) + ) + task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, input_streams=[ @@ -199,10 +224,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): ':'.join([_ROI_TAG, _ROI_STREAM_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ], - output_streams=[ - ':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]), - ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), - ], + output_streams=output_streams, task_options=options, ) return cls( @@ -216,7 +238,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): image: image_module.Image, roi: RegionOfInterest, image_processing_options: Optional[_ImageProcessingOptions] = None, - ) -> List[image_module.Image]: + ) -> InteractiveSegmenterResult: """Performs the actual segmentation task on the provided MediaPipe Image. The image can be of any size with format RGB. @@ -248,7 +270,16 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): normalized_rect.to_pb2() ), }) - segmentation_result = packet_getter.get_image_list( - output_packets[_SEGMENTATION_OUT_STREAM_NAME] - ) + segmentation_result = InteractiveSegmenterResult() + + if _CONFIDENCE_MASKS_STREAM_NAME in output_packets: + segmentation_result.confidence_masks = packet_getter.get_image_list( + output_packets[_CONFIDENCE_MASKS_STREAM_NAME] + ) + + if _CATEGORY_MASK_STREAM_NAME in output_packets: + segmentation_result.category_mask = packet_getter.get_image( + output_packets[_CATEGORY_MASK_STREAM_NAME] + ) + return segmentation_result diff --git a/mediapipe/tasks/web/vision/core/render_utils.ts b/mediapipe/tasks/web/vision/core/render_utils.ts index 879e23010..903d789f5 100644 --- a/mediapipe/tasks/web/vision/core/render_utils.ts +++ b/mediapipe/tasks/web/vision/core/render_utils.ts @@ -59,13 +59,12 @@ export function drawCategoryMask( const isFloatArray = image instanceof Float32Array; for (let i = 0; i < image.length; i++) { const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i]; - const color = COLOR_MAP[colorIndex]; + let color = COLOR_MAP[colorIndex % COLOR_MAP.length]; - // When we're given a confidence mask by accident, we just log and return. - // TODO: We should fix this. if (!color) { + // TODO: We should fix this. console.warn('No color for ', colorIndex); - return; + color = COLOR_MAP[colorIndex % COLOR_MAP.length]; } rgbaArray[4 * i] = color[0]; diff --git a/mediapipe/tasks/web/vision/core/types.d.ts b/mediapipe/tasks/web/vision/core/types.d.ts index 5699126b9..344d4db85 100644 --- a/mediapipe/tasks/web/vision/core/types.d.ts +++ b/mediapipe/tasks/web/vision/core/types.d.ts @@ -25,17 +25,6 @@ import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/ke */ export type SegmentationMask = Uint8ClampedArray|Float32Array|WebGLTexture; -/** - * A callback that receives the computed masks from the segmentation tasks. The - * callback either receives a single element array with a category mask (as a - * `[Uint8ClampedArray]`) or multiple confidence masks (as a `Float32Array[]`). - * The returned data is only valid for the duration of the callback. If - * asynchronous processing is needed, all data needs to be copied before the - * callback returns. - */ -export type SegmentationMaskCallback = - (masks: SegmentationMask[], width: number, height: number) => void; - /** * A callback that receives an `ImageData` object from a Vision task. The * lifetime of the underlying data is limited to the duration of the callback. diff --git a/mediapipe/tasks/web/vision/face_landmarker/face_landmarks_connections.ts b/mediapipe/tasks/web/vision/face_landmarker/face_landmarks_connections.ts index 978324750..337f663e3 100644 --- a/mediapipe/tasks/web/vision/face_landmarker/face_landmarks_connections.ts +++ b/mediapipe/tasks/web/vision/face_landmarker/face_landmarks_connections.ts @@ -19,7 +19,7 @@ import {Connection} from '../../../../tasks/web/vision/core/types'; // tslint:disable:class-as-namespace Using for easier import by 3P users /** - * A class containing the Pairs of landmark indices to be rendered with + * A class containing the pairs of landmark indices to be rendered with * connections. */ export class FaceLandmarksConnections { diff --git a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts index 34067aaba..dfce03030 100644 --- a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts +++ b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts @@ -129,10 +129,6 @@ export class FaceStylizer extends VisionTaskRunner { * synchronously once the callback returns. Only use this method when the * FaceStylizer is created with the image running mode. * - * The input image can be of any size. To ensure that the output image has - * reasonable quality, the stylized output image size is determined by the - * model output size. - * * @param image An image to process. * @param callback The callback that is invoked with the stylized image. The * lifetime of the returned data is only guaranteed for the duration of the @@ -153,11 +149,6 @@ export class FaceStylizer extends VisionTaskRunner { * If both are specified, the crop around the region-of-interest is extracted * first, then the specified rotation is applied to the crop. * - * The input image can be of any size. To ensure that the output image has - * reasonable quality, the stylized output image size is the smaller of the - * model output size and the size of the 'regionOfInterest' specified in - * 'imageProcessingOptions'. - * * @param image An image to process. * @param imageProcessingOptions the `ImageProcessingOptions` specifying how * to process the input image before running inference. @@ -192,9 +183,6 @@ export class FaceStylizer extends VisionTaskRunner { * frame's timestamp (in milliseconds). The input timestamps must be * monotonically increasing. * - * To ensure that the output image has reasonable quality, the stylized - * output image size is determined by the model output size. - * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. * @param callback The callback that is invoked with the stylized image. The @@ -221,10 +209,6 @@ export class FaceStylizer extends VisionTaskRunner { * frame's timestamp (in milliseconds). The input timestamps must be * monotonically increasing. * - * To ensure that the output image has reasonable quality, the stylized - * output image size is the smaller of the model output size and the size of - * the 'regionOfInterest' specified in 'imageProcessingOptions'. - * * @param videoFrame A video frame to process. * @param imageProcessingOptions the `ImageProcessingOptions` specifying how * to process the input image before running inference. @@ -278,8 +262,12 @@ export class FaceStylizer extends VisionTaskRunner { this.graphRunner.attachImageListener( STYLIZED_IMAGE_STREAM, (image, timestamp) => { - const imageData = this.convertToImageData(image); - this.userCallback(imageData, image.width, image.height); + if (image.data instanceof WebGLTexture) { + this.userCallback(image.data, image.width, image.height); + } else { + const imageData = this.convertToImageData(image); + this.userCallback(imageData, image.width, image.height); + } this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachEmptyPacketListener( diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index 9156e89b7..a3a630e90 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -34,6 +34,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/tasks/web/vision/hand_landmarker:hand_landmarks_connections", "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 74d37cb63..df9c91282 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -31,6 +31,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/ import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {HAND_CONNECTIONS} from '../../../../tasks/web/vision/hand_landmarker/hand_landmarks_connections'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -72,6 +73,12 @@ export class GestureRecognizer extends VisionTaskRunner { private readonly handGestureRecognizerGraphOptions: HandGestureRecognizerGraphOptions; + /** + * An array containing the pairs of hand landmark indices to be rendered with + * connections. + */ + static HAND_CONNECTIONS = HAND_CONNECTIONS; + /** * Initializes the Wasm runtime and creates a new gesture recognizer from the * provided options. diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index c5687ee2f..948d56927 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -16,6 +16,7 @@ mediapipe_ts_library( visibility = ["//visibility:public"], deps = [ ":hand_landmarker_types", + ":hand_landmarks_connections", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto", @@ -72,3 +73,9 @@ jasmine_node_test( tags = ["nomsan"], deps = [":hand_landmarker_test_lib"], ) + +mediapipe_ts_library( + name = "hand_landmarks_connections", + srcs = ["hand_landmarks_connections.ts"], + deps = ["//mediapipe/tasks/web/vision/core:types"], +) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 1978bb061..62928536d 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -27,6 +27,7 @@ import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/con import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {HAND_CONNECTIONS} from '../../../../tasks/web/vision/hand_landmarker/hand_landmarks_connections'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -63,6 +64,12 @@ export class HandLandmarker extends VisionTaskRunner { HandLandmarksDetectorGraphOptions; private readonly handDetectorGraphOptions: HandDetectorGraphOptions; + /** + * An array containing the pairs of hand landmark indices to be rendered with + * connections. + */ + static HAND_CONNECTIONS = HAND_CONNECTIONS; + /** * Initializes the Wasm runtime and creates a new `HandLandmarker` from the * provided options. diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarks_connections.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarks_connections.ts new file mode 100644 index 000000000..edb789c8f --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarks_connections.ts @@ -0,0 +1,31 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Connection} from '../../../../tasks/web/vision/core/types'; + +/** + * An array containing the pairs of hand landmark indices to be rendered with + * connections. + */ +export const HAND_CONNECTIONS: Connection[] = [ + {start: 0, end: 1}, {start: 1, end: 2}, {start: 2, end: 3}, + {start: 3, end: 4}, {start: 0, end: 5}, {start: 5, end: 6}, + {start: 6, end: 7}, {start: 7, end: 8}, {start: 5, end: 9}, + {start: 9, end: 10}, {start: 10, end: 11}, {start: 11, end: 12}, + {start: 9, end: 13}, {start: 13, end: 14}, {start: 14, end: 15}, + {start: 15, end: 16}, {start: 13, end: 17}, {start: 0, end: 17}, + {start: 17, end: 18}, {start: 18, end: 19}, {start: 19, end: 20} +]; diff --git a/mediapipe/tasks/web/vision/image_segmenter/BUILD b/mediapipe/tasks/web/vision/image_segmenter/BUILD index a4b9008dd..3db15641f 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/web/vision/image_segmenter/BUILD @@ -29,7 +29,10 @@ mediapipe_ts_library( mediapipe_ts_declaration( name = "image_segmenter_types", - srcs = ["image_segmenter_options.d.ts"], + srcs = [ + "image_segmenter_options.d.ts", + "image_segmenter_result.d.ts", + ], deps = [ "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts index 3690fd855..740047762 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts @@ -22,33 +22,48 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../ import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; -import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types'; +import {SegmentationMask} from '../../../../tasks/web/vision/core/types'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {LabelMapItem} from '../../../../util/label_map_pb'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageSegmenterOptions} from './image_segmenter_options'; +import {ImageSegmenterResult} from './image_segmenter_result'; export * from './image_segmenter_options'; -export {SegmentationMask, SegmentationMaskCallback}; +export * from './image_segmenter_result'; +export {SegmentationMask}; export {ImageSource}; // Used in the public API const IMAGE_STREAM = 'image_in'; const NORM_RECT_STREAM = 'norm_rect'; -const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks'; +const CONFIDENCE_MASKS_STREAM = 'confidence_masks'; +const CATEGORY_MASK_STREAM = 'category_mask'; const IMAGE_SEGMENTER_GRAPH = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = 'mediapipe.tasks.TensorsToSegmentationCalculator'; +const DEFAULT_OUTPUT_CATEGORY_MASK = false; +const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern +/** + * A callback that receives the computed masks from the image segmenter. The + * returned data is only valid for the duration of the callback. If + * asynchronous processing is needed, all data needs to be copied before the + * callback returns. + */ +export type ImageSegmenterCallack = (result: ImageSegmenterResult) => void; + /** Performs image segmentation on images. */ export class ImageSegmenter extends VisionTaskRunner { - private userCallback: SegmentationMaskCallback = () => {}; + private result: ImageSegmenterResult = {width: 0, height: 0}; private labels: string[] = []; + private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; + private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private readonly options: ImageSegmenterGraphOptionsProto; private readonly segmenterOptions: SegmenterOptionsProto; @@ -109,7 +124,6 @@ export class ImageSegmenter extends VisionTaskRunner { this.options.setBaseOptions(new BaseOptionsProto()); } - protected override get baseOptions(): BaseOptionsProto { return this.options.getBaseOptions()!; } @@ -137,12 +151,14 @@ export class ImageSegmenter extends VisionTaskRunner { this.options.clearDisplayNamesLocale(); } - if (options.outputType === 'CONFIDENCE_MASK') { - this.segmenterOptions.setOutputType( - SegmenterOptionsProto.OutputType.CONFIDENCE_MASK); - } else { - this.segmenterOptions.setOutputType( - SegmenterOptionsProto.OutputType.CATEGORY_MASK); + if ('outputCategoryMask' in options) { + this.outputCategoryMask = + options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK; + } + + if ('outputConfidenceMasks' in options) { + this.outputConfidenceMasks = + options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS; } return super.applyOptions(options); @@ -192,7 +208,7 @@ export class ImageSegmenter extends VisionTaskRunner { * lifetime of the returned data is only guaranteed for the duration of the * callback. */ - segment(image: ImageSource, callback: SegmentationMaskCallback): void; + segment(image: ImageSource, callback: ImageSegmenterCallack): void; /** * Performs image segmentation on the provided single image and invokes the * callback with the response. The method returns synchronously once the @@ -208,22 +224,77 @@ export class ImageSegmenter extends VisionTaskRunner { */ segment( image: ImageSource, imageProcessingOptions: ImageProcessingOptions, - callback: SegmentationMaskCallback): void; + callback: ImageSegmenterCallack): void; segment( image: ImageSource, imageProcessingOptionsOrCallback: ImageProcessingOptions| - SegmentationMaskCallback, - callback?: SegmentationMaskCallback): void { + ImageSegmenterCallack, + callback?: ImageSegmenterCallack): void { const imageProcessingOptions = typeof imageProcessingOptionsOrCallback !== 'function' ? imageProcessingOptionsOrCallback : {}; - - this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? + const userCallback = + typeof imageProcessingOptionsOrCallback === 'function' ? imageProcessingOptionsOrCallback : callback!; + + this.reset(); this.processImageData(image, imageProcessingOptions); - this.userCallback = () => {}; + userCallback(this.result); + } + + /** + * Performs image segmentation on the provided video frame and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segmentForVideo( + videoFrame: ImageSource, timestamp: number, + callback: ImageSegmenterCallack): void; + /** + * Performs image segmentation on the provided video frame and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param timestamp The timestamp of the current frame, in ms. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segmentForVideo( + videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, + timestamp: number, callback: ImageSegmenterCallack): void; + segmentForVideo( + videoFrame: ImageSource, + timestampOrImageProcessingOptions: number|ImageProcessingOptions, + timestampOrCallback: number|ImageSegmenterCallack, + callback?: ImageSegmenterCallack): void { + const imageProcessingOptions = + typeof timestampOrImageProcessingOptions !== 'number' ? + timestampOrImageProcessingOptions : + {}; + const timestamp = typeof timestampOrImageProcessingOptions === 'number' ? + timestampOrImageProcessingOptions : + timestampOrCallback as number; + const userCallback = typeof timestampOrCallback === 'function' ? + timestampOrCallback : + callback!; + + this.reset(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + userCallback(this.result); } /** @@ -241,56 +312,8 @@ export class ImageSegmenter extends VisionTaskRunner { return this.labels; } - /** - * Performs image segmentation on the provided video frame and invokes the - * callback with the response. The method returns synchronously once the - * callback returns. Only use this method when the ImageSegmenter is - * created with running mode `video`. - * - * @param videoFrame A video frame to process. - * @param timestamp The timestamp of the current frame, in ms. - * @param callback The callback that is invoked with the segmented masks. The - * lifetime of the returned data is only guaranteed for the duration of the - * callback. - */ - segmentForVideo( - videoFrame: ImageSource, timestamp: number, - callback: SegmentationMaskCallback): void; - /** - * Performs image segmentation on the provided video frame and invokes the - * callback with the response. The method returns synchronously once the - * callback returns. Only use this method when the ImageSegmenter is - * created with running mode `video`. - * - * @param videoFrame A video frame to process. - * @param imageProcessingOptions the `ImageProcessingOptions` specifying how - * to process the input image before running inference. - * @param timestamp The timestamp of the current frame, in ms. - * @param callback The callback that is invoked with the segmented masks. The - * lifetime of the returned data is only guaranteed for the duration of the - * callback. - */ - segmentForVideo( - videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, - timestamp: number, callback: SegmentationMaskCallback): void; - segmentForVideo( - videoFrame: ImageSource, - timestampOrImageProcessingOptions: number|ImageProcessingOptions, - timestampOrCallback: number|SegmentationMaskCallback, - callback?: SegmentationMaskCallback): void { - const imageProcessingOptions = - typeof timestampOrImageProcessingOptions !== 'number' ? - timestampOrImageProcessingOptions : - {}; - const timestamp = typeof timestampOrImageProcessingOptions === 'number' ? - timestampOrImageProcessingOptions : - timestampOrCallback as number; - - this.userCallback = typeof timestampOrCallback === 'function' ? - timestampOrCallback : - callback!; - this.processVideoData(videoFrame, imageProcessingOptions, timestamp); - this.userCallback = () => {}; + private reset(): void { + this.result = {width: 0, height: 0}; } /** Updates the MediaPipe graph configuration. */ @@ -298,7 +321,6 @@ export class ImageSegmenter extends VisionTaskRunner { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM); - graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM); const calculatorOptions = new CalculatorOptions(); calculatorOptions.setExtension( @@ -308,26 +330,47 @@ export class ImageSegmenter extends VisionTaskRunner { segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH); segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM); segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); - segmenterNode.addOutputStream( - 'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM); segmenterNode.setOptions(calculatorOptions); graphConfig.addNode(segmenterNode); - this.graphRunner.attachImageVectorListener( - GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => { - if (masks.length === 0) { - this.userCallback([], 0, 0); - } else { - this.userCallback( - masks.map(m => m.data), masks[0].width, masks[0].height); - } - this.setLatestOutputTimestamp(timestamp); - }); - this.graphRunner.attachEmptyPacketListener( - GROUPED_SEGMENTATIONS_STREAM, timestamp => { - this.setLatestOutputTimestamp(timestamp); - }); + if (this.outputConfidenceMasks) { + graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM); + segmenterNode.addOutputStream( + 'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM); + + this.graphRunner.attachImageVectorListener( + CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { + this.result.confidenceMasks = masks.map(m => m.data); + if (masks.length >= 0) { + this.result.width = masks[0].width; + this.result.height = masks[0].height; + } + + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + CONFIDENCE_MASKS_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + } + + if (this.outputCategoryMask) { + graphConfig.addOutputStream(CATEGORY_MASK_STREAM); + segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM); + + this.graphRunner.attachImageListener( + CATEGORY_MASK_STREAM, (mask, timestamp) => { + this.result.categoryMask = mask.data; + this.result.width = mask.width; + this.result.height = mask.height; + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + CATEGORY_MASK_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + } const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts index c17e7e421..f80a792a5 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts @@ -24,18 +24,9 @@ export interface ImageSegmenterOptions extends VisionTaskOptions { */ displayNamesLocale?: string|undefined; - /** - * The output type of segmentation results. - * - * The two supported modes are: - * - Category Mask: Gives a single output mask where each pixel represents - * the class which the pixel in the original image was - * predicted to belong to. - * - Confidence Mask: Gives a list of output masks (one for each class). For - * each mask, the pixel represents the prediction - * confidence, usually in the [0.0, 0.1] range. - * - * Defaults to `CATEGORY_MASK`. - */ - outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined; + /** Whether to output confidence masks. Defaults to true. */ + outputConfidenceMasks?: boolean|undefined; + + /** Whether to output the category masks. Defaults to false. */ + outputCategoryMask?: boolean|undefined; } diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts new file mode 100644 index 000000000..be082d516 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts @@ -0,0 +1,37 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** The output result of ImageSegmenter. */ +export declare interface ImageSegmenterResult { + /** + * Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each + * pixel represents the prediction confidence, usually in the [0, 1] range. + */ + confidenceMasks?: Float32Array[]|WebGLTexture[]; + + /** + * A category mask as a Uint8ClampedArray or WebGLTexture where each + * pixel represents the class which the pixel in the original image was + * predicted to belong to. + */ + categoryMask?: Uint8ClampedArray|WebGLTexture; + + /** The width of the masks. */ + width: number; + + /** The height of the masks. */ + height: number; +} diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts index 4cf27b9a5..6b5c90080 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts @@ -18,7 +18,7 @@ import 'jasmine'; // Placeholder for internal dependency on encodeByteArray import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; -import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {ImageSegmenter} from './image_segmenter'; @@ -30,7 +30,9 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake { graph: CalculatorGraphConfig|undefined; fakeWasmModule: SpyWasmModule; - imageVectorListener: + categoryMaskListener: + ((images: WasmImage, timestamp: number) => void)|undefined; + confidenceMasksListener: ((images: WasmImage[], timestamp: number) => void)|undefined; constructor() { @@ -38,11 +40,16 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake { this.fakeWasmModule = this.graphRunner.wasmModule as unknown as SpyWasmModule; - this.attachListenerSpies[0] = + this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('category_mask'); + this.categoryMaskListener = listener; + }); + this.attachListenerSpies[1] = spyOn(this.graphRunner, 'attachImageVectorListener') .and.callFake((stream, listener) => { - expect(stream).toEqual('segmented_masks'); - this.imageVectorListener = listener; + expect(stream).toEqual('confidence_masks'); + this.confidenceMasksListener = listener; }); spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); @@ -63,17 +70,18 @@ describe('ImageSegmenter', () => { it('initializes graph', async () => { verifyGraph(imageSegmenter); - verifyListenersRegistered(imageSegmenter); + + // Verify default options + expect(imageSegmenter.categoryMaskListener).not.toBeDefined(); + expect(imageSegmenter.confidenceMasksListener).toBeDefined(); }); it('reloads graph when settings are changed', async () => { await imageSegmenter.setOptions({displayNamesLocale: 'en'}); verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); - verifyListenersRegistered(imageSegmenter); await imageSegmenter.setOptions({displayNamesLocale: 'de'}); verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']); - verifyListenersRegistered(imageSegmenter); }); it('can use custom models', async () => { @@ -100,9 +108,11 @@ describe('ImageSegmenter', () => { }); it('merges options', async () => { - await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'}); + await imageSegmenter.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); await imageSegmenter.setOptions({displayNamesLocale: 'en'}); - verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]); + verifyGraph( + imageSegmenter, [['baseOptions', 'modelAsset', 'fileContent'], '']); verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); }); @@ -115,22 +125,13 @@ describe('ImageSegmenter', () => { defaultValue: unknown; } - const testCases: TestCase[] = [ - { - optionName: 'displayNamesLocale', - fieldPath: ['displayNamesLocale'], - userValue: 'en', - graphValue: 'en', - defaultValue: 'en' - }, - { - optionName: 'outputType', - fieldPath: ['segmenterOptions', 'outputType'], - userValue: 'CONFIDENCE_MASK', - graphValue: 2, - defaultValue: 1 - }, - ]; + const testCases: TestCase[] = [{ + optionName: 'displayNamesLocale', + fieldPath: ['displayNamesLocale'], + userValue: 'en', + graphValue: 'en', + defaultValue: 'en' + }]; for (const testCase of testCases) { it(`can set ${testCase.optionName}`, async () => { @@ -158,27 +159,31 @@ describe('ImageSegmenter', () => { }).toThrowError('This task doesn\'t support region-of-interest.'); }); - it('supports category masks', (done) => { + it('supports category mask', async () => { const mask = new Uint8ClampedArray([1, 2, 3, 4]); + await imageSegmenter.setOptions( + {outputCategoryMask: true, outputConfidenceMasks: false}); + // Pass the test data to our listener imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { - verifyListenersRegistered(imageSegmenter); - imageSegmenter.imageVectorListener!( - [ - {data: mask, width: 2, height: 2}, - ], - /* timestamp= */ 1337); + expect(imageSegmenter.categoryMaskListener).toBeDefined(); + imageSegmenter.categoryMaskListener! + ({data: mask, width: 2, height: 2}, + /* timestamp= */ 1337); }); // Invoke the image segmenter - imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => { - expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); - expect(masks).toHaveSize(1); - expect(masks[0]).toEqual(mask); - expect(width).toEqual(2); - expect(height).toEqual(2); - done(); + + return new Promise(resolve => { + imageSegmenter.segment({} as HTMLImageElement, result => { + expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result.categoryMask).toEqual(mask); + expect(result.confidenceMasks).not.toBeDefined(); + expect(result.width).toEqual(2); + expect(result.height).toEqual(2); + resolve(); + }); }); }); @@ -186,12 +191,13 @@ describe('ImageSegmenter', () => { const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]); const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]); - await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + await imageSegmenter.setOptions( + {outputCategoryMask: false, outputConfidenceMasks: true}); // Pass the test data to our listener imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { - verifyListenersRegistered(imageSegmenter); - imageSegmenter.imageVectorListener!( + expect(imageSegmenter.confidenceMasksListener).toBeDefined(); + imageSegmenter.confidenceMasksListener!( [ {data: mask1, width: 2, height: 2}, {data: mask2, width: 2, height: 2}, @@ -201,13 +207,49 @@ describe('ImageSegmenter', () => { return new Promise(resolve => { // Invoke the image segmenter - imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => { + imageSegmenter.segment({} as HTMLImageElement, result => { expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); - expect(masks).toHaveSize(2); - expect(masks[0]).toEqual(mask1); - expect(masks[1]).toEqual(mask2); - expect(width).toEqual(2); - expect(height).toEqual(2); + expect(result.categoryMask).not.toBeDefined(); + expect(result.confidenceMasks).toEqual([mask1, mask2]); + expect(result.width).toEqual(2); + expect(result.height).toEqual(2); + resolve(); + }); + }); + }); + + it('supports combined category and confidence masks', async () => { + const categoryMask = new Uint8ClampedArray([1, 0]); + const confidenceMask1 = new Float32Array([0.0, 1.0]); + const confidenceMask2 = new Float32Array([1.0, 0.0]); + + await imageSegmenter.setOptions( + {outputCategoryMask: true, outputConfidenceMasks: true}); + + // Pass the test data to our listener + imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + expect(imageSegmenter.categoryMaskListener).toBeDefined(); + expect(imageSegmenter.confidenceMasksListener).toBeDefined(); + imageSegmenter.categoryMaskListener! + ({data: categoryMask, width: 1, height: 1}, 1337); + imageSegmenter.confidenceMasksListener!( + [ + {data: confidenceMask1, width: 1, height: 1}, + {data: confidenceMask2, width: 1, height: 1}, + ], + 1337); + }); + + return new Promise(resolve => { + // Invoke the image segmenter + imageSegmenter.segment({} as HTMLImageElement, result => { + expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result.categoryMask).toEqual(categoryMask); + expect(result.confidenceMasks).toEqual([ + confidenceMask1, confidenceMask2 + ]); + expect(result.width).toEqual(1); + expect(result.height).toEqual(1); resolve(); }); }); diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/BUILD b/mediapipe/tasks/web/vision/interactive_segmenter/BUILD index a4a3f27c9..ead85d38a 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/BUILD +++ b/mediapipe/tasks/web/vision/interactive_segmenter/BUILD @@ -30,7 +30,10 @@ mediapipe_ts_library( mediapipe_ts_declaration( name = "interactive_segmenter_types", - srcs = ["interactive_segmenter_options.d.ts"], + srcs = [ + "interactive_segmenter_options.d.ts", + "interactive_segmenter_result.d.ts", + ], deps = [ "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts index ddcc7e592..16841bd7f 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts @@ -21,7 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../ import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; -import {RegionOfInterest, SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types'; +import {RegionOfInterest, SegmentationMask} from '../../../../tasks/web/vision/core/types'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {Color as ColorProto} from '../../../../util/color_pb'; import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb'; @@ -29,21 +29,35 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner // Placeholder for internal dependency on trusted resource url import {InteractiveSegmenterOptions} from './interactive_segmenter_options'; +import {InteractiveSegmenterResult} from './interactive_segmenter_result'; export * from './interactive_segmenter_options'; -export {SegmentationMask, SegmentationMaskCallback, RegionOfInterest}; +export * from './interactive_segmenter_result'; +export {SegmentationMask, RegionOfInterest}; export {ImageSource}; const IMAGE_IN_STREAM = 'image_in'; const NORM_RECT_IN_STREAM = 'norm_rect_in'; const ROI_IN_STREAM = 'roi_in'; -const IMAGE_OUT_STREAM = 'image_out'; +const CONFIDENCE_MASKS_STREAM = 'confidence_masks'; +const CATEGORY_MASK_STREAM = 'category_mask'; const IMAGEA_SEGMENTER_GRAPH = 'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph'; +const DEFAULT_OUTPUT_CATEGORY_MASK = false; +const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern +/** + * A callback that receives the computed masks from the interactive segmenter. + * The returned data is only valid for the duration of the callback. If + * asynchronous processing is needed, all data needs to be copied before the + * callback returns. + */ +export type InteractiveSegmenterCallack = + (result: InteractiveSegmenterResult) => void; + /** * Performs interactive segmentation on images. * @@ -69,7 +83,9 @@ const IMAGEA_SEGMENTER_GRAPH = * - batch is always 1 */ export class InteractiveSegmenter extends VisionTaskRunner { - private userCallback: SegmentationMaskCallback = () => {}; + private result: InteractiveSegmenterResult = {width: 0, height: 0}; + private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; + private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private readonly options: ImageSegmenterGraphOptionsProto; private readonly segmenterOptions: SegmenterOptionsProto; @@ -154,12 +170,14 @@ export class InteractiveSegmenter extends VisionTaskRunner { * @return A Promise that resolves when the settings have been applied. */ override setOptions(options: InteractiveSegmenterOptions): Promise { - if (options.outputType === 'CONFIDENCE_MASK') { - this.segmenterOptions.setOutputType( - SegmenterOptionsProto.OutputType.CONFIDENCE_MASK); - } else { - this.segmenterOptions.setOutputType( - SegmenterOptionsProto.OutputType.CATEGORY_MASK); + if ('outputCategoryMask' in options) { + this.outputCategoryMask = + options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK; + } + + if ('outputConfidenceMasks' in options) { + this.outputConfidenceMasks = + options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS; } return super.applyOptions(options); @@ -184,7 +202,7 @@ export class InteractiveSegmenter extends VisionTaskRunner { */ segment( image: ImageSource, roi: RegionOfInterest, - callback: SegmentationMaskCallback): void; + callback: InteractiveSegmenterCallack): void; /** * Performs interactive segmentation on the provided single image and invokes * the callback with the response. The `roi` parameter is used to represent a @@ -213,24 +231,29 @@ export class InteractiveSegmenter extends VisionTaskRunner { segment( image: ImageSource, roi: RegionOfInterest, imageProcessingOptions: ImageProcessingOptions, - callback: SegmentationMaskCallback): void; + callback: InteractiveSegmenterCallack): void; segment( image: ImageSource, roi: RegionOfInterest, imageProcessingOptionsOrCallback: ImageProcessingOptions| - SegmentationMaskCallback, - callback?: SegmentationMaskCallback): void { + InteractiveSegmenterCallack, + callback?: InteractiveSegmenterCallack): void { const imageProcessingOptions = typeof imageProcessingOptionsOrCallback !== 'function' ? imageProcessingOptionsOrCallback : {}; - - this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? + const userCallback = + typeof imageProcessingOptionsOrCallback === 'function' ? imageProcessingOptionsOrCallback : callback!; + this.reset(); this.processRenderData(roi, this.getSynctheticTimestamp()); this.processImageData(image, imageProcessingOptions); - this.userCallback = () => {}; + userCallback(this.result); + } + + private reset(): void { + this.result = {width: 0, height: 0}; } /** Updates the MediaPipe graph configuration. */ @@ -239,7 +262,6 @@ export class InteractiveSegmenter extends VisionTaskRunner { graphConfig.addInputStream(IMAGE_IN_STREAM); graphConfig.addInputStream(ROI_IN_STREAM); graphConfig.addInputStream(NORM_RECT_IN_STREAM); - graphConfig.addOutputStream(IMAGE_OUT_STREAM); const calculatorOptions = new CalculatorOptions(); calculatorOptions.setExtension( @@ -250,24 +272,47 @@ export class InteractiveSegmenter extends VisionTaskRunner { segmenterNode.addInputStream('IMAGE:' + IMAGE_IN_STREAM); segmenterNode.addInputStream('ROI:' + ROI_IN_STREAM); segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_IN_STREAM); - segmenterNode.addOutputStream('GROUPED_SEGMENTATION:' + IMAGE_OUT_STREAM); segmenterNode.setOptions(calculatorOptions); graphConfig.addNode(segmenterNode); - this.graphRunner.attachImageVectorListener( - IMAGE_OUT_STREAM, (masks, timestamp) => { - if (masks.length === 0) { - this.userCallback([], 0, 0); - } else { - this.userCallback( - masks.map(m => m.data), masks[0].width, masks[0].height); - } - this.setLatestOutputTimestamp(timestamp); - }); - this.graphRunner.attachEmptyPacketListener(IMAGE_OUT_STREAM, timestamp => { - this.setLatestOutputTimestamp(timestamp); - }); + if (this.outputConfidenceMasks) { + graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM); + segmenterNode.addOutputStream( + 'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM); + + this.graphRunner.attachImageVectorListener( + CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { + this.result.confidenceMasks = masks.map(m => m.data); + if (masks.length >= 0) { + this.result.width = masks[0].width; + this.result.height = masks[0].height; + } + + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + CONFIDENCE_MASKS_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + } + + if (this.outputCategoryMask) { + graphConfig.addOutputStream(CATEGORY_MASK_STREAM); + segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM); + + this.graphRunner.attachImageListener( + CATEGORY_MASK_STREAM, (mask, timestamp) => { + this.result.categoryMask = mask.data; + this.result.width = mask.width; + this.result.height = mask.height; + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + CATEGORY_MASK_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + } const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts index beb43cd81..269403d97 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts @@ -19,18 +19,9 @@ import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options' /** Options to configure the MediaPipe Interactive Segmenter Task */ export interface InteractiveSegmenterOptions extends TaskRunnerOptions { - /** - * The output type of segmentation results. - * - * The two supported modes are: - * - Category Mask: Gives a single output mask where each pixel represents - * the class which the pixel in the original image was - * predicted to belong to. - * - Confidence Mask: Gives a list of output masks (one for each class). For - * each mask, the pixel represents the prediction - * confidence, usually in the [0.0, 0.1] range. - * - * Defaults to `CATEGORY_MASK`. - */ - outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined; + /** Whether to output confidence masks. Defaults to true. */ + outputConfidenceMasks?: boolean|undefined; + + /** Whether to output the category masks. Defaults to false. */ + outputCategoryMask?: boolean|undefined; } diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts new file mode 100644 index 000000000..f7e1f3a19 --- /dev/null +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts @@ -0,0 +1,37 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** The output result of InteractiveSegmenter. */ +export declare interface InteractiveSegmenterResult { + /** + * Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each + * pixel represents the prediction confidence, usually in the [0, 1] range. + */ + confidenceMasks?: Float32Array[]|WebGLTexture[]; + + /** + * A category mask as a Uint8ClampedArray or WebGLTexture where each + * pixel represents the class which the pixel in the original image was + * predicted to belong to. + */ + categoryMask?: Uint8ClampedArray|WebGLTexture; + + /** The width of the masks. */ + width: number; + + /** The height of the masks. */ + height: number; +} diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts index d6e3a97a5..884be032d 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts @@ -18,7 +18,7 @@ import 'jasmine'; // Placeholder for internal dependency on encodeByteArray import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; -import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils'; import {RenderData as RenderDataProto} from '../../../../util/render_data_pb'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; @@ -37,7 +37,9 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements graph: CalculatorGraphConfig|undefined; fakeWasmModule: SpyWasmModule; - imageVectorListener: + categoryMaskListener: + ((images: WasmImage, timestamp: number) => void)|undefined; + confidenceMasksListener: ((images: WasmImage[], timestamp: number) => void)|undefined; lastRoi?: RenderDataProto; @@ -46,11 +48,16 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements this.fakeWasmModule = this.graphRunner.wasmModule as unknown as SpyWasmModule; - this.attachListenerSpies[0] = + this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('category_mask'); + this.categoryMaskListener = listener; + }); + this.attachListenerSpies[1] = spyOn(this.graphRunner, 'attachImageVectorListener') .and.callFake((stream, listener) => { - expect(stream).toEqual('image_out'); - this.imageVectorListener = listener; + expect(stream).toEqual('confidence_masks'); + this.confidenceMasksListener = listener; }); spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); @@ -79,17 +86,21 @@ describe('InteractiveSegmenter', () => { it('initializes graph', async () => { verifyGraph(interactiveSegmenter); - verifyListenersRegistered(interactiveSegmenter); + + // Verify default options + expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined(); + expect(interactiveSegmenter.confidenceMasksListener).toBeDefined(); }); it('reloads graph when settings are changed', async () => { - await interactiveSegmenter.setOptions({outputType: 'CATEGORY_MASK'}); - verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 1]); - verifyListenersRegistered(interactiveSegmenter); + await interactiveSegmenter.setOptions( + {outputConfidenceMasks: true, outputCategoryMask: false}); + expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined(); + expect(interactiveSegmenter.confidenceMasksListener).toBeDefined(); - await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); - verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 2]); - verifyListenersRegistered(interactiveSegmenter); + await interactiveSegmenter.setOptions( + {outputConfidenceMasks: false, outputCategoryMask: true}); + expect(interactiveSegmenter.categoryMaskListener).toBeDefined(); }); it('can use custom models', async () => { @@ -115,23 +126,6 @@ describe('InteractiveSegmenter', () => { ]); }); - - describe('setOptions()', () => { - const fieldPath = ['segmenterOptions', 'outputType']; - - it(`can set outputType`, async () => { - await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); - verifyGraph(interactiveSegmenter, [fieldPath, 2]); - }); - - it(`can clear outputType`, async () => { - await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); - verifyGraph(interactiveSegmenter, [fieldPath, 2]); - await interactiveSegmenter.setOptions({outputType: undefined}); - verifyGraph(interactiveSegmenter, [fieldPath, 1]); - }); - }); - it('doesn\'t support region of interest', () => { expect(() => { interactiveSegmenter.segment( @@ -153,60 +147,99 @@ describe('InteractiveSegmenter', () => { interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {}); }); - it('supports category masks', (done) => { + it('supports category mask', async () => { const mask = new Uint8ClampedArray([1, 2, 3, 4]); + await interactiveSegmenter.setOptions( + {outputCategoryMask: true, outputConfidenceMasks: false}); + // Pass the test data to our listener interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { - verifyListenersRegistered(interactiveSegmenter); - interactiveSegmenter.imageVectorListener!( - [ - {data: mask, width: 2, height: 2}, - ], - /* timestamp= */ 1337); + expect(interactiveSegmenter.categoryMaskListener).toBeDefined(); + interactiveSegmenter.categoryMaskListener! + ({data: mask, width: 2, height: 2}, + /* timestamp= */ 1337); }); // Invoke the image segmenter - interactiveSegmenter.segment( - {} as HTMLImageElement, ROI, (masks, width, height) => { - expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) - .toHaveBeenCalled(); - expect(masks).toHaveSize(1); - expect(masks[0]).toEqual(mask); - expect(width).toEqual(2); - expect(height).toEqual(2); - done(); - }); + return new Promise(resolve => { + interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => { + expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) + .toHaveBeenCalled(); + expect(result.categoryMask).toEqual(mask); + expect(result.confidenceMasks).not.toBeDefined(); + expect(result.width).toEqual(2); + expect(result.height).toEqual(2); + resolve(); + }); + }); }); it('supports confidence masks', async () => { const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]); const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]); - await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + await interactiveSegmenter.setOptions( + {outputCategoryMask: false, outputConfidenceMasks: true}); // Pass the test data to our listener interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { - verifyListenersRegistered(interactiveSegmenter); - interactiveSegmenter.imageVectorListener!( + expect(interactiveSegmenter.confidenceMasksListener).toBeDefined(); + interactiveSegmenter.confidenceMasksListener!( [ {data: mask1, width: 2, height: 2}, {data: mask2, width: 2, height: 2}, ], 1337); }); + return new Promise(resolve => { + // Invoke the image segmenter + interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => { + expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) + .toHaveBeenCalled(); + expect(result.categoryMask).not.toBeDefined(); + expect(result.confidenceMasks).toEqual([mask1, mask2]); + expect(result.width).toEqual(2); + expect(result.height).toEqual(2); + resolve(); + }); + }); + }); + + it('supports combined category and confidence masks', async () => { + const categoryMask = new Uint8ClampedArray([1, 0]); + const confidenceMask1 = new Float32Array([0.0, 1.0]); + const confidenceMask2 = new Float32Array([1.0, 0.0]); + + await interactiveSegmenter.setOptions( + {outputCategoryMask: true, outputConfidenceMasks: true}); + + // Pass the test data to our listener + interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + expect(interactiveSegmenter.categoryMaskListener).toBeDefined(); + expect(interactiveSegmenter.confidenceMasksListener).toBeDefined(); + interactiveSegmenter.categoryMaskListener! + ({data: categoryMask, width: 1, height: 1}, 1337); + interactiveSegmenter.confidenceMasksListener!( + [ + {data: confidenceMask1, width: 1, height: 1}, + {data: confidenceMask2, width: 1, height: 1}, + ], + 1337); + }); return new Promise(resolve => { // Invoke the image segmenter interactiveSegmenter.segment( - {} as HTMLImageElement, ROI, (masks, width, height) => { + {} as HTMLImageElement, ROI, result => { expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) .toHaveBeenCalled(); - expect(masks).toHaveSize(2); - expect(masks[0]).toEqual(mask1); - expect(masks[1]).toEqual(mask2); - expect(width).toEqual(2); - expect(height).toEqual(2); + expect(result.categoryMask).toEqual(categoryMask); + expect(result.confidenceMasks).toEqual([ + confidenceMask1, confidenceMask2 + ]); + expect(result.width).toEqual(1); + expect(result.height).toEqual(1); resolve(); }); }); diff --git a/mediapipe/util/annotation_renderer.cc b/mediapipe/util/annotation_renderer.cc index 671f47505..5188da896 100644 --- a/mediapipe/util/annotation_renderer.cc +++ b/mediapipe/util/annotation_renderer.cc @@ -56,8 +56,8 @@ bool NormalizedtoPixelCoordinates(double normalized_x, double normalized_y, VLOG(1) << "Normalized coordinates must be between 0.0 and 1.0"; } - *x_px = static_cast(round(normalized_x * image_width)); - *y_px = static_cast(round(normalized_y * image_height)); + *x_px = static_cast(round(normalized_x * image_width)); + *y_px = static_cast(round(normalized_y * image_height)); return true; } diff --git a/mediapipe/util/cpu_util.cc b/mediapipe/util/cpu_util.cc index c1be9793b..052eabb85 100644 --- a/mediapipe/util/cpu_util.cc +++ b/mediapipe/util/cpu_util.cc @@ -43,7 +43,7 @@ ABSL_FLAG(std::string, system_cpu_max_freq_file, namespace mediapipe { namespace { -constexpr uint32 kBufferLength = 64; +constexpr uint32_t kBufferLength = 64; absl::StatusOr GetFilePath(int cpu) { if (!absl::StrContains(absl::GetFlag(FLAGS_system_cpu_max_freq_file), "$0")) { @@ -54,7 +54,7 @@ absl::StatusOr GetFilePath(int cpu) { return absl::Substitute(absl::GetFlag(FLAGS_system_cpu_max_freq_file), cpu); } -absl::StatusOr GetCpuMaxFrequency(int cpu) { +absl::StatusOr GetCpuMaxFrequency(int cpu) { auto path_or_status = GetFilePath(cpu); if (!path_or_status.ok()) { return path_or_status.status(); @@ -65,7 +65,7 @@ absl::StatusOr GetCpuMaxFrequency(int cpu) { char buffer[kBufferLength]; file.getline(buffer, kBufferLength); file.close(); - uint64 frequency; + uint64_t frequency; if (absl::SimpleAtoi(buffer, &frequency)) { return frequency; } else { @@ -79,7 +79,7 @@ absl::StatusOr GetCpuMaxFrequency(int cpu) { } std::set InferLowerOrHigherCoreIds(bool lower) { - std::vector> cpu_freq_pairs; + std::vector> cpu_freq_pairs; for (int cpu = 0; cpu < NumCPUCores(); ++cpu) { auto freq_or_status = GetCpuMaxFrequency(cpu); if (freq_or_status.ok()) { @@ -90,12 +90,12 @@ std::set InferLowerOrHigherCoreIds(bool lower) { return {}; } - absl::c_sort(cpu_freq_pairs, [lower](const std::pair& left, - const std::pair& right) { + absl::c_sort(cpu_freq_pairs, [lower](const std::pair& left, + const std::pair& right) { return (lower && left.second < right.second) || (!lower && left.second > right.second); }); - uint64 edge_freq = cpu_freq_pairs[0].second; + uint64_t edge_freq = cpu_freq_pairs[0].second; std::set inferred_cores; for (const auto& cpu_freq_pair : cpu_freq_pairs) { diff --git a/mediapipe/util/image_frame_util.cc b/mediapipe/util/image_frame_util.cc index a3a038b00..bf2773fdc 100644 --- a/mediapipe/util/image_frame_util.cc +++ b/mediapipe/util/image_frame_util.cc @@ -89,12 +89,12 @@ void ImageFrameToYUVImage(const ImageFrame& image_frame, YUVImage* yuv_image) { const int uv_stride = (uv_width + 15) & ~15; const int y_size = y_stride * height; const int uv_size = uv_stride * uv_height; - uint8* data = - reinterpret_cast(aligned_malloc(y_size + uv_size * 2, 16)); + uint8_t* data = + reinterpret_cast(aligned_malloc(y_size + uv_size * 2, 16)); std::function deallocate = [data]() { aligned_free(data); }; - uint8* y = data; - uint8* u = y + y_size; - uint8* v = u + uv_size; + uint8_t* y = data; + uint8_t* u = y + y_size; + uint8_t* v = u + uv_size; yuv_image->Initialize(libyuv::FOURCC_I420, deallocate, // y, y_stride, // u, uv_stride, // @@ -123,10 +123,11 @@ void ImageFrameToYUVNV12Image(const ImageFrame& image_frame, const int uv_stride = y_stride; const int uv_height = (height + 1) / 2; const int uv_size = uv_stride * uv_height; - uint8* data = reinterpret_cast(aligned_malloc(y_size + uv_size, 16)); + uint8_t* data = + reinterpret_cast(aligned_malloc(y_size + uv_size, 16)); std::function deallocate = [data] { aligned_free(data); }; - uint8* y = data; - uint8* uv = y + y_size; + uint8_t* y = data; + uint8_t* uv = y + y_size; yuv_nv12_image->Initialize(libyuv::FOURCC_NV12, deallocate, y, y_stride, uv, uv_stride, nullptr, 0, width, height); const int rv = libyuv::I420ToNV12( @@ -210,44 +211,44 @@ void YUVImageToImageFrameFromFormat(const YUVImage& yuv_image, } } -void SrgbToMpegYCbCr(const uint8 r, const uint8 g, const uint8 b, // - uint8* y, uint8* cb, uint8* cr) { +void SrgbToMpegYCbCr(const uint8_t r, const uint8_t g, const uint8_t b, // + uint8_t* y, uint8_t* cb, uint8_t* cr) { // ITU-R BT.601 conversion from sRGB to YCbCr. // FastIntRound is used rather than SafeRound since the possible // range of values is [16,235] for Y and [16,240] for Cb and Cr and we // don't care about the rounding direction for values exactly between // two integers. - *y = static_cast( + *y = static_cast( mediapipe::MathUtil::FastIntRound(16.0 + // 65.481 * r / 255.0 + // 128.553 * g / 255.0 + // 24.966 * b / 255.0)); - *cb = static_cast( + *cb = static_cast( mediapipe::MathUtil::FastIntRound(128.0 + // -37.797 * r / 255.0 + // -74.203 * g / 255.0 + // 112.0 * b / 255.0)); - *cr = static_cast( + *cr = static_cast( mediapipe::MathUtil::FastIntRound(128.0 + // 112.0 * r / 255.0 + // -93.786 * g / 255.0 + // -18.214 * b / 255.0)); } -void MpegYCbCrToSrgb(const uint8 y, const uint8 cb, const uint8 cr, // - uint8* r, uint8* g, uint8* b) { +void MpegYCbCrToSrgb(const uint8_t y, const uint8_t cb, const uint8_t cr, // + uint8_t* r, uint8_t* g, uint8_t* b) { // ITU-R BT.601 conversion from YCbCr to sRGB // Use SafeRound since many MPEG YCbCr values do not correspond directly // to an sRGB value. - *r = mediapipe::MathUtil::SafeRound( // - 255.0 / 219.0 * (y - 16.0) + // + *r = mediapipe::MathUtil::SafeRound( // + 255.0 / 219.0 * (y - 16.0) + // 255.0 / 112.0 * 0.701 * (cr - 128.0)); - *g = mediapipe::MathUtil::SafeRound( + *g = mediapipe::MathUtil::SafeRound( 255.0 / 219.0 * (y - 16.0) - // 255.0 / 112.0 * 0.886 * 0.114 / 0.587 * (cb - 128.0) - // 255.0 / 112.0 * 0.701 * 0.299 / 0.587 * (cr - 128.0)); - *b = mediapipe::MathUtil::SafeRound( // - 255.0 / 219.0 * (y - 16.0) + // + *b = mediapipe::MathUtil::SafeRound( // + 255.0 / 219.0 * (y - 16.0) + // 255.0 / 112.0 * 0.886 * (cb - 128.0)); } @@ -260,15 +261,15 @@ void MpegYCbCrToSrgb(const uint8 y, const uint8 cb, const uint8 cr, // cv::Mat GetSrgbToLinearRgb16Lut() { cv::Mat lut(1, 256, CV_16UC1); - uint16* ptr = lut.ptr(); + uint16_t* ptr = lut.ptr(); constexpr double kUint8Max = 255.0; constexpr double kUint16Max = 65535.0; for (int i = 0; i < 256; ++i) { if (i < 0.04045 * kUint8Max) { - ptr[i] = static_cast( + ptr[i] = static_cast( (static_cast(i) / kUint8Max / 12.92) * kUint16Max + .5); } else { - ptr[i] = static_cast( + ptr[i] = static_cast( pow((static_cast(i) / kUint8Max + 0.055) / 1.055, 2.4) * kUint16Max + .5); @@ -279,15 +280,15 @@ cv::Mat GetSrgbToLinearRgb16Lut() { cv::Mat GetLinearRgb16ToSrgbLut() { cv::Mat lut(1, 65536, CV_8UC1); - uint8* ptr = lut.ptr(); + uint8_t* ptr = lut.ptr(); constexpr double kUint8Max = 255.0; constexpr double kUint16Max = 65535.0; for (int i = 0; i < 65536; ++i) { if (i < 0.0031308 * kUint16Max) { - ptr[i] = static_cast( + ptr[i] = static_cast( (static_cast(i) / kUint16Max * 12.92) * kUint8Max + .5); } else { - ptr[i] = static_cast( + ptr[i] = static_cast( (1.055 * pow(static_cast(i) / kUint16Max, 1.0 / 2.4) - .055) * kUint8Max + .5); @@ -306,13 +307,13 @@ void LinearRgb16ToSrgb(const cv::Mat& source, cv::Mat* destination) { destination->create(source.size(), CV_8UC(source.channels())); static const cv::Mat kLut = GetLinearRgb16ToSrgbLut(); - const uint8* lookup_table_ptr = kLut.ptr(); + const uint8_t* lookup_table_ptr = kLut.ptr(); const int num_channels = source.channels(); for (int row = 0; row < source.rows; ++row) { for (int col = 0; col < source.cols; ++col) { for (int channel = 0; channel < num_channels; ++channel) { - uint8* ptr = destination->ptr(row); - const uint16* ptr16 = source.ptr(row); + uint8_t* ptr = destination->ptr(row); + const uint16_t* ptr16 = source.ptr(row); ptr[col * num_channels + channel] = lookup_table_ptr[ptr16[col * num_channels + channel]]; } diff --git a/mediapipe/util/image_test_utils.cc b/mediapipe/util/image_test_utils.cc index 815666985..77b755953 100644 --- a/mediapipe/util/image_test_utils.cc +++ b/mediapipe/util/image_test_utils.cc @@ -43,14 +43,14 @@ mediapipe::ImageFormat::Format GetImageFormat(int image_channels) { Packet MakeImageFramePacket(cv::Mat input, int timestamp) { ImageFrame input_image(GetImageFormat(input.channels()), input.cols, - input.rows, input.step, input.data, [](uint8*) {}); + input.rows, input.step, input.data, [](uint8_t*) {}); return MakePacket(std::move(input_image)).At(Timestamp(0)); } Packet MakeImagePacket(cv::Mat input, int timestamp) { mediapipe::Image input_image(std::make_shared( GetImageFormat(input.channels()), input.cols, input.rows, input.step, - input.data, [](uint8*) {})); + input.data, [](uint8_t*) {})); return MakePacket(std::move(input_image)).At(Timestamp(0)); } diff --git a/mediapipe/util/label_map_util.cc b/mediapipe/util/label_map_util.cc index 914a2ba76..eb909349d 100644 --- a/mediapipe/util/label_map_util.cc +++ b/mediapipe/util/label_map_util.cc @@ -25,7 +25,7 @@ namespace mediapipe { -absl::StatusOr> BuildLabelMapFromFiles( +absl::StatusOr> BuildLabelMapFromFiles( absl::string_view labels_file_contents, absl::string_view display_names_file) { if (labels_file_contents.empty()) { @@ -68,7 +68,7 @@ absl::StatusOr> BuildLabelMapFromFiles( label_map_items[i].set_display_name(display_names[i]); } } - proto_ns::Map label_map; + proto_ns::Map label_map; for (int i = 0; i < label_map_items.size(); ++i) { label_map[i] = label_map_items[i]; } diff --git a/third_party/flatbuffers/BUILD.bazel b/third_party/flatbuffers/BUILD.bazel index 53cecc734..d5264a026 100644 --- a/third_party/flatbuffers/BUILD.bazel +++ b/third_party/flatbuffers/BUILD.bazel @@ -45,12 +45,16 @@ filegroup( "include/flatbuffers/bfbs_generator.h", "include/flatbuffers/buffer.h", "include/flatbuffers/buffer_ref.h", + "include/flatbuffers/code_generator.h", "include/flatbuffers/code_generators.h", "include/flatbuffers/default_allocator.h", "include/flatbuffers/detached_buffer.h", "include/flatbuffers/flatbuffer_builder.h", "include/flatbuffers/flatbuffers.h", + "include/flatbuffers/flatc.h", + "include/flatbuffers/flex_flat_util.h", "include/flatbuffers/flexbuffers.h", + "include/flatbuffers/grpc.h", "include/flatbuffers/hash.h", "include/flatbuffers/idl.h", "include/flatbuffers/minireflect.h", diff --git a/third_party/flatbuffers/workspace.bzl b/third_party/flatbuffers/workspace.bzl index 703cb0536..02247268b 100644 --- a/third_party/flatbuffers/workspace.bzl +++ b/third_party/flatbuffers/workspace.bzl @@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive") def repo(): third_party_http_archive( name = "flatbuffers", - strip_prefix = "flatbuffers-2.0.6", - sha256 = "e2dc24985a85b278dd06313481a9ca051d048f9474e0f199e372fea3ea4248c9", + strip_prefix = "flatbuffers-23.1.21", + sha256 = "d84cb25686514348e615163b458ae0767001b24b42325f426fd56406fd384238", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v2.0.6.tar.gz", - "https://github.com/google/flatbuffers/archive/v2.0.6.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v23.1.21.tar.gz", + "https://github.com/google/flatbuffers/archive/v23.1.21.tar.gz", ], build_file = "//third_party/flatbuffers:BUILD.bazel", delete = ["build_defs.bzl", "BUILD.bazel"], diff --git a/third_party/wasm_files.bzl b/third_party/wasm_files.bzl index 148b5970f..a484d2f82 100644 --- a/third_party/wasm_files.bzl +++ b/third_party/wasm_files.bzl @@ -12,72 +12,72 @@ def wasm_files(): http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_js", - sha256 = "0eca68e2291a548b734bcab5db4c9e6b997e852ea7e19228003b9e2a78c7c646", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1681328323089931"], + sha256 = "b810de53d7ccf991b9c70fcdf7e88b5c3f2942ae766436f22be48159b6a7e687", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1681849488227617"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm", - sha256 = "69bc95af5b783b510ec1842d6fb9594254907d8e1334799c5753164878a7dcac", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1681328325829340"], + sha256 = "26d91147e5c6c8a92e0a4ebf59599068a3cff6108847b793ef33ac23e98eddb9", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1681849491546937"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js", - sha256 = "88a0176cc80d6a1eb175a5105df705cf8b8684cf13f6db0a264af0b67b65a22a", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1681328328330829"], + sha256 = "b38e37b3024692558eaaba159921fedd3297d1a09bba1c16a06fed327845b0bd", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1681849494099698"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm", - sha256 = "1cc0c3db7d252801be4b090d8bbba61f308cc3dd5efe197319581d3af29495c7", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1681328331085637"], + sha256 = "6a8e73d2e926565046e16adf1748f0f8ec5135fafe7eb8b9c83892e64c1a449a", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1681849496451970"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_js", - sha256 = "d9cd100b6d330d36f7749fe5fc64a2cdd0abb947a0376e6140784cfb0361a4e2", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1681328333442454"], + sha256 = "785cba67b623b1dc66dc3621e97fd6b30edccbb408184a3094d0aa68ddd5becb", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1681849498746265"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_wasm", - sha256 = "30a2fcca630bdad6e99173ea7d0d8c5d7086aedf393d0159fa05bf9d08d4ff65", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1681328335803336"], + sha256 = "a858b8a2e8b40e9c936b66566c5aefd396536c4e936459ab9ae7e239621adc14", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1681849501370461"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js", - sha256 = "70ca2bd15c56e0ce7bb10ff2188b4a1f9eafbb657eb9424e4cab8d7b29179871", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1681328338162884"], + sha256 = "5292f1442d5e5c037e7cffb78a8c2d71255348ca2c3bd759b314bdbedd5590c2", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1681849503379116"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm", - sha256 = "8221b385905f36a769d7731a0adbe18b681bcb873561890429ca84278c67c3fd", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1681328340808115"], + sha256 = "e44b48ab29ee1d8befec804e9a63445c56266b679d19fb476d556ca621f0e493", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1681849505997020"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_js", - sha256 = "07692acd8202adafebd35dbcd7e2b8e88a76d4a0e6b9229cb3cad59503eeddc7", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1681328343147709"], + sha256 = "205855eba70464a92b9d00e90acac15c51a9f76192f900e697304ac6dea8f714", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1681849508414277"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm", - sha256 = "03bf553fa6a768b0d70103a5e7d835b6b37371ff44e201c3392f22e0879737c3", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1681328345605574"], + sha256 = "c0cbd0df3adb2a9cd1331d14f522d2bae9f8adc9f1b35f92cbbc4b782b190cef", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1681849510936608"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js", - sha256 = "36697be14f921985eac15d1447ec8a260817b05ade1c9bb3ca7e906e0f047ec0", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1681328348025082"], + sha256 = "0969812de4d3573198fa2eba4f5b0a7e97e98f97bd4215d876543f4925e57b84", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1681849513292639"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm", - sha256 = "103fb145438d61cfecb2e8db3f06b43a5d77a7e3fcea940437fe272227cf2592", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1681328350709881"], + sha256 = "f2ab62c3f8dabab0a573dadf5c105ff81a03c29c70f091f8cf273ae030c0a86f", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1681849515999000"], )