From 5a6837d034f9583e2f43659c388638ac14ad0b7e Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 16 Nov 2022 22:08:52 -0800 Subject: [PATCH 001/346] Fix errors that will occur in python 3.11 --- mediapipe/tasks/python/audio/audio_classifier.py | 3 ++- mediapipe/tasks/python/audio/audio_embedder.py | 3 ++- mediapipe/tasks/python/text/text_classifier.py | 4 +++- mediapipe/tasks/python/text/text_embedder.py | 4 +++- mediapipe/tasks/python/vision/gesture_recognizer.py | 6 ++++-- mediapipe/tasks/python/vision/image_classifier.py | 3 ++- mediapipe/tasks/python/vision/image_embedder.py | 3 ++- 7 files changed, 18 insertions(+), 8 deletions(-) diff --git a/mediapipe/tasks/python/audio/audio_classifier.py b/mediapipe/tasks/python/audio/audio_classifier.py index 7955cc4dc..2dd1cc4a3 100644 --- a/mediapipe/tasks/python/audio/audio_classifier.py +++ b/mediapipe/tasks/python/audio/audio_classifier.py @@ -70,7 +70,8 @@ class AudioClassifierOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - classifier_options: _ClassifierOptions = _ClassifierOptions() + classifier_options: Optional[_ClassifierOptions] = dataclasses.field( + default_factory=lambda: _ClassifierOptions()) result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None @doc_controls.do_not_generate_docs diff --git a/mediapipe/tasks/python/audio/audio_embedder.py b/mediapipe/tasks/python/audio/audio_embedder.py index a774d71e9..4484064ee 100644 --- a/mediapipe/tasks/python/audio/audio_embedder.py +++ b/mediapipe/tasks/python/audio/audio_embedder.py @@ -71,7 +71,8 @@ class AudioEmbedderOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - embedder_options: _EmbedderOptions = _EmbedderOptions() + embedder_options: Optional[_EmbedderOptions] = dataclasses.field( + default_factory=lambda: _EmbedderOptions()) result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None @doc_controls.do_not_generate_docs diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py index 92d547f20..c6095e1c3 100644 --- a/mediapipe/tasks/python/text/text_classifier.py +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -14,6 +14,7 @@ """MediaPipe text classifier task.""" import dataclasses +from typing import Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter @@ -48,7 +49,8 @@ class TextClassifierOptions: classifier_options: Options for the text classification task. """ base_options: _BaseOptions - classifier_options: _ClassifierOptions = _ClassifierOptions() + classifier_options: Optional[_ClassifierOptions] = dataclasses.field( + default_factory=lambda: _ClassifierOptions()) @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextClassifierGraphOptionsProto: diff --git a/mediapipe/tasks/python/text/text_embedder.py b/mediapipe/tasks/python/text/text_embedder.py index f3e5eecbe..1a32796a3 100644 --- a/mediapipe/tasks/python/text/text_embedder.py +++ b/mediapipe/tasks/python/text/text_embedder.py @@ -14,6 +14,7 @@ """MediaPipe text embedder task.""" import dataclasses +from typing import Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter @@ -49,7 +50,8 @@ class TextEmbedderOptions: embedder_options: Options for the text embedder task. """ base_options: _BaseOptions - embedder_options: _EmbedderOptions = _EmbedderOptions() + embedder_options: Optional[_EmbedderOptions] = dataclasses.field( + default_factory=lambda: _EmbedderOptions()) @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextEmbedderGraphOptionsProto: diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index 9b6fd8cab..8addebe4c 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -181,9 +181,11 @@ class GestureRecognizerOptions: min_hand_presence_confidence: Optional[float] = 0.5 min_tracking_confidence: Optional[float] = 0.5 canned_gesture_classifier_options: Optional[ - _ClassifierOptions] = _ClassifierOptions() + _ClassifierOptions] = dataclasses.field( + default_factory=lambda: _ClassifierOptions()) custom_gesture_classifier_options: Optional[ - _ClassifierOptions] = _ClassifierOptions() + _ClassifierOptions] = dataclasses.field( + default_factory=lambda: _ClassifierOptions()) result_callback: Optional[Callable[ [GestureRecognizerResult, image_module.Image, int], None]] = None diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 763160e1e..d3c2965ba 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -70,7 +70,8 @@ class ImageClassifierOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - classifier_options: _ClassifierOptions = _ClassifierOptions() + classifier_options: Optional[_ClassifierOptions] = dataclasses.field( + default_factory=lambda: _ClassifierOptions()) result_callback: Optional[Callable[ [ImageClassifierResult, image_module.Image, int], None]] = None diff --git a/mediapipe/tasks/python/vision/image_embedder.py b/mediapipe/tasks/python/vision/image_embedder.py index f299fa590..06624d16e 100644 --- a/mediapipe/tasks/python/vision/image_embedder.py +++ b/mediapipe/tasks/python/vision/image_embedder.py @@ -69,7 +69,8 @@ class ImageEmbedderOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - embedder_options: _EmbedderOptions = _EmbedderOptions() + embedder_options: Optional[_EmbedderOptions] = dataclasses.field( + default_factory=lambda: _EmbedderOptions()) result_callback: Optional[Callable[ [ImageEmbedderResult, image_module.Image, int], None]] = None From 1fb0902aa06d45ebc73f5337d9f65f06c418c24b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 17 Nov 2022 14:01:14 -0800 Subject: [PATCH 002/346] Update gesture_recognizer test PiperOrigin-RevId: 489301508 --- .../vision/gesture_recognizer/gesture_recognizer_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 8a6e474d7..39272cbbc 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -14,6 +14,7 @@ import io import os +import random import tempfile from unittest import mock as unittest_mock import zipfile @@ -41,6 +42,7 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() + random.seed(1234) all_data = self._load_data() # Splits data, 90% data for training, 10% for validation self._train_data, self._validation_data = all_data.split(0.9) @@ -93,11 +95,11 @@ class GestureRecognizerTest(tf.test.TestCase): tflite_file=gesture_classifier_tflite_file, size=[1, model.embedding_size]) - def _test_accuracy(self, model, threshold=0.25): + def _test_accuracy(self, model, threshold=0.0): # Test on _train_data because of our limited dataset size _, accuracy = model.evaluate(self._train_data) tf.compat.v1.logging.info(f'train accuracy: {accuracy}') - self.assertGreaterEqual(accuracy, threshold) + self.assertGreater(accuracy, threshold) @unittest_mock.patch.object( gesture_recognizer.hyperparameters, From a7bd725e65e34ea416b15ceeffed972a2b205071 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 17 Nov 2022 16:06:04 -0800 Subject: [PATCH 003/346] Internal change PiperOrigin-RevId: 489331826 --- mediapipe/gpu/gl_context.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 91d2837c5..53e3ff8b7 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -290,8 +290,15 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { // some Emscripten cases), there might be some existing tripped error. ForceClearExistingGlErrors(); - absl::string_view version_string( - reinterpret_cast(glGetString(GL_VERSION))); + absl::string_view version_string; + const GLubyte* version_string_ptr = glGetString(GL_VERSION); + if (version_string_ptr != nullptr) { + version_string = reinterpret_cast(version_string_ptr); + } else { + // This may happen when using SwiftShader, but the numeric versions are + // available and will be used instead. + LOG(WARNING) << "failed to get GL_VERSION string"; + } // We will decide later whether we want to use the version numbers we query // for, or instead derive that information from the context creation result, @@ -333,7 +340,7 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { } LOG(INFO) << "GL version: " << gl_major_version_ << "." << gl_minor_version_ - << " (" << glGetString(GL_VERSION) << ")"; + << " (" << version_string << ")"; { auto status = GetGlExtensions(); if (!status.ok()) { From ab3a5f0fbf1883c4d1dfe1df2db80a7045a390c4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 17 Nov 2022 16:28:08 -0800 Subject: [PATCH 004/346] Make MuxCalculator with DefaultInputStreamHandler to handle graph closure gracefully PiperOrigin-RevId: 489336722 --- mediapipe/calculators/core/mux_calculator.cc | 4 ++++ .../calculators/core/mux_calculator_test.cc | 16 ++++++---------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mediapipe/calculators/core/mux_calculator.cc b/mediapipe/calculators/core/mux_calculator.cc index a0ce2ae34..88b04a32b 100644 --- a/mediapipe/calculators/core/mux_calculator.cc +++ b/mediapipe/calculators/core/mux_calculator.cc @@ -41,6 +41,10 @@ class MuxCalculator : public Node { StreamHandler("MuxInputStreamHandler")); absl::Status Process(CalculatorContext* cc) final { + if (kSelect(cc).IsStream() && kSelect(cc).IsEmpty()) { + return absl::OkStatus(); + } + int select = *kSelect(cc); RET_CHECK(0 <= select && select < kIn(cc).Count()); if (!kIn(cc)[select].IsEmpty()) { diff --git a/mediapipe/calculators/core/mux_calculator_test.cc b/mediapipe/calculators/core/mux_calculator_test.cc index a3ac8a27a..6b9434be9 100644 --- a/mediapipe/calculators/core/mux_calculator_test.cc +++ b/mediapipe/calculators/core/mux_calculator_test.cc @@ -439,7 +439,7 @@ TEST(MuxCalculatorTest, HandlesCloseGracefully) { EXPECT_TRUE(output_packets.empty()); } -TEST(MuxCalculatorTest, CrashesOnCloseWithDeafultInputStreamHandler) { +TEST(MuxCalculatorTest, HandlesCloseGracefullyWithDeafultInputStreamHandler) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie( R"pb( @@ -480,15 +480,11 @@ TEST(MuxCalculatorTest, CrashesOnCloseWithDeafultInputStreamHandler) { MP_ASSERT_OK(graph.AddPacketToInputStream( "value_0", MakePacket(0).At(Timestamp(1000)))); MP_ASSERT_OK(graph.WaitUntilIdle()); - // Currently MuxCalculator crashes with a correct packet set from - // DefaultInputStreamHandler. The SELECT packet is missing at Timestamp 1000, - // and an empty packet is the correct representation of that. - EXPECT_DEATH( - { - (void)graph.CloseAllInputStreams(); - (void)graph.WaitUntilDone(); - }, - "Check failed: payload_"); + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); + + ASSERT_EQ(output_packets.size(), 1); + EXPECT_TRUE(output_packets[0].IsEmpty()); } } // namespace From 6f3cb340e153af68c31462a337ee0bf1c113f7cd Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 17 Nov 2022 17:14:56 -0800 Subject: [PATCH 005/346] Internal change PiperOrigin-RevId: 489345940 --- .../tasks/web/audio/audio_classifier/BUILD | 2 +- .../audio_classifier/audio_classifier.ts | 2 +- mediapipe/tasks/web/core/BUILD | 4 ++-- mediapipe/tasks/web/core/task_runner.ts | 6 +++--- .../tasks/web/text/text_classifier/BUILD | 2 +- .../text/text_classifier/text_classifier.ts | 2 +- mediapipe/tasks/web/text/text_embedder/BUILD | 2 +- .../web/text/text_embedder/text_embedder.ts | 2 +- mediapipe/tasks/web/vision/core/BUILD | 2 +- .../web/vision/core/vision_task_runner.ts | 2 +- .../tasks/web/vision/gesture_recognizer/BUILD | 2 +- .../gesture_recognizer/gesture_recognizer.ts | 2 +- .../tasks/web/vision/hand_landmarker/BUILD | 2 +- .../vision/hand_landmarker/hand_landmarker.ts | 2 +- .../tasks/web/vision/image_classifier/BUILD | 2 +- .../image_classifier/image_classifier.ts | 2 +- .../tasks/web/vision/image_embedder/BUILD | 2 +- .../vision/image_embedder/image_embedder.ts | 2 +- .../tasks/web/vision/object_detector/BUILD | 2 +- .../vision/object_detector/object_detector.ts | 2 +- mediapipe/web/graph_runner/BUILD | 20 ++++++------------- ...{wasm_mediapipe_lib.ts => graph_runner.ts} | 14 ++++++------- ...image_lib.ts => graph_runner_image_lib.ts} | 10 +++++----- .../register_model_resources_graph_service.ts | 10 +++++----- 24 files changed, 46 insertions(+), 54 deletions(-) rename mediapipe/web/graph_runner/{wasm_mediapipe_lib.ts => graph_runner.ts} (99%) rename mediapipe/web/graph_runner/{wasm_mediapipe_image_lib.ts => graph_runner_image_lib.ts} (83%) diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 412af3bea..9e1fcbc51 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 76b926723..5533b0eaa 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -23,7 +23,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/ import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {AudioClassifierOptions} from './audio_classifier_options'; diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index e9ef85d46..6eca8bb4a 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -18,9 +18,9 @@ mediapipe_ts_library( "task_runner.ts", ], deps = [ + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", - "//mediapipe/web/graph_runner:wasm_mediapipe_image_lib_ts", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index c948930fc..67aa4e4df 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -15,12 +15,12 @@ */ import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; -import {SupportImage} from '../../../web/graph_runner/wasm_mediapipe_image_lib'; -import {WasmMediaPipeLib, WasmModule} from '../../../web/graph_runner/wasm_mediapipe_lib'; +import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; +import {GraphRunner, WasmModule} from '../../../web/graph_runner/graph_runner'; // tslint:disable-next-line:enforce-name-casing const WasmMediaPipeImageLib = - SupportModelResourcesGraphService(SupportImage(WasmMediaPipeLib)); + SupportModelResourcesGraphService(SupportImage(GraphRunner)); /** Base class for all MediaPipe Tasks. */ export abstract class TaskRunner extends WasmMediaPipeImageLib { diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 8c3b8e226..71ef02c92 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -26,7 +26,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index d4f413efa..04789f5e1 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -23,7 +23,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/ import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {TextClassifierOptions} from './text_classifier_options'; diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 17b5eac06..c555f8d33 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 7c631683d..57b91d575 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -23,7 +23,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {TextEmbedderOptions} from './text_embedder_options'; diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index e3a5edf33..1d8944f14 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -21,6 +21,6 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 372ce9ba7..79ff45156 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -17,7 +17,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {VisionTaskOptions} from './vision_task_options'; diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index f2b668239..ddfd1a327 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -32,7 +32,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 8e745534e..dd050d0f1 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -31,7 +31,7 @@ import {Landmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {GestureRecognizerOptions} from './gesture_recognizer_options'; diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index 36f1d7eb7..1849687c5 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -27,7 +27,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 0aba5c82c..32b1eed4b 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -27,7 +27,7 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark} from '../../../../tasks/web/components/containers/landmark'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {HandLandmarkerOptions} from './hand_landmarker_options'; diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index e7e830332..ebe64ecf4 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 0011e9c55..b59cb6fb1 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -23,7 +23,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/ import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageClassifierOptions} from './image_classifier_options'; diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index ce1c25700..feb3ae054 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/vision/core:vision_task_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index d17bc72fa..c60665052 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -23,7 +23,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageEmbedderOptions} from './image_embedder_options'; diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index 0975a9fd4..b6bef6bfa 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -22,7 +22,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index e6cbd8627..44046cd1e 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -21,7 +21,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ObjectDetectorOptions} from './object_detector_options'; diff --git a/mediapipe/web/graph_runner/BUILD b/mediapipe/web/graph_runner/BUILD index dab6be50f..5c12947af 100644 --- a/mediapipe/web/graph_runner/BUILD +++ b/mediapipe/web/graph_runner/BUILD @@ -3,32 +3,24 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = [ - ":internal", "//mediapipe/tasks:internal", ]) -package_group( - name = "internal", - packages = [ - "//mediapipe/app/pursuit/wasm/web_ml_cpu/typescript/...", - ], -) - mediapipe_ts_library( - name = "wasm_mediapipe_lib_ts", + name = "graph_runner_ts", srcs = [ - ":wasm_mediapipe_lib.ts", + ":graph_runner.ts", ], allow_unoptimized_namespaces = True, ) mediapipe_ts_library( - name = "wasm_mediapipe_image_lib_ts", + name = "graph_runner_image_lib_ts", srcs = [ - ":wasm_mediapipe_image_lib.ts", + ":graph_runner_image_lib.ts", ], allow_unoptimized_namespaces = True, - deps = [":wasm_mediapipe_lib_ts"], + deps = [":graph_runner_ts"], ) mediapipe_ts_library( @@ -37,5 +29,5 @@ mediapipe_ts_library( ":register_model_resources_graph_service.ts", ], allow_unoptimized_namespaces = True, - deps = [":wasm_mediapipe_lib_ts"], + deps = [":graph_runner_ts"], ) diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts b/mediapipe/web/graph_runner/graph_runner.ts similarity index 99% rename from mediapipe/web/graph_runner/wasm_mediapipe_lib.ts rename to mediapipe/web/graph_runner/graph_runner.ts index 5f8040a33..7de5aa33b 100644 --- a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -129,7 +129,7 @@ declare global { declare function importScripts(...urls: Array): void; /** - * Valid types of image sources which we can run our WasmMediaPipeLib over. + * Valid types of image sources which we can run our GraphRunner over. */ export type ImageSource = HTMLCanvasElement|HTMLVideoElement|HTMLImageElement|ImageData|ImageBitmap; @@ -138,7 +138,7 @@ export type ImageSource = /** A listener that will be invoked with an absl::StatusCode and message. */ export type ErrorListener = (code: number, message: string) => void; -// Internal type of constructors used for initializing WasmMediaPipeLib and +// Internal type of constructors used for initializing GraphRunner and // subclasses. type WasmMediaPipeConstructor = (new ( @@ -151,7 +151,7 @@ type WasmMediaPipeConstructor = * into canvas, or else return the output WebGLTexture. Takes a WebAssembly * Module (must be instantiated to self.Module). */ -export class WasmMediaPipeLib { +export class GraphRunner { // TODO: These should be protected/private, but are left exposed for // now so that we can use proper TS mixins with this class as a base. This // should be somewhat fixed when we create our .d.ts files. @@ -989,7 +989,7 @@ async function runScript(scriptUrl: string) { /** * Global function to initialize Wasm blob and load runtime assets for a * specialized MediaPipe library. This allows us to create a requested - * subclass inheriting from WasmMediaPipeLib. + * subclass inheriting from GraphRunner. * @param constructorFcn The name of the class to instantiate via "new". * @param wasmLoaderScript Url for the wasm-runner script; produced by the build * process. @@ -1043,12 +1043,12 @@ export async function createMediaPipeLib( * @return promise A promise which will resolve when initialization has * completed successfully. */ -export async function createWasmMediaPipeLib( +export async function createGraphRunner( wasmLoaderScript?: string, assetLoaderScript?: string, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, - fileLocator?: FileLocator): Promise { + fileLocator?: FileLocator): Promise { return createMediaPipeLib( - WasmMediaPipeLib, wasmLoaderScript, assetLoaderScript, glCanvas, + GraphRunner, wasmLoaderScript, assetLoaderScript, glCanvas, fileLocator); } diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts b/mediapipe/web/graph_runner/graph_runner_image_lib.ts similarity index 83% rename from mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts rename to mediapipe/web/graph_runner/graph_runner_image_lib.ts index 3b45e8230..e886999cb 100644 --- a/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner_image_lib.ts @@ -1,12 +1,12 @@ -import {ImageSource, WasmMediaPipeLib} from './wasm_mediapipe_lib'; +import {ImageSource, GraphRunner} from './graph_runner'; /** - * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has + * We extend from a GraphRunner constructor. This ensures our mixin has * access to the wasmModule, among other things. The `any` type is required for * mixin constructors. */ // tslint:disable-next-line:no-any -type LibConstructor = new (...args: any[]) => WasmMediaPipeLib; +type LibConstructor = new (...args: any[]) => GraphRunner; /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler @@ -19,10 +19,10 @@ export declare interface WasmImageModule { } /** - * An implementation of WasmMediaPipeLib that supports binding GPU image data as + * An implementation of GraphRunner that supports binding GPU image data as * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for * effective multiple inheritance. Example usage: - * `const WasmMediaPipeImageLib = SupportImage(WasmMediaPipeLib);` + * `const WasmMediaPipeImageLib = SupportImage(GraphRunner);` */ // tslint:disable-next-line:enforce-name-casing export function SupportImage(Base: TBase) { diff --git a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts index e85d63b06..bc9c93e8a 100644 --- a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts +++ b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts @@ -1,12 +1,12 @@ -import {WasmMediaPipeLib} from './wasm_mediapipe_lib'; +import {GraphRunner} from './graph_runner'; /** - * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has + * We extend from a GraphRunner constructor. This ensures our mixin has * access to the wasmModule, among other things. The `any` type is required for * mixin constructors. */ // tslint:disable-next-line:no-any -type LibConstructor = new (...args: any[]) => WasmMediaPipeLib; +type LibConstructor = new (...args: any[]) => GraphRunner; /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler @@ -17,11 +17,11 @@ export declare interface WasmModuleRegisterModelResources { } /** - * An implementation of WasmMediaPipeLib that supports registering model + * An implementation of GraphRunner that supports registering model * resources to a cache, in the form of a GraphService C++-side. We implement as * a proper TS mixin, to allow for effective multiple inheritance. Sample usage: * `const WasmMediaPipeImageLib = SupportModelResourcesGraphService( - * WasmMediaPipeLib);` + * GraphRunner);` */ // tslint:disable:enforce-name-casing export function SupportModelResourcesGraphService( From efcdedbd59a135d757a49b0ff27b656e793386ad Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Thu, 17 Nov 2022 18:14:58 -0800 Subject: [PATCH 006/346] Remove redundant _ios targets PiperOrigin-RevId: 489355333 --- mediapipe/gpu/BUILD | 14 -------------- mediapipe/objc/BUILD | 4 ++-- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 4fb59f1b5..27d91f21a 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -810,20 +810,6 @@ cc_library( }), ) -# TODO: remove -objc_library( - name = "gl_calculator_helper_ios", - copts = [ - "-Wno-shorten-64-to-32", - ], - visibility = ["//visibility:public"], - deps = [ - ":gl_calculator_helper", - "//mediapipe/objc:mediapipe_framework_ios", - "//mediapipe/objc:util", - ], -) - objc_library( name = "MPPMetalHelper", srcs = ["MPPMetalHelper.mm"], diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index 48c9b181a..d77692164 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -147,7 +147,7 @@ objc_library( visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":mediapipe_framework_ios", - "//mediapipe/gpu:gl_calculator_helper_ios", + "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_simple_shaders", ], @@ -173,7 +173,7 @@ objc_library( deps = [ ":mediapipe_framework_ios", ":mediapipe_gl_view_renderer", - "//mediapipe/gpu:gl_calculator_helper_ios", + "//mediapipe/gpu:gl_calculator_helper", ], ) From ae44012c0c5a53916f9ee01b3c745868836c784b Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Fri, 18 Nov 2022 08:39:37 -0800 Subject: [PATCH 007/346] Allowing BypassCalculator to accept InputSidePackets. PiperOrigin-RevId: 489483992 --- mediapipe/calculators/core/bypass_calculator.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/calculators/core/bypass_calculator.cc b/mediapipe/calculators/core/bypass_calculator.cc index efc0612ec..4e007329b 100644 --- a/mediapipe/calculators/core/bypass_calculator.cc +++ b/mediapipe/calculators/core/bypass_calculator.cc @@ -111,6 +111,10 @@ class BypassCalculator : public Node { cc->Outputs().Get(id).SetAny(); } } + for (auto id = cc->InputSidePackets().BeginId(); + id != cc->InputSidePackets().EndId(); ++id) { + cc->InputSidePackets().Get(id).SetAny(); + } return absl::OkStatus(); } From e046982a3c6706625c997df50e51e19157624ac7 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 18 Nov 2022 08:44:02 -0800 Subject: [PATCH 008/346] Internal change PiperOrigin-RevId: 489484898 --- .../tensor/audio_to_tensor_calculator.cc | 49 ++++++++++++++++--- .../tensor/audio_to_tensor_calculator.proto | 13 +++++ 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc index d0513518a..9cb23a393 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc @@ -43,6 +43,7 @@ namespace api2 { namespace { using Options = ::mediapipe::AudioToTensorCalculatorOptions; +using DftTensorFormat = Options::DftTensorFormat; using FlushMode = Options::FlushMode; std::vector HannWindow(int window_size, bool sqrt_hann) { @@ -188,6 +189,8 @@ class AudioToTensorCalculator : public Node { int padding_samples_before_; int padding_samples_after_; FlushMode flush_mode_; + DftTensorFormat dft_tensor_format_; + Timestamp initial_timestamp_ = Timestamp::Unstarted(); int64 cumulative_input_samples_ = 0; Timestamp next_output_timestamp_ = Timestamp::Unstarted(); @@ -273,6 +276,7 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) { } padding_samples_before_ = options.padding_samples_before(); padding_samples_after_ = options.padding_samples_after(); + dft_tensor_format_ = options.dft_tensor_format(); flush_mode_ = options.flush_mode(); RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ @@ -492,14 +496,43 @@ absl::Status AudioToTensorCalculator::OutputTensor(const Matrix& block, kDcAndNyquistOut(cc).Send(std::make_pair(fft_output_[0], fft_output_[1]), timestamp); } - Matrix fft_output_matrix = - Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); - fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_); - // The last two elements are the DFT Nyquist values. - fft_output_matrix(fft_size_ - 2) = fft_output_[1]; // Nyquist real part - fft_output_matrix(fft_size_ - 1) = 0.0f; // Nyquist imagery part - ASSIGN_OR_RETURN(output_tensor, - ConvertToTensor(fft_output_matrix, {2, fft_size_ / 2})); + switch (dft_tensor_format_) { + case Options::WITH_NYQUIST: { + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); + fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_); + // The last two elements are Nyquist component. + fft_output_matrix(fft_size_ - 2) = fft_output_[1]; // Nyquist real part + fft_output_matrix(fft_size_ - 1) = 0.0f; // Nyquist imagery part + ASSIGN_OR_RETURN(output_tensor, ConvertToTensor(fft_output_matrix, + {2, fft_size_ / 2})); + break; + } + case Options::WITH_DC_AND_NYQUIST: { + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data(), 1, fft_size_); + fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_ + 2); + fft_output_matrix(1) = 0.0f; // DC imagery part. + // The last two elements are Nyquist component. + fft_output_matrix(fft_size_) = fft_output_[1]; // Nyquist real part + fft_output_matrix(fft_size_ + 1) = 0.0f; // Nyquist imagery part + ASSIGN_OR_RETURN( + output_tensor, + ConvertToTensor(fft_output_matrix, {2, (fft_size_ + 2) / 2})); + break; + } + case Options::WITHOUT_DC_AND_NYQUIST: { + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); + ASSIGN_OR_RETURN( + output_tensor, + ConvertToTensor(fft_output_matrix, {2, (fft_size_ - 2) / 2})); + break; + } + default: + return absl::InvalidArgumentError("Unsupported dft tensor format."); + } + } else { ASSIGN_OR_RETURN(output_tensor, ConvertToTensor(block, {num_channels_, num_samples_})); diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto index cff6b2878..aa3c1229c 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto @@ -68,4 +68,17 @@ message AudioToTensorCalculatorOptions { } optional FlushMode flush_mode = 10 [default = ENTIRE_TAIL_AT_TIMESTAMP_MAX]; + + enum DftTensorFormat { + DFT_TENSOR_FORMAT_UNKNOWN = 0; + // The output dft tensor without dc and nyquist components. + WITHOUT_DC_AND_NYQUIST = 1; + // The output dft tensor contains the nyquist component as the last + // two values. + WITH_NYQUIST = 2; + // The output dft tensor contains the dc component as the first two values + // and the nyquist component as the last two values. + WITH_DC_AND_NYQUIST = 3; + } + optional DftTensorFormat dft_tensor_format = 11 [default = WITH_NYQUIST]; } From 2f361e2f4791fa774db5cb20dbc888f89c234447 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 18 Nov 2022 08:51:30 -0800 Subject: [PATCH 009/346] Internal change PiperOrigin-RevId: 489486417 --- mediapipe/util/tracking/BUILD | 3 +-- mediapipe/util/tracking/motion_analysis.cc | 2 +- .../util/tracking/region_flow_computation.cc | 16 ++++++---------- .../tracking/region_flow_computation_test.cc | 2 +- 4 files changed, 9 insertions(+), 14 deletions(-) diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index 319e99d5b..3f1ebb353 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -458,7 +458,6 @@ cc_library( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", - "//mediapipe/framework/port:opencv_highgui", ], ) @@ -739,7 +738,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", - "//mediapipe/framework/port:opencv_highgui", + "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", diff --git a/mediapipe/util/tracking/motion_analysis.cc b/mediapipe/util/tracking/motion_analysis.cc index 0b7678889..5b6a970cf 100644 --- a/mediapipe/util/tracking/motion_analysis.cc +++ b/mediapipe/util/tracking/motion_analysis.cc @@ -791,7 +791,7 @@ void MotionAnalysis::VisualizeBlurAnalysisRegions(cv::Mat* input_view) { region_flow_computation_->ComputeBlurMask(*input_view, &corner_values, &mask); cv::Mat mask_3c; - cv::cvtColor(mask, mask_3c, CV_GRAY2RGB); + cv::cvtColor(mask, mask_3c, cv::COLOR_GRAY2RGB); cv::addWeighted(*input_view, 0.5, mask_3c, 0.5, -128, *input_view); } diff --git a/mediapipe/util/tracking/region_flow_computation.cc b/mediapipe/util/tracking/region_flow_computation.cc index cfd5c23c2..708c868b5 100644 --- a/mediapipe/util/tracking/region_flow_computation.cc +++ b/mediapipe/util/tracking/region_flow_computation.cc @@ -30,6 +30,7 @@ #include "absl/container/node_hash_set.h" #include "absl/memory/memory.h" #include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_features2d_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_video_inc.h" @@ -935,12 +936,13 @@ bool RegionFlowComputation::InitFrame(const cv::Mat& source, // Area based method best for downsampling. // For color images to temporary buffer. cv::Mat& resized = source.channels() == 1 ? dest_frame : *curr_color_image_; - cv::resize(source, resized, resized.size(), 0, 0, CV_INTER_AREA); + cv::resize(source, resized, resized.size(), 0, 0, cv::INTER_AREA); source_ptr = &resized; // Resize feature extraction mask if needed. if (!source_mask.empty()) { dest_mask.create(resized.rows, resized.cols, CV_8UC1); - cv::resize(source_mask, dest_mask, dest_mask.size(), 0, 0, CV_INTER_NN); + cv::resize(source_mask, dest_mask, dest_mask.size(), 0, 0, + cv::INTER_NEAREST); } } else if (!source_mask.empty()) { source_mask.copyTo(dest_mask); @@ -954,7 +956,7 @@ bool RegionFlowComputation::InitFrame(const cv::Mat& source, const int dimension = visual_options.tiny_image_dimension(); data->tiny_image.create(dimension, dimension, type); cv::resize(*source_ptr, data->tiny_image, data->tiny_image.size(), 0, 0, - CV_INTER_AREA); + cv::INTER_AREA); } if (source_ptr->channels() == 1 && @@ -2286,7 +2288,7 @@ void RegionFlowComputation::ExtractFeatures( // Initialize mask from frame's feature extraction mask, by downsampling and // negating the latter mask. if (!data->mask.empty()) { - cv::resize(data->mask, mask, mask.size(), 0, 0, CV_INTER_NN); + cv::resize(data->mask, mask, mask.size(), 0, 0, cv::INTER_NEAREST); for (int y = 0; y < mask.rows; ++y) { uint8* mask_ptr = mask.ptr(y); for (int x = 0; x < mask.cols; ++x) { @@ -2590,12 +2592,6 @@ void RegionFlowComputation::TrackFeatures(FrameTrackingData* from_data_ptr, cv::_InputArray input_frame2(data2.pyramid); #endif - // Using old c-interface for OpenCV's 2.2 tracker. - CvTermCriteria criteria; - criteria.type = CV_TERMCRIT_EPS | CV_TERMCRIT_ITER; - criteria.max_iter = options_.tracking_options().tracking_iterations(); - criteria.epsilon = 0.02f; - feature_track_error_.resize(num_features); feature_status_.resize(num_features); if (use_cv_tracking_) { diff --git a/mediapipe/util/tracking/region_flow_computation_test.cc b/mediapipe/util/tracking/region_flow_computation_test.cc index 0ac6dc2a5..435a8e200 100644 --- a/mediapipe/util/tracking/region_flow_computation_test.cc +++ b/mediapipe/util/tracking/region_flow_computation_test.cc @@ -28,7 +28,7 @@ #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/opencv_core_inc.h" -#include "mediapipe/framework/port/opencv_highgui_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" From 03d388fecffe3734d8f6878f6f0def404065076b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 18 Nov 2022 09:49:23 -0800 Subject: [PATCH 010/346] Add hand landmark named index constants PiperOrigin-RevId: 489498248 --- .../tasks/cc/components/containers/BUILD | 5 ++ .../tasks/cc/components/containers/landmark.h | 48 +++++++++++++ .../tasks/components/containers/BUILD | 12 ++++ .../components/containers/HandLandmark.java | 72 +++++++++++++++++++ .../python/components/containers/landmark.py | 26 +++++++ .../web/components/containers/landmark.d.ts | 25 +++++++ 6 files changed, 188 insertions(+) create mode 100644 mediapipe/tasks/cc/components/containers/landmark.h create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index bd66a0f28..2f5f8be5b 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -49,3 +49,8 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", ], ) + +cc_library( + name = "landmark", + hdrs = ["landmark.h"], +) diff --git a/mediapipe/tasks/cc/components/containers/landmark.h b/mediapipe/tasks/cc/components/containers/landmark.h new file mode 100644 index 000000000..6fdd294ae --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/landmark.h @@ -0,0 +1,48 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ + +namespace mediapipe::tasks::components::containers { + +// The 21 hand landmarks. +enum HandLandmark { + WRIST = 0, + THUMB_CMC = 1, + THUMB_MCP = 2, + THUMB_IP = 3, + THUMB_TIP = 4, + INDEX_FINGER_MCP = 5, + INDEX_FINGER_PIP = 6, + INDEX_FINGER_DIP = 7, + INDEX_FINGER_TIP = 8, + MIDDLE_FINGER_MCP = 9, + MIDDLE_FINGER_PIP = 10, + MIDDLE_FINGER_DIP = 11, + MIDDLE_FINGER_TIP = 12, + RING_FINGER_MCP = 13, + RING_FINGER_PIP = 14, + RING_FINGER_DIP = 15, + RING_FINGER_TIP = 16, + PINKY_MCP = 17, + PINKY_PIP = 18, + PINKY_DIP = 19, + PINKY_TIP = 20 +}; + +} // namespace mediapipe::tasks::components::containers + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index d6e6ac740..869157295 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -74,6 +74,18 @@ android_library( ], ) +android_library( + name = "handlandmark", + srcs = ["HandLandmark.java"], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + deps = [ + "@maven//:androidx_annotation_annotation", + "@maven//:com_google_guava_guava", + ], +) + android_library( name = "landmark", srcs = ["Landmark.java"], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java new file mode 100644 index 000000000..da7c4e0ca --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java @@ -0,0 +1,72 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import androidx.annotation.IntDef; + +/** The 21 hand landmarks. */ +public final class HandLandmark { + public static final int NUM_LANDMARKS = 21; + + public static final int WRIST = 0; + public static final int THUMB_CMC = 1; + public static final int THUMB_MCP = 2; + public static final int THUMB_IP = 3; + public static final int THUMB_TIP = 4; + public static final int INDEX_FINGER_MCP = 5; + public static final int INDEX_FINGER_PIP = 6; + public static final int INDEX_FINGER_DIP = 7; + public static final int INDEX_FINGER_TIP = 8; + public static final int MIDDLE_FINGER_MCP = 9; + public static final int MIDDLE_FINGER_PIP = 10; + public static final int MIDDLE_FINGER_DIP = 11; + public static final int MIDDLE_FINGER_TIP = 12; + public static final int RING_FINGER_MCP = 13; + public static final int RING_FINGER_PIP = 14; + public static final int RING_FINGER_DIP = 15; + public static final int RING_FINGER_TIP = 16; + public static final int PINKY_MCP = 17; + public static final int PINKY_PIP = 18; + public static final int PINKY_DIP = 19; + public static final int PINKY_TIP = 20; + + /** Represents a hand landmark type. */ + @IntDef({ + WRIST, + THUMB_CMC, + THUMB_MCP, + THUMB_IP, + THUMB_TIP, + INDEX_FINGER_MCP, + INDEX_FINGER_PIP, + INDEX_FINGER_DIP, + INDEX_FINGER_TIP, + MIDDLE_FINGER_MCP, + MIDDLE_FINGER_PIP, + MIDDLE_FINGER_DIP, + MIDDLE_FINGER_TIP, + RING_FINGER_MCP, + RING_FINGER_PIP, + RING_FINGER_DIP, + RING_FINGER_TIP, + PINKY_MCP, + PINKY_PIP, + PINKY_DIP, + PINKY_TIP, + }) + public @interface HandLandmarkType {} + + private HandLandmark() {} +} diff --git a/mediapipe/tasks/python/components/containers/landmark.py b/mediapipe/tasks/python/components/containers/landmark.py index dee2a16ad..81b2943dc 100644 --- a/mediapipe/tasks/python/components/containers/landmark.py +++ b/mediapipe/tasks/python/components/containers/landmark.py @@ -14,6 +14,7 @@ """Landmark data class.""" import dataclasses +import enum from typing import Optional from mediapipe.framework.formats import landmark_pb2 @@ -120,3 +121,28 @@ class NormalizedLandmark: z=pb2_obj.z, visibility=pb2_obj.visibility, presence=pb2_obj.presence) + + +class HandLandmark(enum.IntEnum): + """The 21 hand landmarks.""" + WRIST = 0 + THUMB_CMC = 1 + THUMB_MCP = 2 + THUMB_IP = 3 + THUMB_TIP = 4 + INDEX_FINGER_MCP = 5 + INDEX_FINGER_PIP = 6 + INDEX_FINGER_DIP = 7 + INDEX_FINGER_TIP = 8 + MIDDLE_FINGER_MCP = 9 + MIDDLE_FINGER_PIP = 10 + MIDDLE_FINGER_DIP = 11 + MIDDLE_FINGER_TIP = 12 + RING_FINGER_MCP = 13 + RING_FINGER_PIP = 14 + RING_FINGER_DIP = 15 + RING_FINGER_TIP = 16 + PINKY_MCP = 17 + PINKY_PIP = 18 + PINKY_DIP = 19 + PINKY_TIP = 20 diff --git a/mediapipe/tasks/web/components/containers/landmark.d.ts b/mediapipe/tasks/web/components/containers/landmark.d.ts index c887303d0..352717a2f 100644 --- a/mediapipe/tasks/web/components/containers/landmark.d.ts +++ b/mediapipe/tasks/web/components/containers/landmark.d.ts @@ -33,3 +33,28 @@ export declare interface Landmark { /** Whether this landmark is normalized with respect to the image size. */ normalized: boolean; } + +/** The 21 hand landmarks. */ +export const enum HandLandmark { + WRIST = 0, + THUMB_CMC = 1, + THUMB_MCP = 2, + THUMB_IP = 3, + THUMB_TIP = 4, + INDEX_FINGER_MCP = 5, + INDEX_FINGER_PIP = 6, + INDEX_FINGER_DIP = 7, + INDEX_FINGER_TIP = 8, + MIDDLE_FINGER_MCP = 9, + MIDDLE_FINGER_PIP = 10, + MIDDLE_FINGER_DIP = 11, + MIDDLE_FINGER_TIP = 12, + RING_FINGER_MCP = 13, + RING_FINGER_PIP = 14, + RING_FINGER_DIP = 15, + RING_FINGER_TIP = 16, + PINKY_MCP = 17, + PINKY_PIP = 18, + PINKY_DIP = 19, + PINKY_TIP = 20 +} From ac212c15070854b407812148739f6e1b72089a75 Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Fri, 18 Nov 2022 10:06:47 -0800 Subject: [PATCH 011/346] Internal change PiperOrigin-RevId: 489502255 --- mediapipe/calculators/audio/BUILD | 1 - mediapipe/calculators/core/BUILD | 6 ++---- mediapipe/calculators/image/BUILD | 10 +++++----- mediapipe/calculators/tensor/BUILD | 6 +++--- mediapipe/calculators/tensorflow/BUILD | 14 ++++++++------ mediapipe/calculators/tflite/BUILD | 6 +++--- mediapipe/calculators/util/BUILD | 9 ++++----- mediapipe/calculators/video/BUILD | 4 ++-- mediapipe/framework/BUILD | 4 ---- mediapipe/framework/formats/BUILD | 8 +++++--- mediapipe/framework/formats/motion/BUILD | 4 ++-- mediapipe/framework/profiler/BUILD | 4 ++++ mediapipe/framework/stream_handler/BUILD | 4 ++-- mediapipe/framework/tool/BUILD | 7 ++----- mediapipe/gpu/BUILD | 1 - 15 files changed, 42 insertions(+), 46 deletions(-) diff --git a/mediapipe/calculators/audio/BUILD b/mediapipe/calculators/audio/BUILD index ba461e4a7..555f7543f 100644 --- a/mediapipe/calculators/audio/BUILD +++ b/mediapipe/calculators/audio/BUILD @@ -197,7 +197,6 @@ cc_library( ":spectrogram_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", - "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index ecd878115..39837fadb 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -341,7 +341,6 @@ cc_test( srcs = ["concatenate_proto_list_calculator_test.cc"], deps = [ ":concatenate_proto_list_calculator", - ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:timestamp", @@ -403,7 +402,6 @@ cc_test( srcs = ["clip_vector_size_calculator_test.cc"], deps = [ ":clip_vector_size_calculator", - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:timestamp", @@ -956,10 +954,10 @@ cc_library( deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 89e2d371c..c78bc5cf7 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -159,8 +159,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":set_alpha_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", @@ -186,8 +186,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":bilateral_filter_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -290,10 +290,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":image_cropping_calculator_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", - "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", @@ -361,12 +361,12 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":recolor_calculator_cc_proto", + "//mediapipe/util:color_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", - "//mediapipe/util:color_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", ] + select({ @@ -630,8 +630,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":segmentation_smoothing_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image", diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 3f1278397..4c06df0ff 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -433,6 +433,7 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ + ":inference_calculator_cc_proto", ":inference_calculator_options_lib", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -794,12 +795,12 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:port", "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:tensor", - "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/port:ret_check", ] + selects.with_or({ ":compute_shader_unavailable": [], @@ -1279,7 +1280,6 @@ cc_library( "//mediapipe/gpu:MPPMetalHelper", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1378,9 +1378,9 @@ cc_library( "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:port", + "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/util:resource_util", "@org_tensorflow//tensorflow/lite:framework", - "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/framework/port:statusor", ] + selects.with_or({ "//mediapipe/gpu:disable_gpu": [], diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index d0dfc12ab..45f64f4f7 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -346,8 +346,8 @@ cc_library( srcs = ["matrix_to_tensor_calculator.cc"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:time_series_header_cc_proto", ":matrix_to_tensor_calculator_options_cc_proto", + "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", @@ -414,7 +414,7 @@ cc_library( "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:detection_cc_proto", # build_cleaner: keep + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", "//mediapipe/framework/port:opencv_imgcodecs", @@ -451,8 +451,8 @@ cc_library( srcs = ["tensorflow_inference_calculator.cc"], visibility = ["//visibility:public"], deps = [ - ":tensorflow_session", ":tensorflow_inference_calculator_cc_proto", + ":tensorflow_session", "@com_google_absl//absl/log:check", "//mediapipe/framework:timestamp", "@com_google_absl//absl/base:core_headers", @@ -515,6 +515,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", + "@org_tensorflow//tensorflow/core:protos_all_cc", ] + select({ "//conditions:default": [ "//mediapipe/framework/port:file_helpers", @@ -546,6 +547,7 @@ cc_library( "//mediapipe/framework/deps:clock", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", + "@org_tensorflow//tensorflow/core:protos_all_cc", ] + select({ "//conditions:default": [ "//mediapipe/framework/port:file_helpers", @@ -666,8 +668,8 @@ cc_library( srcs = ["tensor_to_matrix_calculator.cc"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:time_series_header_cc_proto", ":tensor_to_matrix_calculator_cc_proto", + "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", @@ -704,10 +706,10 @@ cc_library( srcs = ["tensor_to_vector_float_calculator.cc"], visibility = ["//visibility:public"], deps = [ + ":tensor_to_vector_float_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", - ":tensor_to_vector_float_calculator_options_cc_proto", ] + select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:framework", @@ -1083,7 +1085,6 @@ cc_test( linkstatic = 1, deps = [ ":tensor_to_image_frame_calculator", - ":tensor_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:image_frame", @@ -1236,6 +1237,7 @@ cc_test( data = [":test_frozen_graph"], linkstatic = 1, deps = [ + ":tensorflow_inference_calculator_cc_proto", ":tensorflow_session", ":tensorflow_inference_calculator", ":tensorflow_session_from_frozen_graph_generator", diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 2007a4fe1..8edaeee02 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -289,8 +289,8 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ - "//mediapipe/util/tflite:config", ":tflite_converter_calculator_cc_proto", + "//mediapipe/util/tflite:config", "//mediapipe/util:resource_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -410,15 +410,15 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ - "//mediapipe/util/tflite:config", ":tflite_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", + "//mediapipe/util/tflite:config", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "//mediapipe/framework/deps:file_path", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:location", - "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/port:ret_check", "@org_tensorflow//tensorflow/lite:framework", ] + selects.with_or({ diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 3a9ddc36f..24e976a73 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -23,8 +23,8 @@ cc_library( srcs = ["alignment_points_to_rects_calculator.cc"], visibility = ["//visibility:public"], deps = [ + ":detections_to_rects_calculator_cc_proto", "//mediapipe/calculators/util:detections_to_rects_calculator", - "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -266,8 +266,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":annotation_overlay_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/util:color_cc_proto", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", @@ -755,7 +755,6 @@ cc_library( deps = [ ":labels_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/port:ret_check", @@ -1313,8 +1312,8 @@ cc_library( srcs = ["to_image_calculator.cc"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:image_frame", @@ -1336,8 +1335,8 @@ cc_library( srcs = ["from_image_calculator.cc"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image", diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index 53d968151..2db3ed252 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -342,12 +342,12 @@ cc_library( "//mediapipe/framework/port:opencv_features2d", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/util/tracking:box_tracker_cc_proto", + "//mediapipe/util/tracking:flow_packager_cc_proto", "//mediapipe/util:resource_util", "//mediapipe/util/tracking", "//mediapipe/util/tracking:box_detector", "//mediapipe/util/tracking:box_tracker", - "//mediapipe/util/tracking:box_tracker_cc_proto", - "//mediapipe/util/tracking:flow_packager_cc_proto", "//mediapipe/util/tracking:tracking_visualization_utilities", ] + select({ "//mediapipe:android": [ diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 8ccdac3b9..e3429f1e9 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1039,7 +1039,6 @@ cc_library( ":graph_service_manager", ":port", "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1660,9 +1659,6 @@ cc_test( "//mediapipe/calculators/core:constant_side_packet_calculator", "//mediapipe/calculators/core:default_side_packet_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:template_parser", diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index c3241d911..e13bb2704 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -133,9 +133,9 @@ cc_library( "//visibility:public", ], deps = [ + ":affine_transform_data_cc_proto", "//mediapipe/framework:port", "//mediapipe/framework:type_map", - "//mediapipe/framework/formats:affine_transform_data_cc_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:point", @@ -209,8 +209,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ "@com_google_protobuf//:protobuf", - "//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/formats/annotation:locus_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -241,6 +241,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":location", + "//mediapipe/framework/formats/annotation:rasterization_cc_proto", "//mediapipe/framework/port:opencv_imgproc", ], alwayslink = 1, @@ -251,6 +252,7 @@ cc_test( srcs = ["location_opencv_test.cc"], deps = [ ":location_opencv", + "//mediapipe/framework/formats/annotation:rasterization_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:rectangle", ], @@ -346,8 +348,8 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", "//mediapipe/framework:type_map", diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index 28e0bfc6a..9819d262c 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -16,10 +16,10 @@ # Description: # Working with dense optical flow in mediapipe. -licenses(["notice"]) - load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +licenses(["notice"]) + package(default_visibility = ["//visibility:private"]) proto_library( diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 237aa825f..b53a1ac39 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -334,6 +334,10 @@ cc_library( "graph_profiler_stub.h", ], visibility = ["//mediapipe/framework:__pkg__"], + deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_profile_cc_proto", + ], ) cc_test( diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 8771a8773..866a5120e 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -13,6 +13,8 @@ # limitations under the License. # +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + licenses(["notice"]) package( @@ -20,8 +22,6 @@ package( features = ["-layering_check"], ) -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") - proto_library( name = "default_input_stream_handler_proto", srcs = ["default_input_stream_handler.proto"], diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index e54fb2177..52d04b4b1 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -299,12 +299,12 @@ mediapipe_cc_test( data = [":node_chain_subgraph.proto"], requires_full_emulation = False, deps = [ + ":node_chain_subgraph_cc_proto", ":options_field_util", ":options_registry", ":options_syntax_util", ":options_util", "//mediapipe/calculators/core:flow_limiter_calculator", - "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", "//mediapipe/framework:basic_types_registration", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", @@ -312,6 +312,7 @@ mediapipe_cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", + "//mediapipe/framework/testdata:night_light_calculator_cc_proto", "//mediapipe/framework/testdata:night_light_calculator_options_lib", "//mediapipe/framework/tool:node_chain_subgraph_options_lib", "//mediapipe/util:header_util", @@ -486,7 +487,6 @@ cc_library( deps = [ ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework/deps:proto_descriptor_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:numbers", "//mediapipe/framework/port:ret_check", @@ -738,9 +738,7 @@ cc_test( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:graph_service_manager", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:packet", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework:packet_type", "//mediapipe/framework:status_handler", @@ -923,7 +921,6 @@ cc_test( "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:stream_handler_cc_proto", "//mediapipe/framework:subgraph", "//mediapipe/framework:test_calculators", "//mediapipe/framework/port:gtest_main", diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 27d91f21a..10a8d7fff 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -783,7 +783,6 @@ cc_library( ":image_frame_view", ":shader_util", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:calculator_cc_proto", "@com_google_absl//absl/base:core_headers", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", From e2052a6a517fe1d8ce487f46a9856a225644d3f2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 18 Nov 2022 11:11:22 -0800 Subject: [PATCH 012/346] Rename embedding postprocessor "configure" method for consistency with classification postprocessor. PiperOrigin-RevId: 489518257 --- .../audio/audio_embedder/audio_embedder_graph.cc | 10 ++++++---- .../processors/embedding_postprocessing_graph.cc | 6 +++--- .../processors/embedding_postprocessing_graph.h | 2 +- .../embedding_postprocessing_graph_test.cc | 14 +++++++------- .../cc/text/text_embedder/text_embedder_graph.cc | 10 ++++++---- .../vision/image_embedder/image_embedder_graph.cc | 10 ++++++---- 6 files changed, 29 insertions(+), 23 deletions(-) diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc index 7667feaa3..f093b4d25 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc @@ -158,10 +158,12 @@ class AudioEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( - model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + MP_RETURN_IF_ERROR( + components::processors::ConfigureEmbeddingPostprocessingGraph( + model_resources, task_options.embedder_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Time aggregation is only needed for performing audio embedding on // audio files. Disables timestamp aggregation by not connecting the diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc index 880aec5d7..ad4881e12 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc @@ -150,7 +150,7 @@ absl::StatusOr> GetHeadNames( } // namespace -absl::Status ConfigureEmbeddingPostprocessing( +absl::Status ConfigureEmbeddingPostprocessingGraph( const ModelResources& model_resources, const proto::EmbedderOptions& embedder_options, proto::EmbeddingPostprocessingGraphOptions* options) { @@ -193,8 +193,8 @@ absl::Status ConfigureEmbeddingPostprocessing( // timestamp aggregation is required. // // The recommended way of using this graph is through the GraphBuilder API using -// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more -// details. +// the 'ConfigureEmbeddingPostprocessingGraph()' function. See header file for +// more details. class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { public: absl::StatusOr GetConfig( diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h index 58606ed80..889992463 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h @@ -58,7 +58,7 @@ namespace processors { // The embedding result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -absl::Status ConfigureEmbeddingPostprocessing( +absl::Status ConfigureEmbeddingPostprocessingGraph( const tasks::core::ModelResources& model_resources, const proto::EmbedderOptions& embedder_options, proto::EmbeddingPostprocessingGraphOptions* options); diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc index 84d84d648..163e46ee8 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc @@ -95,8 +95,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { options_in.set_l2_normalize(true); proto::EmbeddingPostprocessingGraphOptions options_out; - MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, - &options_out)); + MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources, + options_in, &options_out)); EXPECT_THAT( options_out, @@ -117,8 +117,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) { options_in.set_quantize(true); proto::EmbeddingPostprocessingGraphOptions options_out; - MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, - &options_out)); + MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources, + options_in, &options_out)); EXPECT_THAT( options_out, @@ -138,8 +138,8 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { options_in.set_l2_normalize(true); proto::EmbeddingPostprocessingGraphOptions options_out; - MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, - &options_out)); + MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources, + options_in, &options_out)); EXPECT_THAT( options_out, @@ -164,7 +164,7 @@ class PostprocessingTest : public tflite_shims::testing::Test { auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors." "EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing( + MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph( *model_resources, options, &postprocessing .GetOptions())); diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc index 79eedb6b5..c54636ee2 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc @@ -128,10 +128,12 @@ class TextEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( - model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + MP_RETURN_IF_ERROR( + components::processors::ConfigureEmbeddingPostprocessingGraph( + model_resources, task_options.embedder_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the embedding result. diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index 11e25144c..bf0dcf3c7 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -151,10 +151,12 @@ class ImageEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( - model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + MP_RETURN_IF_ERROR( + components::processors::ConfigureEmbeddingPostprocessingGraph( + model_resources, task_options.embedder_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the embedding results. From 71ae496a2001d1206b792bedd45d4027d7f043c7 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 18 Nov 2022 12:10:47 -0800 Subject: [PATCH 013/346] Add AudioEmbedder documentation PiperOrigin-RevId: 489532283 --- .../audio_embedder/audio_embedder_graph.cc | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc index f093b4d25..187f11f7f 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc @@ -100,6 +100,46 @@ void ConfigureAudioToTensorCalculator( } } // namespace +// An "AudioEmebdderGraph" performs embedding extractions. +// - Accepts CPU audio buffer and outputs embedding results on CPU. +// +// Inputs: +// AUDIO - Matrix +// Audio buffer to perform classification on. +// SAMPLE_RATE - double @Optional +// The sample rate of the corresponding audio data in the "AUDIO" stream. +// If sample rate is not provided, the "AUDIO" stream must carry a time +// series stream header with sample rate info. +// +// Outputs: +// EMBEDDINGS - EmbeddingResult @Optional +// The embedding results aggregated by head. Only produces results if +// the graph if the 'use_stream_mode' option is true. +// TIMESTAMPED_EMBEDDINGS - std::vector @Optional +// The embedding result aggregated by timestamp, then by head. Only +// produces results if the graph if the 'use_stream_mode' option is false. +// +// Example: +// node { +// calculator: "mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph" +// input_stream: "AUDIO:audio_in" +// input_stream: "SAMPLE_RATE:sample_rate_in" +// output_stream: "EMBEDDINGS:embeddings_out" +// output_stream: "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out" +// options { +// [mediapipe.tasks.audio.audio_embedder.proto.AudioEmbedderGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } +// embedder_options { +// l2_normalize: true +// } +// } +// } +// } class AudioEmbedderGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( From 1b594a0310f9c1bc3ece2562455bba0f812efd3a Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 18 Nov 2022 12:42:58 -0800 Subject: [PATCH 014/346] Return error status when any tflite input and output tensor doesn't have valid dimensionality information that is needed to allocate Gl/Metal buffer before calling ModifyGraphWithDelegate. PiperOrigin-RevId: 489539740 --- mediapipe/calculators/tensor/BUILD | 2 ++ mediapipe/calculators/tensor/inference_calculator_gl.cc | 8 ++++++++ .../calculators/tensor/inference_calculator_metal.cc | 7 +++++++ 3 files changed, 17 insertions(+) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 4c06df0ff..2a573fc44 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -464,6 +464,7 @@ cc_library( "//mediapipe/gpu:gl_calculator_helper", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", ], alwayslink = 1, @@ -513,6 +514,7 @@ cc_library( "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/util/tflite:config", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate_internal", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index bd8eb3eed..27b8bc23a 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -20,6 +20,7 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "mediapipe/calculators/tensor/inference_calculator.h" #include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/framework/calculator_context.h" @@ -154,6 +155,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::LoadDelegate( const auto& input_indices = interpreter_->inputs(); for (int i = 0; i < input_indices.size(); ++i) { const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Input tensor at index [%d] doesn't specify dimensions.", + input_indices[i]); + gpu_buffers_in_.emplace_back(absl::make_unique( Tensor::ElementType::kFloat32, Tensor::Shape{std::vector{ @@ -171,6 +176,9 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::LoadDelegate( // Create and bind output buffers. for (int i = 0; i < output_size_; ++i) { const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Output tensor at index [%d] doesn't specify dimensions.", + output_indices[i]); gpu_buffers_out_.emplace_back(absl::make_unique( Tensor::ElementType::kFloat32, Tensor::Shape{std::vector{ diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc index a85071f3e..750f0456e 100644 --- a/mediapipe/calculators/tensor/inference_calculator_metal.cc +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -22,6 +22,7 @@ #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "mediapipe/calculators/tensor/inference_calculator.h" #import "mediapipe/gpu/MPPMetalHelper.h" #include "mediapipe/gpu/MPPMetalUtil.h" @@ -245,6 +246,9 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( const auto& input_indices = interpreter_->inputs(); for (int i = 0; i < input_indices.size(); ++i) { const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Input tensor at index [%d] doesn't specify dimensions.", + input_indices[i]); // Create and bind input buffer. std::vector dims{tensor->dims->data, tensor->dims->data + tensor->dims->size}; @@ -266,6 +270,9 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( output_shapes_.resize(output_indices.size()); for (int i = 0; i < output_shapes_.size(); ++i) { const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Output tensor at index [%d] doesn't specify dimensions.", + output_indices[i]); RET_CHECK(tensor->dims->size <= 4); // Create and bind output buffers. // Channels are always padded to multiple of 4. From 524ac3ca61dc165f23a8d6ce29a9ff36d2fa7e98 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 18 Nov 2022 12:45:56 -0800 Subject: [PATCH 015/346] Internal change for Model Maker PiperOrigin-RevId: 489540387 --- mediapipe/model_maker/python/core/tasks/classifier.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index 200726864..f376edffa 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -91,6 +91,10 @@ class Classifier(custom_model.CustomModel): self._history = self._model.fit( x=train_dataset, epochs=self._hparams.epochs, + # `steps_per_epoch` is intentionally set to None in case the dataset + # is not repeated. Otherwise, the training process will stop when the + # dataset is exhausted even if there are epochs remaining. + steps_per_epoch=None, validation_data=validation_dataset, callbacks=self._callbacks) From bbd5da7971aa0d39bbeba638de34ded860bd30b3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 18 Nov 2022 17:10:54 -0800 Subject: [PATCH 016/346] Added the gray scale image support for the ImageToTensorCalculator on CPU. PiperOrigin-RevId: 489593917 --- .../tensor/image_to_tensor_calculator_test.cc | 79 ++++++++++++++++--- .../image_to_tensor_converter_opencv.cc | 29 ++++--- .../tensor/image_to_tensor_utils.cc | 7 +- 3 files changed, 93 insertions(+), 22 deletions(-) diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc index 07a5f9fe1..7ea60d98e 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -54,6 +54,13 @@ cv::Mat GetRgba(absl::string_view path) { return rgb; } +cv::Mat GetGray(absl::string_view path) { + cv::Mat bgr = cv::imread(file::JoinPath("./", path)); + cv::Mat gray; + cv::cvtColor(bgr, gray, cv::COLOR_BGR2GRAY); + return gray; +} + // Image to tensor test template. // No processing/assertions should be done after the function is invoked. void RunTestWithInputImagePacket(const Packet& input_image_packet, @@ -147,29 +154,34 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, ASSERT_THAT(tensor_vec, testing::SizeIs(1)); const Tensor& tensor = tensor_vec[0]; + const int channels = tensor.shape().dims[3]; + ASSERT_TRUE(channels == 1 || channels == 3); auto view = tensor.GetCpuReadView(); cv::Mat tensor_mat; if (output_int_tensor) { if (range_min < 0) { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kInt8); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8SC3, + tensor_mat = cv::Mat(tensor_height, tensor_width, + channels == 1 ? CV_8SC1 : CV_8SC3, const_cast(view.buffer())); } else { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kUInt8); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8UC3, + tensor_mat = cv::Mat(tensor_height, tensor_width, + channels == 1 ? CV_8UC1 : CV_8UC3, const_cast(view.buffer())); } } else { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kFloat32); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_32FC3, + tensor_mat = cv::Mat(tensor_height, tensor_width, + channels == 1 ? CV_32FC1 : CV_32FC3, const_cast(view.buffer())); } cv::Mat result_rgb; auto transformation = GetValueRangeTransformation(range_min, range_max, 0.0f, 255.0f).value(); - tensor_mat.convertTo(result_rgb, CV_8UC3, transformation.scale, - transformation.offset); + tensor_mat.convertTo(result_rgb, channels == 1 ? CV_8UC1 : CV_8UC3, + transformation.scale, transformation.offset); cv::Mat diff; cv::absdiff(result_rgb, expected_result, diff); @@ -185,17 +197,27 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, MP_ASSERT_OK(graph.WaitUntilDone()); } +mediapipe::ImageFormat::Format GetImageFormat(int image_channels) { + if (image_channels == 4) { + return ImageFormat::SRGBA; + } else if (image_channels == 3) { + return ImageFormat::SRGB; + } else if (image_channels == 1) { + return ImageFormat::GRAY8; + } + CHECK(false) << "Unsupported input image channles: " << image_channels; +} + Packet MakeImageFramePacket(cv::Mat input) { - ImageFrame input_image( - input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, - input.cols, input.rows, input.step, input.data, [](uint8*) {}); + ImageFrame input_image(GetImageFormat(input.channels()), input.cols, + input.rows, input.step, input.data, [](uint8*) {}); return MakePacket(std::move(input_image)).At(Timestamp(0)); } Packet MakeImagePacket(cv::Mat input) { mediapipe::Image input_image(std::make_shared( - input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, - input.cols, input.rows, input.step, input.data, [](uint8*) {})); + GetImageFormat(input.channels()), input.cols, input.rows, input.step, + input.data, [](uint8*) {})); return MakePacket(std::move(input_image)).At(Timestamp(0)); } @@ -429,6 +451,24 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotation) { /*border_mode=*/{}, roi); } +TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotationGray) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + RunTest(GetGray("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetGray("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "large_sub_rect_keep_aspect_with_rotation.png"), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + /*border_mode=*/{}, roi); +} + TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) { mediapipe::NormalizedRect roi; @@ -448,6 +488,25 @@ TEST(ImageToTensorCalculatorTest, /*border_mode=*/BorderMode::kZero, roi); } +TEST(ImageToTensorCalculatorTest, + LargeSubRectKeepAspectWithRotationBorderZeroGray) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + RunTest(GetGray("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetGray("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "large_sub_rect_keep_aspect_with_rotation_border_zero.png"), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + /*border_mode=*/BorderMode::kZero, roi); +} + TEST(ImageToTensorCalculatorTest, NoOpExceptRange) { mediapipe::NormalizedRect roi; roi.set_x_center(0.5f); diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc index f910b59f3..76e46f99d 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -48,15 +48,19 @@ class OpenCvProcessor : public ImageToTensorConverter { switch (tensor_type_) { case Tensor::ElementType::kInt8: mat_type_ = CV_8SC3; + mat_gray_type_ = CV_8SC1; break; case Tensor::ElementType::kFloat32: mat_type_ = CV_32FC3; + mat_gray_type_ = CV_32FC1; break; case Tensor::ElementType::kUInt8: mat_type_ = CV_8UC3; + mat_gray_type_ = CV_8UC1; break; default: mat_type_ = -1; + mat_gray_type_ = -1; } } @@ -64,11 +68,13 @@ class OpenCvProcessor : public ImageToTensorConverter { float range_min, float range_max, int tensor_buffer_offset, Tensor& output_tensor) override { - if (input.image_format() != mediapipe::ImageFormat::SRGB && - input.image_format() != mediapipe::ImageFormat::SRGBA) { - return InvalidArgumentError( - absl::StrCat("Only RGBA/RGB formats are supported, passed format: ", - static_cast(input.image_format()))); + const bool is_supported_format = + input.image_format() == mediapipe::ImageFormat::SRGB || + input.image_format() == mediapipe::ImageFormat::SRGBA || + input.image_format() == mediapipe::ImageFormat::GRAY8; + if (!is_supported_format) { + return InvalidArgumentError(absl::StrCat( + "Unsupported format: ", static_cast(input.image_format()))); } // TODO: Remove the check once tensor_buffer_offset > 0 is // supported. @@ -82,17 +88,18 @@ class OpenCvProcessor : public ImageToTensorConverter { const int output_channels = output_shape.dims[3]; auto buffer_view = output_tensor.GetCpuWriteView(); cv::Mat dst; + const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_; switch (tensor_type_) { case Tensor::ElementType::kInt8: - dst = cv::Mat(output_height, output_width, mat_type_, + dst = cv::Mat(output_height, output_width, dst_data_type, buffer_view.buffer()); break; case Tensor::ElementType::kFloat32: - dst = cv::Mat(output_height, output_width, mat_type_, + dst = cv::Mat(output_height, output_width, dst_data_type, buffer_view.buffer()); break; case Tensor::ElementType::kUInt8: - dst = cv::Mat(output_height, output_width, mat_type_, + dst = cv::Mat(output_height, output_width, dst_data_type, buffer_view.buffer()); break; default: @@ -137,7 +144,8 @@ class OpenCvProcessor : public ImageToTensorConverter { auto transform, GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, range_min, range_max)); - transformed.convertTo(dst, mat_type_, transform.scale, transform.offset); + transformed.convertTo(dst, dst_data_type, transform.scale, + transform.offset); return absl::OkStatus(); } @@ -148,7 +156,7 @@ class OpenCvProcessor : public ImageToTensorConverter { RET_CHECK_EQ(output_shape.dims[0], 1) << "Handling batch dimension not equal to 1 is not implemented in this " "converter."; - RET_CHECK_EQ(output_shape.dims[3], 3) + RET_CHECK(output_shape.dims[3] == 3 || output_shape.dims[3] == 1) << "Wrong output channel: " << output_shape.dims[3]; return absl::OkStatus(); } @@ -156,6 +164,7 @@ class OpenCvProcessor : public ImageToTensorConverter { enum cv::BorderTypes border_mode_; Tensor::ElementType tensor_type_; int mat_type_; + int mat_gray_type_; }; } // namespace diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils.cc b/mediapipe/calculators/tensor/image_to_tensor_utils.cc index 3f4c05d4e..d27c595b5 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_utils.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_utils.cc @@ -253,8 +253,11 @@ int GetNumOutputChannels(const mediapipe::Image& image) { } #endif // MEDIAPIPE_METAL_ENABLED #endif // !MEDIAPIPE_DISABLE_GPU - // All of the processors except for Metal expect 3 channels. - return 3; + // The output tensor channel is 1 for the input image with 1 channel; And the + // output tensor channels is 3 for the input image with 3 or 4 channels. + // TODO: Add a unittest here to test the behavior on GPU, i.e. + // failure. + return image.channels() == 1 ? 1 : 3; } absl::StatusOr> GetInputImage( From eb8ef1ace0a2b4c84c04a468478d8eb8463daeed Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Fri, 18 Nov 2022 19:41:05 -0800 Subject: [PATCH 017/346] Use shared_from_this in GlTextureBuffer::GetReadView, GetWriteView This ensures that the callbacks in GlTextureView won't call an expired object, even if user code holds a GlTextureView after releasing the buffer. Note that GlTextureBuffer is not always held by a shared_ptr, but it always is when GpuBuffer calls GetRead/WriteView on it. An alternative solution would have been to have GpuBuffer pass its shared_ptr to the view method, which could have been implemented with some compile-time logic to detect whether the method expects such an argument. However, that doesn't seem necessary. PiperOrigin-RevId: 489611843 --- mediapipe/gpu/gl_texture_buffer.cc | 23 +++++++++++++++++------ mediapipe/gpu/gl_texture_buffer.h | 3 ++- mediapipe/gpu/gpu_buffer_test.cc | 22 ++++++++++++++++++++++ 3 files changed, 41 insertions(+), 7 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 09703d89d..7f77cd4b3 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -260,13 +260,18 @@ GlTextureView GlTextureBuffer::GetReadView(internal::types, auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); + // Note that this method is only supposed to be called by GpuBuffer, which + // ensures this condition is satisfied. + DCHECK(!weak_from_this().expired()) + << "GlTextureBuffer must be held in shared_ptr to get a GlTextureView"; // Insert wait call to sync with the producer. WaitOnGpu(); - GlTextureView::DetachFn detach = [this](GlTextureView& texture) { - // Inform the GlTextureBuffer that we have finished accessing its - // contents, and create a consumer sync point. - DidRead(texture.gl_context()->CreateSyncToken()); - }; + GlTextureView::DetachFn detach = + [texbuf = shared_from_this()](GlTextureView& texture) { + // Inform the GlTextureBuffer that we have finished accessing its + // contents, and create a consumer sync point. + texbuf->DidRead(texture.gl_context()->CreateSyncToken()); + }; return GlTextureView(gl_context.get(), target(), name(), width(), height(), plane, std::move(detach), nullptr); } @@ -276,12 +281,18 @@ GlTextureView GlTextureBuffer::GetWriteView(internal::types, auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); + // Note that this method is only supposed to be called by GpuBuffer, which + // ensures this condition is satisfied. + DCHECK(!weak_from_this().expired()) + << "GlTextureBuffer must be held in shared_ptr to get a GlTextureView"; // Insert wait call to sync with the producer. WaitOnGpu(); Reuse(); // TODO: the producer wait should probably be part of Reuse in the // case when there are no consumers. GlTextureView::DoneWritingFn done_writing = - [this](const GlTextureView& texture) { ViewDoneWriting(texture); }; + [texbuf = shared_from_this()](const GlTextureView& texture) { + texbuf->ViewDoneWriting(texture); + }; return GlTextureView(gl_context.get(), target(), name(), width(), height(), plane, nullptr, std::move(done_writing)); } diff --git a/mediapipe/gpu/gl_texture_buffer.h b/mediapipe/gpu/gl_texture_buffer.h index c7643fd1b..f785571a1 100644 --- a/mediapipe/gpu/gl_texture_buffer.h +++ b/mediapipe/gpu/gl_texture_buffer.h @@ -35,7 +35,8 @@ class GlCalculatorHelperImpl; // Implements a GPU memory buffer as an OpenGL texture. For internal use. class GlTextureBuffer : public internal::GpuBufferStorageImpl< - GlTextureBuffer, internal::ViewProvider> { + GlTextureBuffer, internal::ViewProvider>, + public std::enable_shared_from_this { public: // This is called when the texture buffer is deleted. It is passed a sync // token created at that time on the GlContext. If the GlTextureBuffer has diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc index 796cb1d9d..145b71806 100644 --- a/mediapipe/gpu/gpu_buffer_test.cc +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -14,6 +14,8 @@ #include "mediapipe/gpu/gpu_buffer.h" +#include + #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -206,5 +208,25 @@ TEST_F(GpuBufferTest, Overwrite) { } } +TEST_F(GpuBufferTest, GlTextureViewRetainsWhatItNeeds) { + GpuBuffer buffer(300, 200, GpuBufferFormat::kBGRA32); + { + std::shared_ptr view = buffer.GetWriteView(); + EXPECT_EQ(view->Width(), 300); + EXPECT_EQ(view->Height(), 200); + FillImageFrameRGBA(*view, 255, 0, 0, 255); + } + + RunInGlContext([buffer = std::move(buffer)]() mutable { + // This is not a recommended pattern, but let's make sure that we don't + // crash if the buffer is released before the view. The view can hold + // callbacks into its underlying storage. + auto view = buffer.GetReadView(0); + buffer = nullptr; + }); + // We're really checking that we haven't crashed. + EXPECT_TRUE(true); +} + } // anonymous namespace } // namespace mediapipe From e853f04b79bb47e9542f54ba34065de3c5dcbd73 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 18 Nov 2022 19:53:21 -0800 Subject: [PATCH 018/346] Create AudioTaskRunner PiperOrigin-RevId: 489613573 --- .../tasks/audio/core/BaseAudioTaskApi.java | 1 + .../tasks/web/audio/audio_classifier/BUILD | 4 +- .../audio_classifier/audio_classifier.ts | 53 ++++++++--------- mediapipe/tasks/web/audio/core/BUILD | 14 ++++- .../web/audio/core/audio_task_options.d.ts | 21 ------- .../tasks/web/audio/core/audio_task_runner.ts | 58 +++++++++++++++++++ 6 files changed, 98 insertions(+), 53 deletions(-) create mode 100644 mediapipe/tasks/web/audio/core/audio_task_runner.ts diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java index 8eaf0adcb..2782f8d36 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java @@ -116,6 +116,7 @@ public class BaseAudioTaskApi implements AutoCloseable { defaultSampleRate = sampleRate; } } + /** * An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be * available in the user-defined result listener. diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 9e1fcbc51..498b17845 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -17,14 +17,14 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/audio/core:audio_task_runner", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 5533b0eaa..0c54a4718 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -18,10 +18,10 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -47,9 +47,8 @@ const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications'; // tslint:disable:jspb-use-builder-pattern /** Performs audio classification. */ -export class AudioClassifier extends TaskRunner { +export class AudioClassifier extends AudioTaskRunner { private classificationResults: AudioClassifierResult[] = []; - private defaultSampleRate = 48000; private readonly options = new AudioClassifierGraphOptions(); /** @@ -111,6 +110,14 @@ export class AudioClassifier extends TaskRunner { wasmLoaderOptions, new Uint8Array(graphData)); } + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); + } + /** * Sets new options for the audio classifier. * @@ -120,34 +127,19 @@ export class AudioClassifier extends TaskRunner { * * @param options The options for the audio classifier. */ - async setOptions(options: AudioClassifierOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override async setOptions(options: AudioClassifierOptions): Promise { + await super.setOptions(options); this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); this.refreshGraph(); } /** - * Sets the sample rate for all calls to `classify()` that omit an explicit - * sample rate. `48000` is used as a default if this method is not called. - * - * @param sampleRate A sample rate (e.g. `44100`). - */ - setDefaultSampleRate(sampleRate: number) { - this.defaultSampleRate = sampleRate; - } - - /** - * Performs audio classification on the provided audio data and waits + * Performs audio classification on the provided audio clip and waits * synchronously for the response. * - * @param audioData An array of raw audio capture data, like - * from a call to getChannelData on an AudioBuffer. + * @param audioData An array of raw audio capture data, like from a call to + * `getChannelData()` on an AudioBuffer. * @param sampleRate The sample rate in Hz of the provided audio data. If not * set, defaults to the sample rate set via `setDefaultSampleRate()` or * `48000` if no custom default was set. @@ -155,18 +147,21 @@ export class AudioClassifier extends TaskRunner { */ classify(audioData: Float32Array, sampleRate?: number): AudioClassifierResult[] { - sampleRate = sampleRate ?? this.defaultSampleRate; + return this.processAudioClip(audioData, sampleRate); + } + /** Sends an audio package to the graph and returns the classifications. */ + protected override process( + audioData: Float32Array, sampleRate: number, + timestampMs: number): AudioClassifierResult[] { // Configures the number of samples in the WASM layer. We re-configure the // number of samples and the sample rate for every frame, but ignore other // side effects of this function (such as sending the input side packet and // the input stream header). this.configureAudio( /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); - - const timestamp = performance.now(); - this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp); - this.addAudioToStream(audioData, timestamp); + this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.addAudioToStream(audioData, timestampMs); this.classificationResults = []; this.finishProcessing(); diff --git a/mediapipe/tasks/web/audio/core/BUILD b/mediapipe/tasks/web/audio/core/BUILD index ed60f2435..91ebbf524 100644 --- a/mediapipe/tasks/web/audio/core/BUILD +++ b/mediapipe/tasks/web/audio/core/BUILD @@ -1,6 +1,6 @@ # This package contains options shared by all MediaPipe Audio Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,3 +11,15 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core", ], ) + +mediapipe_ts_library( + name = "audio_task_runner", + srcs = ["audio_task_runner.ts"], + deps = [ + ":audio_task_options", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + ], +) diff --git a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts b/mediapipe/tasks/web/audio/core/audio_task_options.d.ts index 58a6e55d8..e3068625d 100644 --- a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts +++ b/mediapipe/tasks/web/audio/core/audio_task_options.d.ts @@ -16,29 +16,8 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options'; -/** - * MediaPipe audio task running mode. A MediaPipe audio task can be run with - * two different modes: - * - audio_clips: The mode for running a mediapipe audio task on independent - * audio clips. - * - audio_stream: The mode for running a mediapipe audio task on an audio - * stream, such as from a microphone. - * - */ -export type RunningMode = 'audio_clips'|'audio_stream'; - /** The options for configuring a MediaPipe Audio Task. */ export declare interface AudioTaskOptions { /** Options to configure the loading of the model assets. */ baseOptions?: BaseOptions; - - /** - * The running mode of the task. Default to the audio_clips mode. - * Audio tasks have two running modes: - * 1) The mode for running a mediapipe audio task on independent - * audio clips. - * 2) The mode for running a mediapipe audio task on an audio - * stream, such as from a microphone. - */ - runningMode?: RunningMode; } diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts new file mode 100644 index 000000000..ceff3895b --- /dev/null +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -0,0 +1,58 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; + +import {AudioTaskOptions} from './audio_task_options'; + +/** Base class for all MediaPipe Audio Tasks. */ +export abstract class AudioTaskRunner extends TaskRunner { + protected abstract baseOptions?: BaseOptionsProto|undefined; + private defaultSampleRate = 48000; + + /** Configures the shared options of an audio task. */ + async setOptions(options: AudioTaskOptions): Promise { + this.baseOptions = this.baseOptions ?? new BaseOptionsProto(); + if (options.baseOptions) { + this.baseOptions = await convertBaseOptionsToProto( + options.baseOptions, this.baseOptions); + } + } + + /** + * Sets the sample rate for API calls that omit an explicit sample rate. + * `48000` is used as a default if this method is not called. + * + * @param sampleRate A sample rate (e.g. `44100`). + */ + setDefaultSampleRate(sampleRate: number) { + this.defaultSampleRate = sampleRate; + } + + /** Sends an audio packet to the graph and awaits results. */ + protected abstract process( + audioData: Float32Array, sampleRate: number, timestampMs: number): T; + + /** Sends a single audio clip to the graph and awaits results. */ + protected processAudioClip(audioData: Float32Array, sampleRate?: number): T { + return this.process( + audioData, sampleRate ?? this.defaultSampleRate, performance.now()); + } +} + + From bbcbd5fc6c8fcefaf45da9c126a6f7aa8b6386c2 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Sat, 19 Nov 2022 04:47:55 -0800 Subject: [PATCH 019/346] Audio Embedder for Web PiperOrigin-RevId: 489669966 --- mediapipe/tasks/web/BUILD | 1 + mediapipe/tasks/web/audio.ts | 4 +- mediapipe/tasks/web/audio/BUILD | 1 + .../tasks/web/audio/audio_embedder/BUILD | 43 ++++ .../audio/audio_embedder/audio_embedder.ts | 211 ++++++++++++++++++ .../audio_embedder_options.d.ts | 22 ++ .../audio_embedder/audio_embedder_result.d.ts | 17 ++ mediapipe/tasks/web/audio/index.ts | 1 + 8 files changed, 299 insertions(+), 1 deletion(-) create mode 100644 mediapipe/tasks/web/audio/audio_embedder/BUILD create mode 100644 mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts create mode 100644 mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts create mode 100644 mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index e9703e37a..af76a1fe8 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -26,6 +26,7 @@ mediapipe_ts_library( srcs = ["audio.ts"], deps = [ "//mediapipe/tasks/web/audio/audio_classifier", + "//mediapipe/tasks/web/audio/audio_embedder", ], ) diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts index 764fd8393..056426f50 100644 --- a/mediapipe/tasks/web/audio.ts +++ b/mediapipe/tasks/web/audio.ts @@ -15,9 +15,11 @@ */ import {AudioClassifier as AudioClassifierImpl} from '../../tasks/web/audio/audio_classifier/audio_classifier'; +import {AudioEmbedder as AudioEmbedderImpl} from '../../tasks/web/audio/audio_embedder/audio_embedder'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. const AudioClassifier = AudioClassifierImpl; +const AudioEmbedder = AudioEmbedderImpl; -export {AudioClassifier}; +export {AudioClassifier, AudioEmbedder}; diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 4f6e48b28..acd7494d7 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -9,5 +9,6 @@ mediapipe_ts_library( srcs = ["index.ts"], deps = [ "//mediapipe/tasks/web/audio/audio_classifier", + "//mediapipe/tasks/web/audio/audio_embedder", ], ) diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD new file mode 100644 index 000000000..7d9a994a3 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -0,0 +1,43 @@ +# This contains the MediaPipe Audio Embedder Task. +# +# This task takes audio input and performs embedding. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "audio_embedder", + srcs = ["audio_embedder.ts"], + deps = [ + ":audio_embedder_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/audio/core:audio_task_runner", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/components/processors:embedder_options", + "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "audio_embedder_types", + srcs = [ + "audio_embedder_options.d.ts", + "audio_embedder_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/audio/core:audio_task_options", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + ], +) diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts new file mode 100644 index 000000000..51cb819de --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -0,0 +1,211 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {AudioEmbedderGraphOptions as AudioEmbedderGraphOptionsProto} from '../../../../tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options_pb'; +import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; +import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; +import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {AudioEmbedderOptions} from './audio_embedder_options'; +import {AudioEmbedderResult} from './audio_embedder_result'; + +export * from './audio_embedder_options'; +export * from './audio_embedder_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' cannot +// be changed +// TODO: Change this to `audio_in` to match the name in the CC +// implementation +const AUDIO_STREAM = 'input_audio'; +const SAMPLE_RATE_STREAM = 'sample_rate'; +const EMBEDDINGS_STREAM = 'embeddings_out'; +const TIMESTAMPED_EMBEDDINGS_STREAM = 'timestamped_embeddings_out'; +const AUDIO_EMBEDDER_CALCULATOR = + 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph'; + +/** Performs embedding extraction on audio. */ +export class AudioEmbedder extends AudioTaskRunner { + private embeddingResults: AudioEmbedderResult[] = []; + private readonly options = new AudioEmbedderGraphOptionsProto(); + + /** + * Initializes the Wasm runtime and creates a new audio embedder from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param audioEmbedderOptions The options for the audio embedder. Note that + * either a path to the TFLite model or the model itself needs to be + * provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + audioEmbedderOptions: AudioEmbedderOptions): Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const embedder = await createMediaPipeLib( + AudioEmbedder, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await embedder.setOptions(audioEmbedderOptions); + return embedder; + } + + /** + * Initializes the Wasm runtime and creates a new audio embedder based on the + * provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the TFLite model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return AudioEmbedder.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new audio embedder based on the + * path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the TFLite model. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return AudioEmbedder.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the audio embedder. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the audio embedder. + */ + override async setOptions(options: AudioEmbedderOptions): Promise { + await super.setOptions(options); + this.options.setEmbedderOptions(convertEmbedderOptionsToProto( + options, this.options.getEmbedderOptions())); + this.refreshGraph(); + } + + /** + * Performs embeding extraction on the provided audio clip and waits + * synchronously for the response. + * + * @param audioData An array of raw audio capture data, like from a call to + * `getChannelData()` on an AudioBuffer. + * @param sampleRate The sample rate in Hz of the provided audio data. If not + * set, defaults to the sample rate set via `setDefaultSampleRate()` or + * `48000` if no custom default was set. + * @return The embedding resuls of the audio + */ + embed(audioData: Float32Array, sampleRate?: number): AudioEmbedderResult[] { + return this.processAudioClip(audioData, sampleRate); + } + + protected override process( + audioData: Float32Array, sampleRate: number, + timestampMs: number): AudioEmbedderResult[] { + // Configures the number of samples in the WASM layer. We re-configure the + // number of samples and the sample rate for every frame, but ignore other + // side effects of this function (such as sending the input side packet and + // the input stream header). + this.configureAudio( + /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); + this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.addAudioToStream(audioData, timestampMs); + + this.embeddingResults = []; + this.finishProcessing(); + return this.embeddingResults; + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(AUDIO_STREAM); + graphConfig.addInputStream(SAMPLE_RATE_STREAM); + graphConfig.addOutputStream(EMBEDDINGS_STREAM); + graphConfig.addOutputStream(TIMESTAMPED_EMBEDDINGS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + AudioEmbedderGraphOptionsProto.ext, this.options); + + const embedderNode = new CalculatorGraphConfig.Node(); + embedderNode.setCalculator(AUDIO_EMBEDDER_CALCULATOR); + embedderNode.addInputStream('AUDIO:' + AUDIO_STREAM); + embedderNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM); + embedderNode.addOutputStream('EMBEDDINGS:' + EMBEDDINGS_STREAM); + embedderNode.addOutputStream( + 'TIMESTAMPED_EMBEDDINGS:' + TIMESTAMPED_EMBEDDINGS_STREAM); + embedderNode.setOptions(calculatorOptions); + + graphConfig.addNode(embedderNode); + + this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + }); + + this.attachProtoVectorListener(TIMESTAMPED_EMBEDDINGS_STREAM, data => { + for (const binaryProto of data) { + const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + } + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + + diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts new file mode 100644 index 000000000..98f412d0f --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; + +/** Options to configure the MediaPipe Audio Embedder Task */ +export declare interface AudioEmbedderOptions extends EmbedderOptions, + AudioTaskOptions {} diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts new file mode 100644 index 000000000..13abc28d9 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {Embedding, EmbeddingResult as AudioEmbedderResult} from '../../../../tasks/web/components/containers/embedding_result'; diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index a5083b326..17a908f30 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -15,3 +15,4 @@ */ export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; From 977ee4272e90272fef0ab140036816e83e05c615 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sat, 19 Nov 2022 10:51:20 -0800 Subject: [PATCH 020/346] Add public visibility to the model maker public API. PiperOrigin-RevId: 489701768 --- mediapipe/model_maker/python/text/text_classifier/BUILD | 7 +++++++ .../model_maker/python/vision/gesture_recognizer/BUILD | 7 +++++++ mediapipe/model_maker/python/vision/image_classifier/BUILD | 7 +++++++ 3 files changed, 21 insertions(+) diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 0c35e7966..7bb41351e 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -21,9 +21,16 @@ package( licenses(["notice"]) +###################################################################### +# Public target of the MediaPipe Model Maker TextCassifier APIs. + +# Please see https://developers.google.com/mediapipe/solutions/text/text_classifier/customize for +# more information about the MediaPipe Model Maker TextCassifier APIs. +###################################################################### py_library( name = "text_classifier_import", srcs = ["__init__.py"], + visibility = ["//visibility:public"], deps = [ ":dataset", ":model_options", diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index b7d334d9c..b9425a181 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -103,9 +103,16 @@ py_library( ], ) +###################################################################### +# Public target of the MediaPipe Model Maker GestureRecognizer APIs. + +# Please see https://developers.google.com/mediapipe/solutions/vision/gesture_recognizer/customize +# for more information about the MediaPipe Model Maker GestureRecognizer APIs. +###################################################################### py_library( name = "gesture_recognizer_import", srcs = ["__init__.py"], + visibility = ["//visibility:public"], deps = [ ":dataset", ":gesture_recognizer", diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index c581d9fbc..29ae189e9 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -21,9 +21,16 @@ package( default_visibility = ["//mediapipe:__subpackages__"], ) +###################################################################### +# Public target of the MediaPipe Model Maker ImageClassifier APIs. + +# Please see https://developers.google.com/mediapipe/solutions/vision/image_classifier/customize for +# more information about the MediaPipe Model Maker ImageClassifier APIs. +###################################################################### py_library( name = "image_classifier_import", srcs = ["__init__.py"], + visibility = ["//visibility:public"], deps = [ ":dataset", ":hyperparameters", From a33cb1e05e602cb06b6e6ecdc3a12dad82f5f4e4 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Sat, 19 Nov 2022 21:03:29 -0800 Subject: [PATCH 021/346] Check that Java buffer supports direct access before using it If the buffer is not created with allocateDirect, JNI APIs will return a data pointer of nullptr and a capacity of -1. This can cause a crash when we access it. Also clean up the code to raise exceptions instead of just logging errors and returning nullptr. PiperOrigin-RevId: 489751312 --- .../framework/jni/packet_creator_jni.cc | 171 +++++++++++------- .../framework/jni/packet_getter_jni.cc | 42 +++-- 2 files changed, 133 insertions(+), 80 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index 250d7c938..2d5447401 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -17,6 +17,8 @@ #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/camera_intrinsics.h" #include "mediapipe/framework/formats/image.h" @@ -107,17 +109,18 @@ absl::StatusOr CreateGpuBuffer( // Create a 1, 3, or 4 channel 8-bit ImageFrame shared pointer from a Java // ByteBuffer. -std::unique_ptr CreateImageFrameFromByteBuffer( - JNIEnv* env, jobject byte_buffer, jint width, jint height, - mediapipe::ImageFormat::Format format) { +absl::StatusOr> +CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width, + jint height, + mediapipe::ImageFormat::Format format) { switch (format) { case mediapipe::ImageFormat::SRGBA: case mediapipe::ImageFormat::SRGB: case mediapipe::ImageFormat::GRAY8: break; default: - LOG(ERROR) << "Format must be either SRGBA, SRGB, or GRAY8."; - return nullptr; + return absl::InvalidArgumentError( + "Format must be either SRGBA, SRGB, or GRAY8."); } auto image_frame = std::make_unique( @@ -125,25 +128,30 @@ std::unique_ptr CreateImageFrameFromByteBuffer( mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + const void* buffer_data = env->GetDirectBufferAddress(byte_buffer); + if (buffer_data == nullptr || buffer_size < 0) { + return absl::InvalidArgumentError( + "Cannot get direct access to the input buffer. It should be created " + "using allocateDirect."); + } + const int num_channels = image_frame->NumberOfChannels(); const int expected_buffer_size = num_channels == 1 ? width * height : image_frame->PixelDataSize(); - if (buffer_size != expected_buffer_size) { - if (num_channels != 1) - LOG(ERROR) << "The input image buffer should have 4 bytes alignment."; - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << expected_buffer_size - << ", Image width: " << width; - return nullptr; - } + RET_CHECK_EQ(buffer_size, expected_buffer_size) + << (num_channels != 1 + ? "The input image buffer should have 4 bytes alignment. " + : "") + << "Please check the input buffer size." + << " Buffer size: " << buffer_size + << ", Buffer size needed: " << expected_buffer_size + << ", Image width: " << width; // Copy buffer data to image frame's pixel_data_. if (num_channels == 1) { const int width_step = image_frame->WidthStep(); - const char* src_row = - reinterpret_cast(env->GetDirectBufferAddress(byte_buffer)); + const char* src_row = reinterpret_cast(buffer_data); char* dst_row = reinterpret_cast(image_frame->MutablePixelData()); for (int i = height; i > 0; --i) { std::memcpy(dst_row, src_row, width); @@ -152,7 +160,6 @@ std::unique_ptr CreateImageFrameFromByteBuffer( } } else { // 3 and 4 channels. - const void* buffer_data = env->GetDirectBufferAddress(byte_buffer); std::memcpy(image_frame->MutablePixelData(), buffer_data, image_frame->PixelDataSize()); } @@ -176,77 +183,100 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = CreateImageFrameFromByteBuffer( + auto image_frame_or = CreateImageFrameFromByteBuffer( env, byte_buffer, width, height, mediapipe::ImageFormat::SRGB); - if (nullptr == image_frame) return 0L; + if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } +absl::StatusOr> CreateRgbImageFromRgba( + JNIEnv* env, jobject byte_buffer, jint width, jint height) { + const uint8_t* rgba_data = + static_cast(env->GetDirectBufferAddress(byte_buffer)); + int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + if (rgba_data == nullptr || buffer_size < 0) { + return absl::InvalidArgumentError( + "Cannot get direct access to the input buffer. It should be created " + "using allocateDirect."); + } + + const int expected_buffer_size = width * height * 4; + RET_CHECK_EQ(buffer_size, expected_buffer_size) + << "Please check the input buffer size." + << " Buffer size: " << buffer_size + << ", Buffer size needed: " << expected_buffer_size + << ", Image width: " << width; + + auto image_frame = absl::make_unique( + mediapipe::ImageFormat::SRGB, width, height, + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); + mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height, + image_frame->MutablePixelData(), + image_frame->WidthStep()); + return image_frame; +} + JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImageFromRgba)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - const uint8_t* rgba_data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::SRGB, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (buffer_size != width * height * 4) { - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << width * height * 4 - << ", Image width: " << width; - return 0L; - } - mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height, - image_frame->MutablePixelData(), - image_frame->WidthStep()); - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + auto image_frame_or = CreateRgbImageFromRgba(env, byte_buffer, width, height); + if (ThrowIfError(env, image_frame_or.status())) return 0L; + + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = CreateImageFrameFromByteBuffer( + auto image_frame_or = CreateImageFrameFromByteBuffer( env, byte_buffer, width, height, mediapipe::ImageFormat::GRAY8); - if (nullptr == image_frame) return 0L; + if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - const void* data = env->GetDirectBufferAddress(byte_buffer); - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::VEC32F1, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (buffer_size != image_frame->PixelDataSize()) { - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << image_frame->PixelDataSize() - << ", Image width: " << width; - return 0L; - } - std::memcpy(image_frame->MutablePixelData(), data, - image_frame->PixelDataSize()); - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + // TODO: merge this case with CreateImageFrameFromByteBuffer. + auto image_frame_or = + [&]() -> absl::StatusOr> { + const void* data = env->GetDirectBufferAddress(byte_buffer); + int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + if (data == nullptr || buffer_size < 0) { + return absl::InvalidArgumentError( + "input buffer does not support direct access"); + } + + auto image_frame = absl::make_unique( + mediapipe::ImageFormat::VEC32F1, width, height, + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); + RET_CHECK_EQ(buffer_size, image_frame->PixelDataSize()) + << "Please check the input buffer size." + << " Buffer size: " << buffer_size + << ", Buffer size needed: " << image_frame->PixelDataSize() + << ", Image width: " << width; + std::memcpy(image_frame->MutablePixelData(), data, + image_frame->PixelDataSize()); + return image_frame; + }(); + if (ThrowIfError(env, image_frame_or.status())) return 0L; + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = CreateImageFrameFromByteBuffer( + auto image_frame_or = CreateImageFrameFromByteBuffer( env, byte_buffer, width, height, mediapipe::ImageFormat::SRGBA); - if (nullptr == image_frame) return 0L; + if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } @@ -291,6 +321,12 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateAudioPacketDirect)( jint num_samples) { const uint8_t* audio_sample = reinterpret_cast(env->GetDirectBufferAddress(data)); + if (!audio_sample) { + ThrowIfError(env, absl::InvalidArgumentError( + "Cannot get direct access to the input buffer. It " + "should be created using allocateDirect.")); + return 0L; + } mediapipe::Packet packet = createAudioPacket(audio_sample, num_samples, num_channels); return CreatePacketWithContext(context, packet); @@ -360,8 +396,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEnv* env, jobject thiz, jlong context, jint rows, jint cols, jfloatArray data) { if (env->GetArrayLength(data) != rows * cols) { - LOG(ERROR) << "Please check the matrix data size, has to be rows * cols = " - << rows * cols; + ThrowIfError( + env, absl::InvalidArgumentError(absl::StrCat( + "Please check the matrix data size, has to be rows * cols = ", + rows * cols))); return 0L; } std::unique_ptr matrix(new mediapipe::Matrix(rows, cols)); @@ -392,16 +430,18 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( format = mediapipe::ImageFormat::GRAY8; break; default: - LOG(ERROR) << "Channels must be either 1, 3, or 4."; + ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat( + "Channels must be either 1, 3, or 4, but are ", + num_channels))); return 0L; } - auto image_frame = + auto image_frame_or = CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, format); - if (nullptr == image_frame) return 0L; + if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = - mediapipe::MakePacket(std::move(image_frame)); + mediapipe::MakePacket(*std::move(image_frame_or)); return CreatePacketWithContext(context, packet); } @@ -502,7 +542,8 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCalculatorOptions)( jbyte* data_ref = env->GetByteArrayElements(data, nullptr); auto options = absl::make_unique(); if (!options->ParseFromArray(data_ref, count)) { - LOG(ERROR) << "Parsing binary-encoded CalculatorOptions failed."; + ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat( + "Parsing binary-encoded CalculatorOptions failed."))); return 0L; } mediapipe::Packet packet = mediapipe::Adopt(options.release()); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc index c215dd929..737f6db72 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc @@ -14,6 +14,7 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h" +#include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" @@ -299,34 +300,38 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)( : GetFromNativeHandle(packet); int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + void* buffer_data = env->GetDirectBufferAddress(byte_buffer); + if (buffer_data == nullptr || buffer_size < 0) { + ThrowIfError(env, absl::InvalidArgumentError( + "input buffer does not support direct access")); + return false; + } // Assume byte buffer stores pixel data contiguously. const int expected_buffer_size = image.Width() * image.Height() * image.ByteDepth() * image.NumberOfChannels(); if (buffer_size != expected_buffer_size) { - LOG(ERROR) << "Expected buffer size " << expected_buffer_size - << " got: " << buffer_size << ", width " << image.Width() - << ", height " << image.Height() << ", channels " - << image.NumberOfChannels(); + ThrowIfError( + env, absl::InvalidArgumentError(absl::StrCat( + "Expected buffer size ", expected_buffer_size, + " got: ", buffer_size, ", width ", image.Width(), ", height ", + image.Height(), ", channels ", image.NumberOfChannels()))); return false; } switch (image.ByteDepth()) { case 1: { - uint8* data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); + uint8* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } case 2: { - uint16* data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); + uint16* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } case 4: { - float* data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); + float* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } @@ -351,12 +356,19 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)( uint8_t* rgba_data = static_cast(env->GetDirectBufferAddress(byte_buffer)); int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + if (rgba_data == nullptr || buffer_size < 0) { + ThrowIfError(env, absl::InvalidArgumentError( + "input buffer does not support direct access")); + return false; + } if (buffer_size != image.Width() * image.Height() * 4) { - LOG(ERROR) << "Buffer size has to be width*height*4\n" - << "Image width: " << image.Width() - << ", Image height: " << image.Height() - << ", Buffer size: " << buffer_size << ", Buffer size needed: " - << image.Width() * image.Height() * 4; + ThrowIfError(env, + absl::InvalidArgumentError(absl::StrCat( + "Buffer size has to be width*height*4\n" + "Image width: ", + image.Width(), ", Image height: ", image.Height(), + ", Buffer size: ", buffer_size, ", Buffer size needed: ", + image.Width() * image.Height() * 4))); return false; } mediapipe::android::RgbToRgba(image.PixelData(), image.WidthStep(), From bdf4078e89cb11e01da0c5eda6322a22ad74e127 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sat, 19 Nov 2022 21:12:23 -0800 Subject: [PATCH 022/346] Internal change PiperOrigin-RevId: 489752009 --- mediapipe/model_maker/python/core/utils/BUILD | 1 + .../python/core/utils/model_util_test.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index 12fef631f..492bba0a9 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -45,6 +45,7 @@ py_test( name = "model_util_test", srcs = ["model_util_test.py"], deps = [ + ":file_util", ":model_util", ":quantization", ":test_util", diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index 05c6ffe3f..f0020db25 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -14,10 +14,12 @@ import os from typing import Optional +from unittest import mock as unittest_mock from absl.testing import parameterized import tensorflow as tf +from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import test_util @@ -25,11 +27,15 @@ from mediapipe.model_maker.python.core.utils import test_util class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): - def test_load_keras_model(self): + @unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True) + def test_load_keras_model(self, mock_get_absolute_path): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model') model.save(saved_model_path) + # model_util.load_keras_model takes in a relative path to files within the + # model_maker dir, so we patch the function for testing + mock_get_absolute_path.return_value = saved_model_path loaded_model = model_util.load_keras_model(saved_model_path) input_tensors = test_util.create_random_sample(size=[1, input_dim]) @@ -37,13 +43,16 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): loaded_model_output = loaded_model.predict_on_batch(input_tensors) self.assertTrue((model_output == loaded_model_output).all()) - def test_load_tflite_model_buffer(self): + @unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True) + def test_load_tflite_model_buffer(self, mock_get_absolute_path): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) tflite_model = model_util.convert_to_tflite(model) tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file) - + # model_util.load_tflite_model_buffer takes in a relative path to files + # within the model_maker dir, so we patch the function for testing + mock_get_absolute_path.return_value = tflite_file tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file) test_util.test_tflite( keras_model=model, From a367753eda595f01a60e4ccb12845f2675cb37c5 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Sun, 20 Nov 2022 10:39:59 -0800 Subject: [PATCH 023/346] Internal change PiperOrigin-RevId: 489824381 --- .../vision/gesture_recognizer/gesture_recognizer_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 39272cbbc..9cee88362 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -14,7 +14,6 @@ import io import os -import random import tempfile from unittest import mock as unittest_mock import zipfile @@ -27,6 +26,7 @@ from mediapipe.model_maker.python.vision import gesture_recognizer from mediapipe.tasks.python.test import test_utils _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data' +tf.keras.backend.experimental.enable_tf_random_generator() class GestureRecognizerTest(tf.test.TestCase): @@ -42,7 +42,7 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() - random.seed(1234) + tf.keras.utils.set_random_seed(87654321) all_data = self._load_data() # Splits data, 90% data for training, 10% for validation self._train_data, self._validation_data = all_data.split(0.9) From 6cf464636b00fb5039bf705319ffe09408d207b3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sun, 20 Nov 2022 14:24:21 -0800 Subject: [PATCH 024/346] Internal change PiperOrigin-RevId: 489842199 --- mediapipe/tasks/BUILD | 7 ++ .../tasks/cc/audio/audio_classifier/BUILD | 53 ++++++----- mediapipe/tasks/cc/audio/audio_embedder/BUILD | 55 ++++++------ mediapipe/tasks/cc/audio/core/BUILD | 1 + .../tasks/cc/components/containers/BUILD | 2 +- .../tasks/cc/components/processors/BUILD | 2 + mediapipe/tasks/cc/core/BUILD | 4 +- mediapipe/tasks/cc/text/text_classifier/BUILD | 51 ++++++----- mediapipe/tasks/cc/text/text_embedder/BUILD | 3 + mediapipe/tasks/cc/vision/core/BUILD | 2 + .../tasks/cc/vision/gesture_recognizer/BUILD | 90 ++++++++++--------- .../tasks/cc/vision/hand_landmarker/BUILD | 72 ++++++++------- .../tasks/cc/vision/image_classifier/BUILD | 49 +++++----- .../tasks/cc/vision/image_embedder/BUILD | 49 +++++----- .../tasks/cc/vision/image_segmenter/BUILD | 6 +- .../tasks/cc/vision/object_detector/BUILD | 65 +++++++------- 16 files changed, 278 insertions(+), 233 deletions(-) diff --git a/mediapipe/tasks/BUILD b/mediapipe/tasks/BUILD index 242a88cfc..98ddd5777 100644 --- a/mediapipe/tasks/BUILD +++ b/mediapipe/tasks/BUILD @@ -21,3 +21,10 @@ package_group( "//mediapipe/tasks/...", ], ) + +package_group( + name = "users", + includes = [ + ":internal", + ], +) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index 1955adfe7..a817bcc3b 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -16,6 +16,35 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Audio Classifier +# https://developers.google.com/mediapipe/solutions/audio/audio_classifier +cc_library( + name = "audio_classifier", + srcs = ["audio_classifier.cc"], + hdrs = ["audio_classifier.h"], + visibility = [ + "//mediapipe/tasks:users", + ], + deps = [ + ":audio_classifier_graph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:matrix", + "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", + "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", + "//mediapipe/tasks/cc/audio/core:base_audio_task_api", + "//mediapipe/tasks/cc/audio/core:running_mode", + "//mediapipe/tasks/cc/components/containers:classification_result", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "audio_classifier_graph", srcs = ["audio_classifier_graph.cc"], @@ -52,28 +81,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "audio_classifier", - srcs = ["audio_classifier.cc"], - hdrs = ["audio_classifier.h"], - deps = [ - ":audio_classifier_graph", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/formats:matrix", - "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", - "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", - "//mediapipe/tasks/cc/audio/core:base_audio_task_api", - "//mediapipe/tasks/cc/audio/core:running_mode", - "//mediapipe/tasks/cc/components/containers:classification_result", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - # TODO: mediapipe/tasks/cc/audio/utils:test_utils does not compile in the OSS build diff --git a/mediapipe/tasks/cc/audio/audio_embedder/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/BUILD index b982ef39a..adba28e6a 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/cc/audio/audio_embedder/BUILD @@ -16,6 +16,36 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Audio Embedder +# https://developers.google.com/mediapipe/solutions/audio/audio_embedder +cc_library( + name = "audio_embedder", + srcs = ["audio_embedder.cc"], + hdrs = ["audio_embedder.h"], + visibility = [ + "//mediapipe/tasks:users", + ], + deps = [ + ":audio_embedder_graph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:matrix", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto", + "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", + "//mediapipe/tasks/cc/audio/core:base_audio_task_api", + "//mediapipe/tasks/cc/audio/core:running_mode", + "//mediapipe/tasks/cc/components/containers:embedding_result", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedder_options", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:cosine_similarity", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "audio_embedder_graph", srcs = ["audio_embedder_graph.cc"], @@ -51,29 +81,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "audio_embedder", - srcs = ["audio_embedder.cc"], - hdrs = ["audio_embedder.h"], - deps = [ - ":audio_embedder_graph", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/formats:matrix", - "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto", - "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", - "//mediapipe/tasks/cc/audio/core:base_audio_task_api", - "//mediapipe/tasks/cc/audio/core:running_mode", - "//mediapipe/tasks/cc/components/containers:embedding_result", - "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", - "//mediapipe/tasks/cc/components/processors:embedder_options", - "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:cosine_similarity", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - # TODO: mediapipe/tasks/cc/audio/utils:test_utils does not compile in the OSS build diff --git a/mediapipe/tasks/cc/audio/core/BUILD b/mediapipe/tasks/cc/audio/core/BUILD index 93362fd3d..016faa10f 100644 --- a/mediapipe/tasks/cc/audio/core/BUILD +++ b/mediapipe/tasks/cc/audio/core/BUILD @@ -19,6 +19,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) cc_library( name = "running_mode", hdrs = ["running_mode.h"], + visibility = ["//visibility:public"], ) cc_library( diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index 2f5f8be5b..dec977fb8 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 7845a3dae..32a628db7 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -20,6 +20,7 @@ cc_library( name = "classifier_options", srcs = ["classifier_options.cc"], hdrs = ["classifier_options.h"], + visibility = ["//visibility:public"], deps = ["//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto"], ) @@ -67,6 +68,7 @@ cc_library( name = "embedder_options", srcs = ["embedder_options.cc"], hdrs = ["embedder_options.h"], + visibility = ["//visibility:public"], deps = ["//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto"], ) diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index f14457073..202f3ea3c 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -22,9 +22,7 @@ cc_library( name = "base_options", srcs = ["base_options.cc"], hdrs = ["base_options.h"], - visibility = [ - "//mediapipe/tasks:internal", - ], + visibility = ["//visibility:public"], deps = [ ":mediapipe_builtin_op_resolver", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index 52b0c0e4b..01adc9fc3 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -16,6 +16,33 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Text Classifier +# https://developers.google.com/mediapipe/solutions/text/text_classifier +cc_library( + name = "text_classifier", + srcs = ["text_classifier.cc"], + hdrs = ["text_classifier.h"], + visibility = ["//visibility:public"], + deps = [ + ":text_classifier_graph", + "//mediapipe/framework:packet", + "//mediapipe/framework/api2:builder", + "//mediapipe/tasks/cc/components/containers:category", + "//mediapipe/tasks/cc/components/containers:classification_result", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:task_api_factory", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "text_classifier_graph", srcs = ["text_classifier_graph.cc"], @@ -41,30 +68,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "text_classifier", - srcs = ["text_classifier.cc"], - hdrs = ["text_classifier.h"], - deps = [ - ":text_classifier_graph", - "//mediapipe/framework:packet", - "//mediapipe/framework/api2:builder", - "//mediapipe/tasks/cc/components/containers:category", - "//mediapipe/tasks/cc/components/containers:classification_result", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:task_api_factory", - "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - cc_test( name = "text_classifier_test", srcs = ["text_classifier_test.cc"], diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD index e2e16c9c1..27c9cb730 100644 --- a/mediapipe/tasks/cc/text/text_embedder/BUILD +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -16,10 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Text Embedder +# https://developers.google.com/mediapipe/solutions/text/text_embedder cc_library( name = "text_embedder", srcs = ["text_embedder.cc"], hdrs = ["text_embedder.h"], + visibility = ["//visibility:public"], deps = [ ":text_embedder_graph", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", diff --git a/mediapipe/tasks/cc/vision/core/BUILD b/mediapipe/tasks/cc/vision/core/BUILD index e8e197a1d..1f5ab5faf 100644 --- a/mediapipe/tasks/cc/vision/core/BUILD +++ b/mediapipe/tasks/cc/vision/core/BUILD @@ -19,11 +19,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) cc_library( name = "running_mode", hdrs = ["running_mode.h"], + visibility = ["//visibility:public"], ) cc_library( name = "image_processing_options", hdrs = ["image_processing_options.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/cc/components/containers:rect", ], diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index 75289b1e8..7b144e7aa 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -18,6 +18,52 @@ package(default_visibility = [ licenses(["notice"]) +# Docs for Mediapipe Tasks Gesture Recognizer +# https://developers.google.com/mediapipe/solutions/vision/gesture_recognizer +cc_library( + name = "gesture_recognizer", + srcs = ["gesture_recognizer.cc"], + hdrs = ["gesture_recognizer.h"], + visibility = ["//visibility:public"], + deps = [ + ":gesture_recognizer_graph", + ":gesture_recognizer_result", + ":hand_gesture_recognizer_graph", + "//mediapipe/framework:packet", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + cc_library( name = "handedness_util", srcs = ["handedness_util.cc"], @@ -127,51 +173,9 @@ cc_library( cc_library( name = "gesture_recognizer_result", hdrs = ["gesture_recognizer_result.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", ], ) - -cc_library( - name = "gesture_recognizer", - srcs = ["gesture_recognizer.cc"], - hdrs = ["gesture_recognizer.h"], - deps = [ - ":gesture_recognizer_graph", - ":gesture_recognizer_result", - ":hand_gesture_recognizer_graph", - "//mediapipe/framework:packet", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/core:base_vision_task_api", - "//mediapipe/tasks/cc/vision/core:image_processing_options", - "//mediapipe/tasks/cc/vision/core:running_mode", - "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - ], -) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 5c5073fc2..3b869eab4 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -18,6 +18,43 @@ package(default_visibility = [ licenses(["notice"]) +# Docs for Mediapipe Tasks Hand Landmarker +# https://developers.google.com/mediapipe/solutions/vision/hand_landmarker +cc_library( + name = "hand_landmarker", + srcs = ["hand_landmarker.cc"], + hdrs = ["hand_landmarker.h"], + visibility = ["//visibility:public"], + deps = [ + ":hand_landmarker_graph", + ":hand_landmarker_result", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], +) + cc_library( name = "hand_landmarks_detector_graph", srcs = ["hand_landmarks_detector_graph.cc"], @@ -113,44 +150,11 @@ cc_library( cc_library( name = "hand_landmarker_result", hdrs = ["hand_landmarker_result.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", ], ) -cc_library( - name = "hand_landmarker", - srcs = ["hand_landmarker.cc"], - hdrs = ["hand_landmarker.h"], - deps = [ - ":hand_landmarker_graph", - ":hand_landmarker_result", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/core:base_vision_task_api", - "//mediapipe/tasks/cc/vision/core:image_processing_options", - "//mediapipe/tasks/cc/vision/core:running_mode", - "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", - "@com_google_absl//absl/status:statusor", - ], -) - # TODO: Enable this test diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index b59d8d682..2b93aa262 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -16,33 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -cc_library( - name = "image_classifier_graph", - srcs = ["image_classifier_graph.cc"], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", - "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", - "@com_google_absl//absl/status:statusor", - ], - alwayslink = 1, -) - +# Docs for Mediapipe Tasks Image Classifier +# https://developers.google.com/mediapipe/solutions/vision/image_classifier cc_library( name = "image_classifier", srcs = ["image_classifier.cc"], hdrs = ["image_classifier.h"], + visibility = ["//visibility:public"], deps = [ ":image_classifier_graph", "//mediapipe/framework:packet", @@ -69,4 +49,27 @@ cc_library( ], ) +cc_library( + name = "image_classifier_graph", + srcs = ["image_classifier_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/image_embedder/BUILD b/mediapipe/tasks/cc/vision/image_embedder/BUILD index ea7f40261..8fdb97ccd 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/BUILD +++ b/mediapipe/tasks/cc/vision/image_embedder/BUILD @@ -16,33 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -cc_library( - name = "image_embedder_graph", - srcs = ["image_embedder_graph.cc"], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", - "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", - "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", - "@com_google_absl//absl/status:statusor", - ], - alwayslink = 1, -) - +# Docs for Mediapipe Tasks Image Embedder +# https://developers.google.com/mediapipe/solutions/vision/image_embedder cc_library( name = "image_embedder", srcs = ["image_embedder.cc"], hdrs = ["image_embedder.h"], + visibility = ["//visibility:public"], deps = [ ":image_embedder_graph", "//mediapipe/framework/api2:builder", @@ -67,4 +47,27 @@ cc_library( ], ) +cc_library( + name = "image_embedder_graph", + srcs = ["image_embedder_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", + "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 7206a45ea..595eef568 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -16,13 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Image Segmenter +# https://developers.google.com/mediapipe/solutions/vision/image_segmenter cc_library( name = "image_segmenter", srcs = ["image_segmenter.cc"], hdrs = ["image_segmenter.h"], - visibility = [ - "//mediapipe/tasks:internal", - ], + visibility = ["//visibility:public"], deps = [ ":image_segmenter_graph", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 8220d8b7f..b8002fa96 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -16,6 +16,41 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Object Detector +# https://developers.google.com/mediapipe/solutions/vision/object_detector +cc_library( + name = "object_detector", + srcs = ["object_detector.cc"], + hdrs = ["object_detector.h"], + visibility = [ + "//mediapipe/tasks:users", + ], + deps = [ + ":object_detector_graph", + "//mediapipe/calculators/core:concatenate_vector_calculator", + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "object_detector_graph", srcs = ["object_detector_graph.cc"], @@ -56,34 +91,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "object_detector", - srcs = ["object_detector.cc"], - hdrs = ["object_detector.h"], - deps = [ - ":object_detector_graph", - "//mediapipe/calculators/core:concatenate_vector_calculator", - "//mediapipe/calculators/core:split_vector_calculator", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/formats:detection_cc_proto", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/core:base_vision_task_api", - "//mediapipe/tasks/cc/vision/core:image_processing_options", - "//mediapipe/tasks/cc/vision/core:running_mode", - "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - # TODO: This test fails in OSS From 3ac7f6a216c12d617edd6549ace59f4f76e085c7 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Sun, 20 Nov 2022 19:30:05 -0800 Subject: [PATCH 025/346] Simplify image creation in PacketCreator Use more existing functions, remove redundant code, remove direct use of RuntimeException. PiperOrigin-RevId: 489868983 --- .../mediapipe/framework/PacketCreator.java | 53 +++++---- .../framework/jni/packet_creator_jni.cc | 104 +++++------------- .../framework/jni/packet_creator_jni.h | 2 +- 3 files changed, 64 insertions(+), 95 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java index d93eea7b5..04265cab5 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java @@ -55,7 +55,11 @@ public class PacketCreator { public Packet createRgbImage(ByteBuffer buffer, int width, int height) { int widthStep = (((width * 3) + 3) / 4) * 4; if (widthStep * height != buffer.capacity()) { - throw new RuntimeException("The size of the buffer should be: " + widthStep * height); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + widthStep * height + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateRgbImage(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -123,7 +127,11 @@ public class PacketCreator { */ public Packet createRgbImageFromRgba(ByteBuffer buffer, int width, int height) { if (width * height * 4 != buffer.capacity()) { - throw new RuntimeException("The size of the buffer should be: " + width * height * 4); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + width * height * 4 + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateRgbImageFromRgba(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -136,7 +144,7 @@ public class PacketCreator { */ public Packet createGrayscaleImage(ByteBuffer buffer, int width, int height) { if (width * height != buffer.capacity()) { - throw new RuntimeException( + throw new IllegalArgumentException( "The size of the buffer should be: " + width * height + " but is " + buffer.capacity()); } return Packet.create( @@ -150,7 +158,11 @@ public class PacketCreator { */ public Packet createRgbaImageFrame(ByteBuffer buffer, int width, int height) { if (buffer.capacity() != width * height * 4) { - throw new RuntimeException("buffer doesn't have the correct size."); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + width * height * 4 + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateRgbaImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -163,7 +175,11 @@ public class PacketCreator { */ public Packet createFloatImageFrame(FloatBuffer buffer, int width, int height) { if (buffer.capacity() != width * height * 4) { - throw new RuntimeException("buffer doesn't have the correct size."); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + width * height * 4 + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateFloatImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -354,25 +370,24 @@ public class PacketCreator { *

For 3 and 4 channel images, the pixel rows should have 4-byte alignment. */ public Packet createImage(ByteBuffer buffer, int width, int height, int numChannels) { + int widthStep; if (numChannels == 4) { - if (buffer.capacity() != width * height * 4) { - throw new RuntimeException("buffer doesn't have the correct size."); - } + widthStep = width * 4; } else if (numChannels == 3) { - int widthStep = (((width * 3) + 3) / 4) * 4; - if (widthStep * height != buffer.capacity()) { - throw new RuntimeException("The size of the buffer should be: " + widthStep * height); - } + widthStep = (((width * 3) + 3) / 4) * 4; } else if (numChannels == 1) { - if (width * height != buffer.capacity()) { - throw new RuntimeException( - "The size of the buffer should be: " + width * height + " but is " + buffer.capacity()); - } + widthStep = width; } else { - throw new RuntimeException("Channels should be: 1, 3, or 4, but is " + numChannels); + throw new IllegalArgumentException("Channels should be: 1, 3, or 4, but is " + numChannels); + } + int expectedSize = widthStep * height; + if (buffer.capacity() != expectedSize) { + throw new IllegalArgumentException( + "The size of the buffer should be: " + expectedSize + " but is " + buffer.capacity()); } return Packet.create( - nativeCreateCpuImage(mediapipeGraph.getNativeHandle(), buffer, width, height, numChannels)); + nativeCreateCpuImage( + mediapipeGraph.getNativeHandle(), buffer, width, height, widthStep, numChannels)); } /** Helper callback adaptor to create the Java {@link GlSyncToken}. This is called by JNI code. */ @@ -430,7 +445,7 @@ public class PacketCreator { long context, int name, int width, int height, TextureReleaseCallback releaseCallback); private native long nativeCreateCpuImage( - long context, ByteBuffer buffer, int width, int height, int numChannels); + long context, ByteBuffer buffer, int width, int height, int rowBytes, int numChannels); private native long nativeCreateInt32Array(long context, int[] data); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index 2d5447401..46ea1ce41 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -111,22 +111,8 @@ absl::StatusOr CreateGpuBuffer( // ByteBuffer. absl::StatusOr> CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width, - jint height, + jint height, jint width_step, mediapipe::ImageFormat::Format format) { - switch (format) { - case mediapipe::ImageFormat::SRGBA: - case mediapipe::ImageFormat::SRGB: - case mediapipe::ImageFormat::GRAY8: - break; - default: - return absl::InvalidArgumentError( - "Format must be either SRGBA, SRGB, or GRAY8."); - } - - auto image_frame = std::make_unique( - format, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); const void* buffer_data = env->GetDirectBufferAddress(byte_buffer); if (buffer_data == nullptr || buffer_size < 0) { @@ -135,34 +121,19 @@ CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width, "using allocateDirect."); } - const int num_channels = image_frame->NumberOfChannels(); - const int expected_buffer_size = - num_channels == 1 ? width * height : image_frame->PixelDataSize(); - + const int expected_buffer_size = height * width_step; RET_CHECK_EQ(buffer_size, expected_buffer_size) - << (num_channels != 1 - ? "The input image buffer should have 4 bytes alignment. " - : "") - << "Please check the input buffer size." - << " Buffer size: " << buffer_size - << ", Buffer size needed: " << expected_buffer_size - << ", Image width: " << width; + << "Input buffer size should be " << expected_buffer_size + << " but is: " << buffer_size; - // Copy buffer data to image frame's pixel_data_. - if (num_channels == 1) { - const int width_step = image_frame->WidthStep(); - const char* src_row = reinterpret_cast(buffer_data); - char* dst_row = reinterpret_cast(image_frame->MutablePixelData()); - for (int i = height; i > 0; --i) { - std::memcpy(dst_row, src_row, width); - src_row += width; - dst_row += width_step; - } - } else { - // 3 and 4 channels. - std::memcpy(image_frame->MutablePixelData(), buffer_data, - image_frame->PixelDataSize()); - } + auto image_frame = std::make_unique(); + // TODO: we could retain the buffer with a special deleter and use + // the data directly without a copy. May need a new Java API since existing + // code might expect to be able to overwrite the buffer after creating an + // ImageFrame from it. + image_frame->CopyPixelData( + format, width, height, width_step, static_cast(buffer_data), + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); return image_frame; } @@ -183,8 +154,12 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame_or = CreateImageFrameFromByteBuffer( - env, byte_buffer, width, height, mediapipe::ImageFormat::SRGB); + // We require 4-byte alignment. See Java method. + constexpr int kAlignment = 4; + int width_step = ((width * 3 - 1) | (kAlignment - 1)) + 1; + auto image_frame_or = + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, + width_step, mediapipe::ImageFormat::SRGB); if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); @@ -204,10 +179,8 @@ absl::StatusOr> CreateRgbImageFromRgba( const int expected_buffer_size = width * height * 4; RET_CHECK_EQ(buffer_size, expected_buffer_size) - << "Please check the input buffer size." - << " Buffer size: " << buffer_size - << ", Buffer size needed: " << expected_buffer_size - << ", Image width: " << width; + << "Input buffer size should be " << expected_buffer_size + << " but is: " << buffer_size; auto image_frame = absl::make_unique( mediapipe::ImageFormat::SRGB, width, height, @@ -232,7 +205,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { auto image_frame_or = CreateImageFrameFromByteBuffer( - env, byte_buffer, width, height, mediapipe::ImageFormat::GRAY8); + env, byte_buffer, width, height, width, mediapipe::ImageFormat::GRAY8); if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); @@ -242,28 +215,9 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - // TODO: merge this case with CreateImageFrameFromByteBuffer. auto image_frame_or = - [&]() -> absl::StatusOr> { - const void* data = env->GetDirectBufferAddress(byte_buffer); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (data == nullptr || buffer_size < 0) { - return absl::InvalidArgumentError( - "input buffer does not support direct access"); - } - - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::VEC32F1, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - RET_CHECK_EQ(buffer_size, image_frame->PixelDataSize()) - << "Please check the input buffer size." - << " Buffer size: " << buffer_size - << ", Buffer size needed: " << image_frame->PixelDataSize() - << ", Image width: " << width; - std::memcpy(image_frame->MutablePixelData(), data, - image_frame->PixelDataSize()); - return image_frame; - }(); + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, width * 4, + mediapipe::ImageFormat::VEC32F1); if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); @@ -272,10 +226,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame_or = CreateImageFrameFromByteBuffer( - env, byte_buffer, width, height, mediapipe::ImageFormat::SRGBA); + auto image_frame_or = + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, width * 4, + mediapipe::ImageFormat::SRGBA); if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } @@ -417,7 +371,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, - jint height, jint num_channels) { + jint height, jint width_step, jint num_channels) { mediapipe::ImageFormat::Format format; switch (num_channels) { case 4: @@ -436,8 +390,8 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( return 0L; } - auto image_frame_or = - CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, format); + auto image_frame_or = CreateImageFrameFromByteBuffer( + env, byte_buffer, width, height, width_step, format); if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h index d6f44b0a3..b3b1043fb 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h @@ -99,7 +99,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, - jint height, jint num_channels); + jint height, jint width_step, jint num_channels); JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuImage)( JNIEnv* env, jobject thiz, jlong context, jint name, jint width, From 13c6b9a8c6ce6fc9d0e34316821d497bb7f4f9f2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sun, 20 Nov 2022 22:18:49 -0800 Subject: [PATCH 026/346] Allow kernel cache path to be specified without trailing path delimiter PiperOrigin-RevId: 489891079 --- .../calculators/tensor/inference_calculator_gl_advanced.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index ad5df849f..c2c723402 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -241,9 +241,9 @@ absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init( gpu_delegate_options.has_model_token(); if (use_kernel_caching_) { - cached_kernel_filename_ = gpu_delegate_options.cached_kernel_path() + - mediapipe::File::Basename(options.model_path()) + - ".ker"; + cached_kernel_filename_ = mediapipe::file::JoinPath( + gpu_delegate_options.cached_kernel_path(), + mediapipe::File::Basename(options.model_path()) + ".ker"); } if (use_serialized_model_) { serialized_model_path_ = From 7acbf557a1294e3809e8671ac769c855dd3336c4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 01:55:49 -0800 Subject: [PATCH 027/346] Cleanup after migration to new classification output format. PiperOrigin-RevId: 489921603 --- .../tasks/cc/components/calculators/BUILD | 1 - .../classification_aggregation_calculator.cc | 68 +--- .../cc/components/containers/proto/BUILD | 6 - .../containers/proto/category.proto | 41 --- .../containers/proto/classifications.proto | 17 +- .../classification_postprocessing_graph.cc | 9 - .../classification_postprocessing_graph.h | 3 - ...lassification_postprocessing_graph_test.cc | 322 ------------------ .../text_classifier/text_classifier_graph.cc | 27 +- .../image_classifier_graph.cc | 9 - .../com/google/mediapipe/tasks/text/BUILD | 1 - .../com/google/mediapipe/tasks/vision/BUILD | 1 - .../tasks/python/components/containers/BUILD | 2 +- .../python/components/containers/category.py | 16 +- .../containers/classification_result.py | 15 +- 15 files changed, 23 insertions(+), 515 deletions(-) delete mode 100644 mediapipe/tasks/cc/components/containers/proto/category.proto diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index 1f726a018..16931811c 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -37,7 +37,6 @@ cc_library( "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/tasks/cc/components/containers/proto:category_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "@com_google_absl//absl/status", ], diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc index 1a83fdad2..ad2c668c3 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc @@ -25,14 +25,12 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" namespace mediapipe { namespace api2 { using ::mediapipe::tasks::components::containers::proto::ClassificationResult; -using ::mediapipe::tasks::components::containers::proto::Classifications; // Aggregates ClassificationLists into either a ClassificationResult object // representing the classification results aggregated by classifier head, or @@ -57,9 +55,6 @@ using ::mediapipe::tasks::components::containers::proto::Classifications; // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // Example without timestamp aggregation: // node { @@ -122,9 +117,6 @@ class ClassificationAggregationCalculator : public Node { ClassificationResult ConvertToClassificationResult(CalculatorContext* cc); std::vector ConvertToTimestampedClassificationResults( CalculatorContext* cc); - // TODO: deprecate this function once migration is over. - ClassificationResult LegacyConvertToClassificationResult( - CalculatorContext* cc); }; absl::Status ClassificationAggregationCalculator::UpdateContract( @@ -137,10 +129,11 @@ absl::Status ClassificationAggregationCalculator::UpdateContract( << "The size of classifications input streams should match the " "size of head names specified in the calculator options"; } - // TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if - // TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is - // not connected. All dependent tasks must be updated to use these outputs - // first. + if (kTimestampsIn(cc).IsConnected()) { + RET_CHECK(kTimestampedClassificationsOut(cc).IsConnected()); + } else { + RET_CHECK(kClassificationsOut(cc).IsConnected()); + } return absl::OkStatus(); } @@ -170,11 +163,9 @@ absl::Status ClassificationAggregationCalculator::Process( if (kTimestampsIn(cc).IsEmpty()) { return absl::OkStatus(); } - classification_result = LegacyConvertToClassificationResult(cc); kTimestampedClassificationsOut(cc).Send( ConvertToTimestampedClassificationResults(cc)); } else { - classification_result = LegacyConvertToClassificationResult(cc); kClassificationsOut(cc).Send(ConvertToClassificationResult(cc)); } kClassificationResultOut(cc).Send(classification_result); @@ -226,55 +217,6 @@ ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults( return results; } -ClassificationResult -ClassificationAggregationCalculator::LegacyConvertToClassificationResult( - CalculatorContext* cc) { - ClassificationResult result; - Timestamp first_timestamp(0); - std::vector timestamps; - if (time_aggregation_enabled_) { - timestamps = kTimestampsIn(cc).Get(); - first_timestamp = timestamps[0]; - } else { - timestamps = {cc->InputTimestamp()}; - } - for (Timestamp timestamp : timestamps) { - int count = cached_classifications_[timestamp.Value()].size(); - for (int i = 0; i < count; ++i) { - Classifications* c; - if (result.classifications_size() <= i) { - c = result.add_classifications(); - if (!head_names_.empty()) { - c->set_head_index(i); - c->set_head_name(head_names_[i]); - } - } else { - c = result.mutable_classifications(i); - } - auto* entry = c->add_entries(); - for (const auto& elem : - cached_classifications_[timestamp.Value()][i].classification()) { - auto* category = entry->add_categories(); - if (elem.has_index()) { - category->set_index(elem.index()); - } - if (elem.has_score()) { - category->set_score(elem.score()); - } - if (elem.has_label()) { - category->set_category_name(elem.label()); - } - if (elem.has_display_name()) { - category->set_display_name(elem.display_name()); - } - } - entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) / - 1000); - } - } - return result; -} - MEDIAPIPE_REGISTER_NODE(ClassificationAggregationCalculator); } // namespace api2 diff --git a/mediapipe/tasks/cc/components/containers/proto/BUILD b/mediapipe/tasks/cc/components/containers/proto/BUILD index 7b455c0c4..27d2357b5 100644 --- a/mediapipe/tasks/cc/components/containers/proto/BUILD +++ b/mediapipe/tasks/cc/components/containers/proto/BUILD @@ -18,16 +18,10 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -mediapipe_proto_library( - name = "category_proto", - srcs = ["category.proto"], -) - mediapipe_proto_library( name = "classifications_proto", srcs = ["classifications.proto"], deps = [ - ":category_proto", "//mediapipe/framework/formats:classification_proto", ], ) diff --git a/mediapipe/tasks/cc/components/containers/proto/category.proto b/mediapipe/tasks/cc/components/containers/proto/category.proto deleted file mode 100644 index 412e71428..000000000 --- a/mediapipe/tasks/cc/components/containers/proto/category.proto +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -syntax = "proto2"; - -package mediapipe.tasks.components.containers.proto; - -option java_package = "com.google.mediapipe.tasks.components.containers.proto"; -option java_outer_classname = "CategoryProto"; - -// TODO: deprecate this message once migration is over. -// A single classification result. -message Category { - // The index of the category in the corresponding label map, usually packed in - // the TFLite Model Metadata [1]. - // - // [1]: https://www.tensorflow.org/lite/convert/metadata - optional int32 index = 1; - // The score for this category, e.g. (but not necessarily) a probability in - // [0,1]. - optional float score = 2; - // A human readable name of the category filled from the label map. - optional string display_name = 3; - // An ID for the category, not necessarily human-readable, e.g. a Google - // Knowledge Graph ID [1], filled from the label map. - // - // [1]: https://developers.google.com/knowledge-graph - optional string category_name = 4; -} diff --git a/mediapipe/tasks/cc/components/containers/proto/classifications.proto b/mediapipe/tasks/cc/components/containers/proto/classifications.proto index f098ed0e4..2b2306829 100644 --- a/mediapipe/tasks/cc/components/containers/proto/classifications.proto +++ b/mediapipe/tasks/cc/components/containers/proto/classifications.proto @@ -18,27 +18,12 @@ syntax = "proto2"; package mediapipe.tasks.components.containers.proto; import "mediapipe/framework/formats/classification.proto"; -import "mediapipe/tasks/cc/components/containers/proto/category.proto"; option java_package = "com.google.mediapipe.tasks.components.containers.proto"; option java_outer_classname = "ClassificationsProto"; -// TODO: deprecate this message once migration is over. -// List of predicted categories with an optional timestamp. -message ClassificationEntry { - // The array of predicted categories, usually sorted by descending scores, - // e.g., from high to low probability. - repeated Category categories = 1; - // The optional timestamp (in milliseconds) associated to the classifcation - // entry. This is useful for time series use cases, e.g., audio - // classification. - optional int64 timestamp_ms = 2; -} - // Classifications for a given classifier head, i.e. for a given output tensor. message Classifications { - // TODO: deprecate this field once migration is over. - repeated ClassificationEntry entries = 1; // The classification results for this head. optional mediapipe.ClassificationList classification_list = 4; // The index of the classifier head these categories refer to. This is useful @@ -48,6 +33,8 @@ message Classifications { // name. // TODO: Add github link to metadata_schema.fbs. optional string head_name = 3; + // Reserved fields. + reserved 1; } // Classifications for a given classifier model. diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 0fb62afaf..5a0472f5c 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -73,7 +73,6 @@ using TensorsSource = mediapipe::tasks::SourceOrNodeOutput>; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kScoresTag[] = "SCORES"; constexpr char kTensorsTag[] = "TENSORS"; @@ -82,7 +81,6 @@ constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; // Struct holding the different output streams produced by the graph. struct ClassificationPostprocessingOutputStreams { - Source classification_result; Source classifications; Source> timestamped_classifications; }; @@ -400,9 +398,6 @@ absl::Status ConfigureClassificationPostprocessingGraph( // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // The recommended way of using this graph is through the GraphBuilder API // using the 'ConfigureClassificationPostprocessingGraph()' function. See header @@ -418,8 +413,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { sc->Options(), graph[Input>(kTensorsTag)], graph[Input>(kTimestampsTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; output_streams.classifications >> graph[Output(kClassificationsTag)]; output_streams.timestamped_classifications >> @@ -536,8 +529,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { // Connects output. ClassificationPostprocessingOutputStreams output_streams{ - /*classification_result=*/result_aggregation - [Output(kClassificationResultTag)], /*classifications=*/ result_aggregation[Output(kClassificationsTag)], /*timestamped_classifications=*/ diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h index 48575ceb0..03ae91130 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h @@ -58,9 +58,6 @@ namespace processors { // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. absl::Status ConfigureClassificationPostprocessingGraph( const tasks::core::ModelResources& model_resources, const proto::ClassifierOptions& classifier_options, diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc index d4728e725..8eb6f3c3b 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc @@ -86,8 +86,6 @@ constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsName[] = "tensors"; constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsName[] = "timestamps"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; -constexpr char kClassificationResultName[] = "classification_result"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kClassificationsName[] = "classifications"; constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; @@ -728,326 +726,6 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) { })pb")})); } -// TODO: remove these tests once migration is over. -class LegacyPostprocessingTest : public tflite_shims::testing::Test { - protected: - absl::StatusOr BuildGraph( - absl::string_view model_name, const proto::ClassifierOptions& options, - bool connect_timestamps = false) { - ASSIGN_OR_RETURN(auto model_resources, - CreateModelResourcesForModel(model_name)); - - Graph graph; - auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.processors." - "ClassificationPostprocessingGraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph( - *model_resources, options, - &postprocessing - .GetOptions())); - graph[Input>(kTensorsTag)].SetName(kTensorsName) >> - postprocessing.In(kTensorsTag); - if (connect_timestamps) { - graph[Input>(kTimestampsTag)].SetName( - kTimestampsName) >> - postprocessing.In(kTimestampsTag); - } - postprocessing.Out(kClassificationResultTag) - .SetName(kClassificationResultName) >> - graph[Output(kClassificationResultTag)]; - - MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig())); - ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( - kClassificationResultName)); - MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); - return poller; - } - - template - void AddTensor( - const std::vector& tensor, const Tensor::ElementType& element_type, - const Tensor::QuantizationParameters& quantization_parameters = {}) { - tensors_->emplace_back(element_type, - Tensor::Shape{1, static_cast(tensor.size())}, - quantization_parameters); - auto view = tensors_->back().GetCpuWriteView(); - T* buffer = view.buffer(); - std::copy(tensor.begin(), tensor.end(), buffer); - } - - absl::Status Run( - std::optional> aggregation_timestamps = std::nullopt, - int timestamp = 0) { - MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( - kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp)))); - // Reset tensors for future calls. - tensors_ = absl::make_unique>(); - if (aggregation_timestamps.has_value()) { - auto packet = absl::make_unique>(); - for (const auto& timestamp : *aggregation_timestamps) { - packet->emplace_back(Timestamp(timestamp)); - } - MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( - kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp)))); - } - return absl::OkStatus(); - } - - absl::StatusOr GetClassificationResult( - OutputStreamPoller& poller) { - MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle()); - MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams()); - - Packet packet; - if (!poller.Next(&packet)) { - return absl::InternalError("Unable to get output packet"); - } - auto result = packet.Get(); - MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone()); - return result; - } - - private: - CalculatorGraph calculator_graph_; - std::unique_ptr> tensors_ = - absl::make_unique>(); -}; - -TEST_F(LegacyPostprocessingTest, SucceedsWithoutMetadata) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - options.set_score_threshold(0.5); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kQuantizedImageClassifierWithoutMetadata, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 18; - tensor[2] = 16; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT(results, EqualsProto(R"pb(classifications { - entries { - categories { index: 1 score: 0.8 } - categories { index: 2 score: 0.6 } - timestamp_ms: 0 - } - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithMetadata) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 12; - tensor[2] = 14; - tensor[3] = 16; - tensor[4] = 18; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT( - results, - EqualsProto( - R"pb(classifications { - entries { - categories { - index: 4 - score: 0.8 - category_name: "tiger shark" - } - categories { - index: 3 - score: 0.6 - category_name: "great white shark" - } - categories { index: 2 score: 0.4 category_name: "goldfish" } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithScoreCalibration) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 12; - tensor[2] = 14; - tensor[3] = 16; - tensor[4] = 18; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT(results, EqualsProto( - R"pb(classifications { - entries { - categories { - index: 4 - score: 0.6899744811 - category_name: "tiger shark" - } - categories { - index: 3 - score: 0.6456563062 - category_name: "great white shark" - } - categories { - index: 2 - score: 0.5986876601 - category_name: "goldfish" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithMultipleHeads) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(2); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kFloatTwoHeadsAudioClassifierWithMetadata, options)); - // Build input tensors. - std::vector tensor_0(kTwoHeadsNumClasses[0], 0); - tensor_0[1] = 0.2; - tensor_0[2] = 0.4; - tensor_0[3] = 0.6; - std::vector tensor_1(kTwoHeadsNumClasses[1], 0); - tensor_1[1] = 0.2; - tensor_1[2] = 0.4; - tensor_1[3] = 0.6; - - // Send tensors and get results. - AddTensor(tensor_0, Tensor::ElementType::kFloat32); - AddTensor(tensor_1, Tensor::ElementType::kFloat32); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - EXPECT_THAT(results, EqualsProto( - R"pb(classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "Narration, monologue" - } - categories { - index: 2 - score: 0.4 - category_name: "Conversation" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "yamnet_classification" - } - classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "Azara\'s Spinetail" - } - categories { - index: 2 - score: 0.4 - category_name: "House Sparrow" - } - timestamp_ms: 0 - } - head_index: 1 - head_name: "bird_classification" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithTimestamps) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(2); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, - /*connect_timestamps=*/true)); - // Build input tensors. - std::vector tensor_0(kMobileNetNumClasses, 0); - tensor_0[1] = 12; - tensor_0[2] = 14; - tensor_0[3] = 16; - std::vector tensor_1(kMobileNetNumClasses, 0); - tensor_1[5] = 12; - tensor_1[6] = 14; - tensor_1[7] = 16; - - // Send tensors and get results. - AddTensor(tensor_0, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - AddTensor(tensor_1, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run( - /*aggregation_timestamps=*/std::optional>({0, 1000}), - /*timestamp=*/1000)); - - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT( - results, - EqualsProto( - R"pb(classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "great white shark" - } - categories { index: 2 score: 0.4 category_name: "goldfish" } - timestamp_ms: 0 - } - entries { - categories { index: 7 score: 0.6 category_name: "stingray" } - categories { - index: 6 - score: 0.4 - category_name: "electric ray" - } - timestamp_ms: 1 - } - head_index: 0 - head_name: "probability" - })pb")); -} - } // namespace } // namespace processors } // namespace components diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc index 36ff68a07..9a7dce1aa 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc @@ -46,19 +46,11 @@ using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kTextTag[] = "TEXT"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kTensorsTag[] = "TENSORS"; -// TODO: remove once Java API migration is over. -// Struct holding the different output streams produced by the text classifier. -struct TextClassifierOutputStreams { - Source classification_result; - Source classifications; -}; - } // namespace // A "TextClassifierGraph" performs Natural Language classification (including @@ -72,10 +64,6 @@ struct TextClassifierOutputStreams { // Outputs: // CLASSIFICATIONS - ClassificationResult @Optional // The classification results aggregated by classifier head. -// TODO: remove once Java API migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result object that has 3 dimensions: -// (classification head, classification timestamp, classification category). // // Example: // node { @@ -102,14 +90,11 @@ class TextClassifierGraph : public core::ModelTaskGraph { CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN( - auto output_streams, + auto classifications, BuildTextClassifierTask( sc->Options(), *model_resources, graph[Input(kTextTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; - output_streams.classifications >> - graph[Output(kClassificationsTag)]; + classifications >> graph[Output(kClassificationsTag)]; return graph.GetConfig(); } @@ -124,7 +109,7 @@ class TextClassifierGraph : public core::ModelTaskGraph { // TextClassifier model file with model metadata. // text_in: (std::string) stream to run text classification on. // graph: the mediapipe builder::Graph instance to be updated. - absl::StatusOr BuildTextClassifierTask( + absl::StatusOr> BuildTextClassifierTask( const proto::TextClassifierGraphOptions& task_options, const ModelResources& model_resources, Source text_in, Graph& graph) { @@ -161,11 +146,7 @@ class TextClassifierGraph : public core::ModelTaskGraph { // Outputs the aggregated classification result as the subgraph output // stream. - return TextClassifierOutputStreams{ - /*classification_result=*/postprocessing[Output( - kClassificationResultTag)], - /*classifications=*/postprocessing[Output( - kClassificationsTag)]}; + return postprocessing[Output(kClassificationsTag)]; } }; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 8fa1a0d2a..2fc88bcb6 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -47,7 +47,6 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kImageTag[] = "IMAGE"; constexpr char kNormRectTag[] = "NORM_RECT"; @@ -56,7 +55,6 @@ constexpr char kTensorsTag[] = "TENSORS"; // Struct holding the different output streams produced by the image classifier // subgraph. struct ImageClassifierOutputStreams { - Source classification_result; Source classifications; Source image; }; @@ -77,9 +75,6 @@ struct ImageClassifierOutputStreams { // The classification results aggregated by classifier head. // IMAGE - Image // The image that object detection runs on. -// TODO: remove this output once Java API migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // Example: // node { @@ -117,8 +112,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph { sc->Options(), *model_resources, graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; output_streams.classifications >> graph[Output(kClassificationsTag)]; output_streams.image >> graph[Output(kImageTag)]; @@ -174,8 +167,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // Outputs the aggregated classification result as the subgraph output // stream. return ImageClassifierOutputStreams{ - /*classification_result=*/postprocessing[Output( - kClassificationResultTag)], /*classifications=*/ postprocessing[Output(kClassificationsTag)], /*image=*/preprocessing[Output(kImageTag)]}; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 0e72878ab..023a1f286 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -48,7 +48,6 @@ android_library( deps = [ "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 289e3000d..72cee133f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -97,7 +97,6 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index d931c26c7..9d275e167 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -68,7 +68,7 @@ py_library( name = "category", srcs = ["category.py"], deps = [ - "//mediapipe/tasks/cc/components/containers/proto:category_py_pb2", + "//mediapipe/framework/formats:classification_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], ) diff --git a/mediapipe/tasks/python/components/containers/category.py b/mediapipe/tasks/python/components/containers/category.py index cfdb83740..9b5419883 100644 --- a/mediapipe/tasks/python/components/containers/category.py +++ b/mediapipe/tasks/python/components/containers/category.py @@ -16,10 +16,10 @@ import dataclasses from typing import Any, Optional -from mediapipe.tasks.cc.components.containers.proto import category_pb2 +from mediapipe.framework.formats import classification_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls -_CategoryProto = category_pb2.Category +_ClassificationProto = classification_pb2.Classification @dataclasses.dataclass @@ -45,23 +45,23 @@ class Category: category_name: Optional[str] = None @doc_controls.do_not_generate_docs - def to_pb2(self) -> _CategoryProto: + def to_pb2(self) -> _ClassificationProto: """Generates a Category protobuf object.""" - return _CategoryProto( + return _ClassificationProto( index=self.index, score=self.score, - display_name=self.display_name, - category_name=self.category_name) + label=self.category_name, + display_name=self.display_name) @classmethod @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _CategoryProto) -> 'Category': + def create_from_pb2(cls, pb2_obj: _ClassificationProto) -> 'Category': """Creates a `Category` object from the given protobuf object.""" return Category( index=pb2_obj.index, score=pb2_obj.score, display_name=pb2_obj.display_name, - category_name=pb2_obj.category_name) + category_name=pb2_obj.label) def __eq__(self, other: Any) -> bool: """Checks if this object is equal to the given object. diff --git a/mediapipe/tasks/python/components/containers/classification_result.py b/mediapipe/tasks/python/components/containers/classification_result.py index 6ffdabe51..000468041 100644 --- a/mediapipe/tasks/python/components/containers/classification_result.py +++ b/mediapipe/tasks/python/components/containers/classification_result.py @@ -49,11 +49,7 @@ class Classifications: """Generates a Classifications protobuf object.""" classification_list_proto = _ClassificationListProto() for category in self.categories: - classification_proto = _ClassificationProto( - index=category.index, - score=category.score, - label=category.category_name, - display_name=category.display_name) + classification_proto = category.to_pb2() classification_list_proto.classification.append(classification_proto) return _ClassificationsProto( classification_list=classification_list_proto, @@ -65,14 +61,9 @@ class Classifications: def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications': """Creates a `Classifications` object from the given protobuf object.""" categories = [] - for entry in pb2_obj.classification_list.classification: + for classification in pb2_obj.classification_list.classification: categories.append( - category_module.Category( - index=entry.index, - score=entry.score, - display_name=entry.display_name, - category_name=entry.label)) - + category_module.Category.create_from_pb2(classification)) return Classifications( categories=categories, head_index=pb2_obj.head_index, From 7f0134eecbe75a94bcda7cf113e1ae8aa47cd916 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 12:13:38 -0800 Subject: [PATCH 028/346] Internal change PiperOrigin-RevId: 490041386 --- mediapipe/tasks/python/core/BUILD | 1 + mediapipe/tasks/python/text/BUILD | 1 + 2 files changed, 2 insertions(+) diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index 76e2f4f4a..fc0018ab1 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -31,6 +31,7 @@ py_library( py_library( name = "base_options", srcs = ["base_options.py"], + visibility = ["//mediapipe/tasks:users"], deps = [ ":optional_dependencies", "//mediapipe/tasks/cc/core/proto:base_options_py_pb2", diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index bb42da912..10b4b8a6e 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -23,6 +23,7 @@ py_library( srcs = [ "text_classifier.py", ], + visibility = ["//mediapipe/tasks:users"], deps = [ "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", From 652423a23d9a69d5c3dabe61926a55bd77d6d610 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 13:04:53 -0800 Subject: [PATCH 029/346] Internal change PiperOrigin-RevId: 490053179 --- mediapipe/calculators/tensor/image_to_tensor_utils.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils.cc b/mediapipe/calculators/tensor/image_to_tensor_utils.cc index d27c595b5..3f91f3dc2 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_utils.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_utils.cc @@ -253,11 +253,15 @@ int GetNumOutputChannels(const mediapipe::Image& image) { } #endif // MEDIAPIPE_METAL_ENABLED #endif // !MEDIAPIPE_DISABLE_GPU - // The output tensor channel is 1 for the input image with 1 channel; And the - // output tensor channels is 3 for the input image with 3 or 4 channels. // TODO: Add a unittest here to test the behavior on GPU, i.e. // failure. - return image.channels() == 1 ? 1 : 3; + // Only output channel == 1 when running on CPU and the input image channel + // is 1. Ideally, we want to also support GPU for output channel == 1. But + // setting this on the safer side to prevent unintentional failure. + if (!image.UsesGpu() && image.channels() == 1) { + return 1; + } + return 3; } absl::StatusOr> GetInputImage( From adddf2c2abe953b0280507b6168a41bcbb5a08f3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 14:37:42 -0800 Subject: [PATCH 030/346] Extracted common test helper functions out from the unittest into a sharable library. Also migrated away from OpenCVX. PiperOrigin-RevId: 490074410 --- mediapipe/calculators/tensor/BUILD | 2 + .../tensor/image_to_tensor_calculator_test.cc | 169 ++++++------------ mediapipe/util/BUILD | 18 ++ mediapipe/util/image_test_utils.cc | 57 ++++++ mediapipe/util/image_test_utils.h | 32 ++++ 5 files changed, 166 insertions(+), 112 deletions(-) create mode 100644 mediapipe/util/image_test_utils.cc create mode 100644 mediapipe/util/image_test_utils.h diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 2a573fc44..645189a07 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -30,6 +30,7 @@ exports_files( glob(["testdata/image_to_tensor/*"]), visibility = [ "//mediapipe/calculators/image:__subpackages__", + "//mediapipe/util:__subpackages__", ], ) @@ -1133,6 +1134,7 @@ cc_test( "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/util:image_test_utils", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc index 7ea60d98e..ceb1fc502 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -36,29 +36,17 @@ #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/util/image_test_utils.h" namespace mediapipe { namespace { -cv::Mat GetRgb(absl::string_view path) { - cv::Mat bgr = cv::imread(file::JoinPath("./", path)); - cv::Mat rgb; - cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGB); - return rgb; -} +constexpr char kTestDataDir[] = + "/mediapipe/calculators/tensor/testdata/" + "image_to_tensor/"; -cv::Mat GetRgba(absl::string_view path) { - cv::Mat bgr = cv::imread(file::JoinPath("./", path)); - cv::Mat rgb; - cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGBA); - return rgb; -} - -cv::Mat GetGray(absl::string_view path) { - cv::Mat bgr = cv::imread(file::JoinPath("./", path)); - cv::Mat gray; - cv::cvtColor(bgr, gray, cv::COLOR_BGR2GRAY); - return gray; +std::string GetFilePath(absl::string_view filename) { + return file::JoinPath("./", kTestDataDir, filename); } // Image to tensor test template. @@ -259,15 +247,12 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspect) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(0); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png"), - /*float_ranges=*/{{0.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, - /*border mode*/ {}, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_keep_aspect.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, + /*border mode*/ {}, roi); } TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) { @@ -277,11 +262,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(0); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_keep_aspect_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_keep_aspect_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, @@ -295,11 +277,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectWithRotation) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * 90.0f / 180.0f); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_keep_aspect_with_rotation.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_keep_aspect_with_rotation.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, @@ -314,11 +293,9 @@ TEST(ImageToTensorCalculatorTest, roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * 90.0f / 180.0f); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_keep_aspect_with_rotation_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath( + "medium_sub_rect_keep_aspect_with_rotation_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, @@ -332,16 +309,12 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotation) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * -45.0f / 180.0f); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb( - "/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png"), - /*float_ranges=*/{{-1.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, - BorderMode::kReplicate, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_with_rotation.png")), + /*float_ranges=*/{{-1.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, + BorderMode::kReplicate, roi); } TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) { @@ -351,11 +324,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * -45.0f / 180.0f); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_with_rotation_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_with_rotation_border_zero.png")), /*float_ranges=*/{{-1.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, @@ -369,10 +339,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRect) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/large_sub_rect.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, @@ -386,15 +354,12 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectBorderZero) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png"), - /*float_ranges=*/{{0.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, - BorderMode::kZero, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_border_zero.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, + BorderMode::kZero, roi); } TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) { @@ -404,15 +369,12 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png"), - /*float_ranges=*/{{0.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, - BorderMode::kReplicate, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_keep_aspect.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + BorderMode::kReplicate, roi); } TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) { @@ -422,11 +384,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_keep_aspect_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -440,11 +399,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotation) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_keep_aspect_with_rotation.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -458,11 +414,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotationGray) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetGray("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetGray("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation.png"), + RunTest(GetGray(GetFilePath("input.jpg")), + GetGray(GetFilePath("large_sub_rect_keep_aspect_with_rotation.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -477,11 +430,9 @@ TEST(ImageToTensorCalculatorTest, roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation_border_zero.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath( + "large_sub_rect_keep_aspect_with_rotation_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -496,11 +447,9 @@ TEST(ImageToTensorCalculatorTest, roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetGray("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetGray("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation_border_zero.png"), + RunTest(GetGray(GetFilePath("input.jpg")), + GetGray(GetFilePath( + "large_sub_rect_keep_aspect_with_rotation_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -514,10 +463,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRange) { roi.set_width(1.0f); roi.set_height(1.0f); roi.set_rotation(0); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/noop_except_range.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath("noop_except_range.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -531,10 +478,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeBorderZero) { roi.set_width(1.0f); roi.set_height(1.0f); roi.set_rotation(0); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/noop_except_range.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath("noop_except_range.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index 15835aea5..55c1df59f 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -368,3 +368,21 @@ cc_test( "//mediapipe/framework/port:gtest_main", ], ) + +cc_library( + name = "image_test_utils", + testonly = 1, + srcs = ["image_test_utils.cc"], + hdrs = ["image_test_utils.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + ], +) diff --git a/mediapipe/util/image_test_utils.cc b/mediapipe/util/image_test_utils.cc new file mode 100644 index 000000000..815666985 --- /dev/null +++ b/mediapipe/util/image_test_utils.cc @@ -0,0 +1,57 @@ +#include "mediapipe/util/image_test_utils.h" + +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { + +cv::Mat GetRgb(const std::string& path) { + cv::Mat bgr = cv::imread(path); + cv::Mat rgb; + cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGB); + return rgb; +} + +cv::Mat GetRgba(const std::string& path) { + cv::Mat bgr = cv::imread(path); + cv::Mat rgb; + cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGBA); + return rgb; +} + +cv::Mat GetGray(const std::string& path) { + cv::Mat bgr = cv::imread(path); + cv::Mat gray; + cv::cvtColor(bgr, gray, cv::COLOR_BGR2GRAY); + return gray; +} + +mediapipe::ImageFormat::Format GetImageFormat(int image_channels) { + if (image_channels == 4) { + return ImageFormat::SRGBA; + } else if (image_channels == 3) { + return ImageFormat::SRGB; + } else if (image_channels == 1) { + return ImageFormat::GRAY8; + } + LOG(FATAL) << "Unsupported input image channles: " << image_channels; +} + +Packet MakeImageFramePacket(cv::Mat input, int timestamp) { + ImageFrame input_image(GetImageFormat(input.channels()), input.cols, + input.rows, input.step, input.data, [](uint8*) {}); + return MakePacket(std::move(input_image)).At(Timestamp(0)); +} + +Packet MakeImagePacket(cv::Mat input, int timestamp) { + mediapipe::Image input_image(std::make_shared( + GetImageFormat(input.channels()), input.cols, input.rows, input.step, + input.data, [](uint8*) {})); + return MakePacket(std::move(input_image)).At(Timestamp(0)); +} + +} // namespace mediapipe diff --git a/mediapipe/util/image_test_utils.h b/mediapipe/util/image_test_utils.h new file mode 100644 index 000000000..6df9644d2 --- /dev/null +++ b/mediapipe/util/image_test_utils.h @@ -0,0 +1,32 @@ +#ifndef MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_ +#define MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_ + +#include + +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/opencv_core_inc.h" + +namespace mediapipe { + +// Reads the image file into cv::Mat with RGB channels. +cv::Mat GetRgb(const std::string& path); + +// Reads the image file into cv::Mat with RGBA channels. +cv::Mat GetRgba(const std::string& path); + +// Reads the image file into cv::Mat with Gray channel. +cv::Mat GetGray(const std::string& path); + +// Converts the image channels into corresponding ImageFormat. +mediapipe::ImageFormat::Format GetImageFormat(int image_channels); + +// Converts the cv::Mat into ImageFrame packet. +Packet MakeImageFramePacket(cv::Mat input, int timestamp = 0); + +// Converts the cv::Mat into Image packet. +Packet MakeImagePacket(cv::Mat input, int timestamp = 0); + +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_ From d43d0ff615030abb9241c28e6de6e345a8dba7eb Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 15:45:29 -0800 Subject: [PATCH 031/346] Internal change PiperOrigin-RevId: 490089940 --- .../image_to_tensor_converter_opencv.cc | 43 +++++++++++++------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc index 76e46f99d..95e38f89c 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -76,31 +76,49 @@ class OpenCvProcessor : public ImageToTensorConverter { return InvalidArgumentError(absl::StrCat( "Unsupported format: ", static_cast(input.image_format()))); } - // TODO: Remove the check once tensor_buffer_offset > 0 is - // supported. - RET_CHECK_EQ(tensor_buffer_offset, 0) - << "The non-zero tensor_buffer_offset input is not supported yet."; + + RET_CHECK_GE(tensor_buffer_offset, 0) + << "The input tensor_buffer_offset needs to be non-negative."; const auto& output_shape = output_tensor.shape(); MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); const int output_height = output_shape.dims[1]; const int output_width = output_shape.dims[2]; const int output_channels = output_shape.dims[3]; + const int num_elements_per_img = + output_height * output_width * output_channels; auto buffer_view = output_tensor.GetCpuWriteView(); cv::Mat dst; const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_; switch (tensor_type_) { case Tensor::ElementType::kInt8: - dst = cv::Mat(output_height, output_width, dst_data_type, - buffer_view.buffer()); + RET_CHECK_GE(output_shape.num_elements(), + tensor_buffer_offset / sizeof(int8) + num_elements_per_img) + << "The buffer offset + the input image size is larger than the " + "allocated tensor buffer."; + dst = cv::Mat( + output_height, output_width, dst_data_type, + buffer_view.buffer() + tensor_buffer_offset / sizeof(int8)); break; case Tensor::ElementType::kFloat32: - dst = cv::Mat(output_height, output_width, dst_data_type, - buffer_view.buffer()); + RET_CHECK_GE( + output_shape.num_elements(), + tensor_buffer_offset / sizeof(float) + num_elements_per_img) + << "The buffer offset + the input image size is larger than the " + "allocated tensor buffer."; + dst = cv::Mat( + output_height, output_width, dst_data_type, + buffer_view.buffer() + tensor_buffer_offset / sizeof(float)); break; case Tensor::ElementType::kUInt8: - dst = cv::Mat(output_height, output_width, dst_data_type, - buffer_view.buffer()); + RET_CHECK_GE( + output_shape.num_elements(), + tensor_buffer_offset / sizeof(uint8) + num_elements_per_img) + << "The buffer offset + the input image size is larger than the " + "allocated tensor buffer."; + dst = cv::Mat( + output_height, output_width, dst_data_type, + buffer_view.buffer() + tensor_buffer_offset / sizeof(uint8)); break; default: return InvalidArgumentError( @@ -153,9 +171,8 @@ class OpenCvProcessor : public ImageToTensorConverter { absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) { RET_CHECK_EQ(output_shape.dims.size(), 4) << "Wrong output dims size: " << output_shape.dims.size(); - RET_CHECK_EQ(output_shape.dims[0], 1) - << "Handling batch dimension not equal to 1 is not implemented in this " - "converter."; + RET_CHECK_GE(output_shape.dims[0], 1) + << "The batch dimension needs to be equal or larger than 1."; RET_CHECK(output_shape.dims[3] == 3 || output_shape.dims[3] == 1) << "Wrong output channel: " << output_shape.dims[3]; return absl::OkStatus(); From 7c9fc9a6428b1c40738b5dce80abbacd627c4bdf Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Mon, 21 Nov 2022 21:45:58 -0800 Subject: [PATCH 032/346] Remove `mp.solutions` from doc generation. These need to be excluded from the current package, so do it automatically. PiperOrigin-RevId: 490146934 --- docs/build_py_api_docs.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py index fa1e4314f..fe706acd3 100644 --- a/docs/build_py_api_docs.py +++ b/docs/build_py_api_docs.py @@ -30,7 +30,7 @@ from tensorflow_docs.api_generator import public_api try: # mediapipe has not been set up to work with bazel yet, so catch & report. - import mediapipe # pytype: disable=import-error + import mediapipe as mp # pytype: disable=import-error except ImportError as e: raise ImportError('Please `pip install mediapipe`.') from e @@ -58,11 +58,13 @@ _SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', def gen_api_docs(): """Generates API docs for the mediapipe package.""" + if hasattr(mp, 'solutions'): + del mp.solutions doc_generator = generate_lib.DocGenerator( root_title=PROJECT_FULL_NAME, - py_modules=[(PROJECT_SHORT_NAME, mediapipe)], - base_dir=os.path.dirname(mediapipe.__file__), + py_modules=[(PROJECT_SHORT_NAME, mp)], + base_dir=os.path.dirname(mp.__file__), code_url_prefix=_URL_PREFIX.value, search_hints=_SEARCH_HINTS.value, site_path=_SITE_PATH.value, From 54a684717fa39cd39315f8f6cb60b6c5a7fa76aa Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 21 Nov 2022 23:22:49 -0800 Subject: [PATCH 033/346] Internal change PiperOrigin-RevId: 490159674 --- mediapipe/gpu/attachments.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/gpu/attachments.h b/mediapipe/gpu/attachments.h index ca9f074c4..3a73e4676 100644 --- a/mediapipe/gpu/attachments.h +++ b/mediapipe/gpu/attachments.h @@ -31,8 +31,8 @@ class AttachmentBase {}; template class Attachment : public AttachmentBase { public: - using FactoryT = std::function(Context&)>; - Attachment(FactoryT factory) : factory_(factory) {} + using FactoryT = AttachmentPtr (*)(Context&); + explicit constexpr Attachment(FactoryT factory) : factory_(factory) {} Attachment(const Attachment&) = delete; Attachment(Attachment&&) = delete; From a8b776102240ecb73f1a7aeb8ace9db42eb05f96 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 21 Nov 2022 23:27:55 -0800 Subject: [PATCH 034/346] Define a kUtilityFramebuffer context attachment A framebuffer object is often needed to render to a texture or read data from it. Currently we create one in each GlCalculatorHelper, but that is redundant (we only need one per context, and multiple calculators can share the same context). Other times, the code that needs to use this doesn't own a helper. For both reasons, this should be attached to the context. We could just make this a member of GlContext since it's so common. However, I figured we might as well use the attachment system. PiperOrigin-RevId: 490160214 --- mediapipe/gpu/gl_context.cc | 12 ++++++++++++ mediapipe/gpu/gl_context.h | 6 ++++++ 2 files changed, 18 insertions(+) diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 53e3ff8b7..99b995dda 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -1054,4 +1054,16 @@ void GlContext::SetStandardTextureParams(GLenum target, GLint internal_format) { glTexParameteri(target, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); } +const GlContext::Attachment kUtilityFramebuffer( + [](GlContext&) -> GlContext::Attachment::Ptr { + GLuint framebuffer; + glGenFramebuffers(1, &framebuffer); + if (!framebuffer) return nullptr; + return {new GLuint(framebuffer), [](void* ptr) { + GLuint* fb = static_cast(ptr); + glDeleteFramebuffers(1, fb); + delete fb; + }}; + }); + } // namespace mediapipe diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 7f5168d8b..4f2390404 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -474,6 +474,12 @@ class GlContext : public std::enable_shared_from_this { bool destructing_ = false; }; +// A framebuffer that the framework can use to attach textures for rendering +// etc. +// This could just be a member of GlContext, but it serves as a basic example +// of an attachment. +ABSL_CONST_INIT extern const GlContext::Attachment kUtilityFramebuffer; + // For backward compatibility. TODO: migrate remaining callers. ABSL_DEPRECATED( "Prefer passing an explicit GlVersion argument (use " From bacbac8d926d769bf51f770914d603b942094ebb Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 21 Nov 2022 23:57:33 -0800 Subject: [PATCH 035/346] Use kUtilityFramebuffer in ReadTexture This avoids creating a temporary framebuffer each time. PiperOrigin-RevId: 490163892 --- mediapipe/gpu/gl_texture_buffer.cc | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 7f77cd4b3..3d2642552 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -15,6 +15,7 @@ #include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_texture_view.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" @@ -333,8 +334,8 @@ void GlTextureBuffer::ViewDoneWriting(const GlTextureView& view) { #endif // __ANDROID__ } -static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, - void* output, size_t size) { +static void ReadTexture(GlContext& ctx, const GlTextureView& view, + GpuBufferFormat format, void* output, size_t size) { // TODO: check buffer size? We could use glReadnPixels where available // (OpenGL ES 3.2, i.e. nowhere). Note that, to fully check that the read // won't overflow the buffer with glReadPixels, we'd also need to check or @@ -347,10 +348,7 @@ static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, GLint previous_fbo; glGetIntegerv(GL_FRAMEBUFFER_BINDING, &previous_fbo); - // We use a temp fbo to avoid depending on the app having an existing one. - // TODO: keep a utility fbo around in the context? - GLuint fbo = 0; - glGenFramebuffers(1, &fbo); + GLuint fbo = kUtilityFramebuffer.Get(ctx); glBindFramebuffer(GL_FRAMEBUFFER, fbo); glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), view.name(), 0); @@ -360,7 +358,6 @@ static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, 0); // TODO: just set the binding to 0 to avoid the get call? glBindFramebuffer(GL_FRAMEBUFFER, previous_fbo); - glDeleteFramebuffers(1, &fbo); } static std::shared_ptr ConvertToImageFrame( @@ -370,9 +367,10 @@ static std::shared_ptr ConvertToImageFrame( auto output = absl::make_unique(image_format, buf->width(), buf->height(), ImageFrame::kGlDefaultAlignmentBoundary); - buf->GetProducerContext()->Run([buf, &output] { + auto ctx = buf->GetProducerContext(); + ctx->Run([buf, &output, &ctx] { auto view = buf->GetReadView(internal::types{}, /*plane=*/0); - ReadTexture(view, buf->format(), output->MutablePixelData(), + ReadTexture(*ctx, view, buf->format(), output->MutablePixelData(), output->PixelDataSize()); }); return std::make_shared(std::move(output)); From d648926155d19cb6665895661624ec19cc7d33c6 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 22 Nov 2022 00:35:27 -0800 Subject: [PATCH 036/346] Just reset the fb binding to 0 in ReadTexture This saves a get operation. We already have precedent in lots of other MediaPipe code where we just reset bindings to 0. PiperOrigin-RevId: 490170691 --- mediapipe/gpu/gl_texture_buffer.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 3d2642552..d530d5d12 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -345,9 +345,6 @@ static void ReadTexture(GlContext& ctx, const GlTextureView& view, GlTextureInfo info = GlTextureInfoForGpuBufferFormat( format, view.plane(), view.gl_context()->GetGlVersion()); - GLint previous_fbo; - glGetIntegerv(GL_FRAMEBUFFER_BINDING, &previous_fbo); - GLuint fbo = kUtilityFramebuffer.Get(ctx); glBindFramebuffer(GL_FRAMEBUFFER, fbo); glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), @@ -356,8 +353,7 @@ static void ReadTexture(GlContext& ctx, const GlTextureView& view, output); glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, 0, 0); - // TODO: just set the binding to 0 to avoid the get call? - glBindFramebuffer(GL_FRAMEBUFFER, previous_fbo); + glBindFramebuffer(GL_FRAMEBUFFER, 0); } static std::shared_ptr ConvertToImageFrame( From 872d1afda7f8a465db59dfcf9ab56e6d60832646 Mon Sep 17 00:00:00 2001 From: vrabaud Date: Tue, 22 Nov 2022 03:10:35 -0800 Subject: [PATCH 037/346] Internal change PiperOrigin-RevId: 490196129 --- mediapipe/framework/port/BUILD | 11 ++++++++++ mediapipe/framework/port/opencv_videoio_inc.h | 21 +++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 mediapipe/framework/port/opencv_videoio_inc.h diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index 87944d80f..e499ca3a6 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -311,6 +311,17 @@ cc_library( ], ) +cc_library( + name = "opencv_videoio", + hdrs = ["opencv_videoio_inc.h"], + visibility = ["//visibility:public"], + deps = [ + ":opencv_core", + "//mediapipe/framework:port", + "//third_party:opencv", + ], +) + cc_library( name = "parse_text_proto", hdrs = [ diff --git a/mediapipe/framework/port/opencv_videoio_inc.h b/mediapipe/framework/port/opencv_videoio_inc.h new file mode 100644 index 000000000..63029b69f --- /dev/null +++ b/mediapipe/framework/port/opencv_videoio_inc.h @@ -0,0 +1,21 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_PORT_OPENCV_VIDEOIO_INC_H_ +#define MEDIAPIPE_PORT_OPENCV_VIDEOIO_INC_H_ + +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "third_party/OpenCV/videoio.hpp" + +#endif // MEDIAPIPE_PORT_OPENCV_VIDEOIO_INC_H_ From 515d00fc22100bfb948aecfa39408a0b599a0c89 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 22 Nov 2022 15:16:52 -0800 Subject: [PATCH 038/346] Internal change PiperOrigin-RevId: 490349260 --- mediapipe/framework/formats/BUILD | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index e13bb2704..4276ffc3a 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -312,9 +312,7 @@ mediapipe_register_type( mediapipe_proto_library( name = "landmark_proto", srcs = ["landmark.proto"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) mediapipe_register_type( From 7ce4aa6592c30c2ac5d0c075304e50ae7d01b38f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 22 Nov 2022 16:38:51 -0800 Subject: [PATCH 039/346] Internal change PiperOrigin-RevId: 490366250 --- mediapipe/util/sequence/media_sequence_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/util/sequence/media_sequence_test.cc b/mediapipe/util/sequence/media_sequence_test.cc index 40a474599..42b0e3889 100644 --- a/mediapipe/util/sequence/media_sequence_test.cc +++ b/mediapipe/util/sequence/media_sequence_test.cc @@ -802,7 +802,7 @@ TEST(MediaSequenceTest, ReconcileMetadataImages) { tensorflow::SequenceExample sequence; cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {})); std::string encoded_image(bytes.begin(), bytes.end()); AddImageEncoded(encoded_image, &sequence); AddImageEncoded(encoded_image, &sequence); @@ -843,7 +843,7 @@ TEST(MediaSequenceTest, ReconcileMetadataFlowEncoded) { tensorflow::SequenceExample sequence; cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {})); std::string encoded_flow(bytes.begin(), bytes.end()); AddForwardFlowEncoded(encoded_flow, &sequence); From efa9e737f80e245aec4c6ef9483fc92547e6d1d9 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 22 Nov 2022 17:22:18 -0800 Subject: [PATCH 040/346] Use current context if available in ConvertToImageFrame If we're already running in a GlContext, there's no need to go back to the producer context, which may be different. PiperOrigin-RevId: 490373829 --- mediapipe/gpu/gl_texture_buffer.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index d530d5d12..69b9889c7 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -363,7 +363,8 @@ static std::shared_ptr ConvertToImageFrame( auto output = absl::make_unique(image_format, buf->width(), buf->height(), ImageFrame::kGlDefaultAlignmentBoundary); - auto ctx = buf->GetProducerContext(); + auto ctx = GlContext::GetCurrent(); + if (!ctx) ctx = buf->GetProducerContext(); ctx->Run([buf, &output, &ctx] { auto view = buf->GetReadView(internal::types{}, /*plane=*/0); ReadTexture(*ctx, view, buf->format(), output->MutablePixelData(), @@ -392,7 +393,9 @@ static std::shared_ptr ConvertToCvPixelBuffer( std::shared_ptr buf) { auto output = absl::make_unique( buf->width(), buf->height(), buf->format()); - buf->GetProducerContext()->Run([buf, &output] { + auto ctx = GlContext::GetCurrent(); + if (!ctx) ctx = buf->GetProducerContext(); + ctx->Run([buf, &output] { TempGlFramebuffer framebuffer; auto src = buf->GetReadView(internal::types{}, /*plane=*/0); auto dst = From fac97554dfb80e8c14ecbfb2cbe12e0ad26ce0b4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 22 Nov 2022 17:23:48 -0800 Subject: [PATCH 041/346] Small TS audio API improvement PiperOrigin-RevId: 490374083 --- .../audio_classifier/audio_classifier.ts | 14 +- .../audio/audio_embedder/audio_embedder.ts | 14 +- mediapipe/web/graph_runner/graph_runner.ts | 129 ++++++++++++++---- 3 files changed, 105 insertions(+), 52 deletions(-) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 0c54a4718..20c745383 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -35,11 +35,7 @@ export * from './audio_classifier_result'; const MEDIAPIPE_GRAPH = 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; -// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' and -// cannot be changed -// TODO: Change this to `audio_in` to match the name in the CC -// implementation -const AUDIO_STREAM = 'input_audio'; +const AUDIO_STREAM = 'audio_in'; const SAMPLE_RATE_STREAM = 'sample_rate'; const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications'; @@ -154,14 +150,8 @@ export class AudioClassifier extends AudioTaskRunner { protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioClassifierResult[] { - // Configures the number of samples in the WASM layer. We re-configure the - // number of samples and the sample rate for every frame, but ignore other - // side effects of this function (such as sending the input side packet and - // the input stream header). - this.configureAudio( - /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); - this.addAudioToStream(audioData, timestampMs); + this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); this.classificationResults = []; this.finishProcessing(); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 51cb819de..46a7b6729 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -35,11 +35,7 @@ export * from './audio_embedder_result'; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern -// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' cannot -// be changed -// TODO: Change this to `audio_in` to match the name in the CC -// implementation -const AUDIO_STREAM = 'input_audio'; +const AUDIO_STREAM = 'audio_in'; const SAMPLE_RATE_STREAM = 'sample_rate'; const EMBEDDINGS_STREAM = 'embeddings_out'; const TIMESTAMPED_EMBEDDINGS_STREAM = 'timestamped_embeddings_out'; @@ -151,14 +147,8 @@ export class AudioEmbedder extends AudioTaskRunner { protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioEmbedderResult[] { - // Configures the number of samples in the WASM layer. We re-configure the - // number of samples and the sample rate for every frame, but ignore other - // side effects of this function (such as sending the input side packet and - // the input stream header). - this.configureAudio( - /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); - this.addAudioToStream(audioData, timestampMs); + this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); this.embeddingResults = []; this.finishProcessing(); diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 7de5aa33b..c4654794c 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -15,9 +15,6 @@ export declare interface FileLocator { locateFile: (filename: string) => string; } -/** Listener to be passed in by user for handling output audio data. */ -export type AudioOutputListener = (output: Float32Array) => void; - /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler * doesn't break our JS/C++ bridge. @@ -32,19 +29,14 @@ export declare interface WasmModule { _bindTextureToCanvas: () => boolean; _changeBinaryGraph: (size: number, dataPtr: number) => void; _changeTextGraph: (size: number, dataPtr: number) => void; - _configureAudio: - (channels: number, samples: number, sampleRate: number) => void; _free: (ptr: number) => void; _malloc: (size: number) => number; - _processAudio: (dataPtr: number, timestamp: number) => void; _processFrame: (width: number, height: number, timestamp: number) => void; _setAutoRenderToScreen: (enabled: boolean) => void; _waitUntilIdle: () => void; // Exposed so that clients of this lib can access this field dataFileDownloads?: {[url: string]: {loaded: number, total: number}}; - // Wasm module will call us back at this function when given audio data. - onAudioOutput?: AudioOutputListener; // Wasm Module multistream entrypoints. Require // gl_graph_runner_internal_multi_input as a build dependency. @@ -100,11 +92,14 @@ export declare interface WasmModule { _attachProtoVectorListener: (streamNamePtr: number, makeDeepCopy?: boolean) => void; - // Requires dependency ":gl_graph_runner_audio_out", and will register an - // audio output listening function which can be tapped into dynamically during - // graph running via onAudioOutput. This call must be made before graph is - // initialized, but after wasmModule is instantiated. - _attachAudioOutputListener: () => void; + // Require dependency ":gl_graph_runner_audio_out" + _attachAudioListener: (streamNamePtr: number, makeDeepCopy?: boolean) => void; + + // Require dependency ":gl_graph_runner_audio" + _addAudioToInputStream: (dataPtr: number, numChannels: number, + numSamples: number, streamNamePtr: number, timestamp: number) => void; + _configureAudio: (channels: number, samples: number, sampleRate: number, + streamNamePtr: number, headerNamePtr: number) => void; // TODO: Refactor to just use a few numbers (perhaps refactor away // from gl_graph_runner_internal.cc entirely to use something a little more @@ -235,19 +230,38 @@ export class GraphRunner { } /** - * Configures the current graph to handle audio in a certain way. Must be - * called before the graph is set/started in order to use processAudio. + * Configures the current graph to handle audio processing in a certain way + * for all its audio input streams. Additionally can configure audio headers + * (both input side packets as well as input stream headers), but these + * configurations only take effect if called before the graph is set/started. * @param numChannels The number of channels of audio input. Only 1 * is supported for now. * @param numSamples The number of samples that are taken in each * audio capture. * @param sampleRate The rate, in Hz, of the sampling. + * @param streamName The optional name of the input stream to additionally + * configure with audio information. This configuration only occurs before + * the graph is set/started. If unset, a default stream name will be used. + * @param headerName The optional name of the header input side packet to + * additionally configure with audio information. This configuration only + * occurs before the graph is set/started. If unset, a default header name + * will be used. */ - configureAudio(numChannels: number, numSamples: number, sampleRate: number) { - this.wasmModule._configureAudio(numChannels, numSamples, sampleRate); - if (this.wasmModule._attachAudioOutputListener) { - this.wasmModule._attachAudioOutputListener(); + configureAudio(numChannels: number, numSamples: number, sampleRate: number, + streamName?: string, headerName?: string) { + if (!this.wasmModule._configureAudio) { + console.warn( + 'Attempting to use configureAudio without support for input audio. ' + + 'Is build dep ":gl_graph_runner_audio" missing?'); } + streamName = streamName || 'input_audio'; + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + headerName = headerName || 'audio_header'; + this.wrapStringPtr(headerName, (headerNamePtr: number) => { + this.wasmModule._configureAudio(streamNamePtr, headerNamePtr, + numChannels, numSamples, sampleRate); + }); + }); } /** @@ -437,9 +451,36 @@ export class GraphRunner { * processed. * @param audioData An array of raw audio capture data, like * from a call to getChannelData on an AudioBuffer. + * @param streamName The name of the MediaPipe graph stream to add the audio + * data to. * @param timestamp The timestamp of the current frame, in ms. */ - addAudioToStream(audioData: Float32Array, timestamp: number) { + addAudioToStream( + audioData: Float32Array, streamName: string, timestamp: number) { + // numChannels and numSamples being 0 will cause defaults to be used, + // which will reflect values from last call to configureAudio. + this.addAudioToStreamWithShape(audioData, 0, 0, streamName, timestamp); + } + + /** + * Takes the raw data from a JS audio capture array, and sends it to C++ to be + * processed, shaping the audioData array into an audio matrix according to + * the numChannels and numSamples parameters. + * @param audioData An array of raw audio capture data, like + * from a call to getChannelData on an AudioBuffer. + * @param numChannels The number of audio channels this data represents. If 0 + * is passed, then the value will be taken from the last call to + * configureAudio. + * @param numSamples The number of audio samples captured in this data packet. + * If 0 is passed, then the value will be taken from the last call to + * configureAudio. + * @param streamName The name of the MediaPipe graph stream to add the audio + * data to. + * @param timestamp The timestamp of the current frame, in ms. + */ + addAudioToStreamWithShape( + audioData: Float32Array, numChannels: number, numSamples: number, + streamName: string, timestamp: number) { // 4 bytes for each F32 const size = audioData.length * 4; if (this.audioSize !== size) { @@ -450,7 +491,11 @@ export class GraphRunner { this.audioSize = size; } this.wasmModule.HEAPF32.set(audioData, this.audioPtr! / 4); - this.wasmModule._processAudio(this.audioPtr!, timestamp); + + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wasmModule._addAudioToInputStream( + this.audioPtr!, numChannels, numSamples, streamNamePtr, timestamp); + }); } /** @@ -943,17 +988,45 @@ export class GraphRunner { } /** - * Sets a listener to be called back with audio output packet data, as a - * Float32Array, when graph has finished processing it. - * @param audioOutputListener The caller's listener function. + * Attaches an audio packet listener to the specified output_stream, to be + * given a Float32Array as output. + * @param outputStreamName The name of the graph output stream to grab audio + * data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. If the + * audio data needs to be able to outlive the call, you may set the + * optional makeDeepCopy parameter to true, or can manually deep-copy the + * data yourself. + * @param makeDeepCopy Optional convenience parameter which, if set to true, + * will override the default memory management behavior and make a deep + * copy of the underlying data, rather than just returning a view into the + * C++-managed memory. At the cost of a data copy, this allows the + * returned data to outlive the callback lifetime (and it will be cleaned + * up automatically by JS garbage collection whenever the user is finished + * with it). */ - setOnAudioOutput(audioOutputListener: AudioOutputListener) { - this.wasmModule.onAudioOutput = audioOutputListener; - if (!this.wasmModule._attachAudioOutputListener) { + attachAudioListener(outputStreamName: string, + callbackFcn: (data: Float32Array) => void, makeDeepCopy?: boolean): void { + if (!this.wasmModule._attachAudioListener) { console.warn( - 'Attempting to use AudioOutputListener without support for ' + + 'Attempting to use attachAudioListener without support for ' + 'output audio. Is build dep ":gl_graph_runner_audio_out" missing?'); } + + // Set up our TS listener to receive any packets for this stream, and + // additionally reformat our Uint8Array into a Float32Array for the user. + this.setListener(outputStreamName, (data: Uint8Array) => { + const floatArray = new Float32Array(data.buffer); // Should be very fast + callbackFcn(floatArray); + }); + + // Tell our graph to listen for string packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachAudioListener( + outputStreamNamePtr, makeDeepCopy || false); + }); } /** From 8ba9d87e667f0c6e67026f96aa58ee1a980b0ce1 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 22 Nov 2022 17:25:55 -0800 Subject: [PATCH 042/346] Update ImageFrameToGpuBufferCalculator to use api2 and GpuBuffer conversions PiperOrigin-RevId: 490374387 --- mediapipe/gpu/BUILD | 2 + .../image_frame_to_gpu_buffer_calculator.cc | 62 ++++++++----------- 2 files changed, 28 insertions(+), 36 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 10a8d7fff..f97eed678 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -901,6 +901,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":gl_calculator_helper", + ":gpu_buffer_storage_image_frame", + "//mediapipe/framework/api2:node", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:status", diff --git a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc index 2a8331db8..c67fb0c62 100644 --- a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc +++ b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc @@ -12,73 +12,63 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/gpu/gl_calculator_helper.h" -#ifdef __APPLE__ -#include "mediapipe/objc/util.h" -#endif - namespace mediapipe { +namespace api2 { -// Convert ImageFrame to GpuBuffer. -class ImageFrameToGpuBufferCalculator : public CalculatorBase { +class ImageFrameToGpuBufferCalculator + : public RegisteredNode { public: - ImageFrameToGpuBufferCalculator() {} + static constexpr Input kIn{""}; + static constexpr Output kOut{""}; - static absl::Status GetContract(CalculatorContract* cc); + MEDIAPIPE_NODE_INTERFACE(ImageFrameToGpuBufferCalculator, kIn, kOut); + + static absl::Status UpdateContract(CalculatorContract* cc); absl::Status Open(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override; private: -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER GlCalculatorHelper helper_; -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER }; -REGISTER_CALCULATOR(ImageFrameToGpuBufferCalculator); // static -absl::Status ImageFrameToGpuBufferCalculator::GetContract( +absl::Status ImageFrameToGpuBufferCalculator::UpdateContract( CalculatorContract* cc) { - cc->Inputs().Index(0).Set(); - cc->Outputs().Index(0).Set(); // Note: we call this method even on platforms where we don't use the helper, // to ensure the calculator's contract is the same. In particular, the helper // enables support for the legacy side packet, which several graphs still use. - MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); - return absl::OkStatus(); + return GlCalculatorHelper::UpdateContract(cc); } absl::Status ImageFrameToGpuBufferCalculator::Open(CalculatorContext* cc) { - // Inform the framework that we always output at the same timestamp - // as we receive a packet at. - cc->SetOffset(TimestampDiff(0)); -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER MP_RETURN_IF_ERROR(helper_.Open(cc)); -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER return absl::OkStatus(); } absl::Status ImageFrameToGpuBufferCalculator::Process(CalculatorContext* cc) { -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - CFHolder buffer; - MP_RETURN_IF_ERROR(CreateCVPixelBufferForImageFramePacket( - cc->Inputs().Index(0).Value(), &buffer)); - cc->Outputs().Index(0).Add(new GpuBuffer(buffer), cc->InputTimestamp()); -#else - const auto& input = cc->Inputs().Index(0).Get(); - helper_.RunInGlContext([this, &input, &cc]() { - auto src = helper_.CreateSourceTexture(input); - auto output = src.GetFrame(); - glFlush(); - cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - src.Release(); - }); -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + auto image_frame = std::const_pointer_cast( + mediapipe::SharedPtrWithPacket(kIn(cc).packet())); + auto gpu_buffer = api2::MakePacket( + std::make_shared( + std::move(image_frame))) + .At(cc->InputTimestamp()); + // This calculator's behavior has been to do the texture upload eagerly, and + // some graphs may rely on running this on a separate GL context to avoid + // blocking another context with the read operation. So let's request GPU + // access here to ensure that the behavior stays the same. + // TODO: have a better way to do this, or defer until later. + helper_.RunInGlContext( + [&gpu_buffer] { auto view = gpu_buffer->GetReadView(0); }); + kOut(cc).Send(std::move(gpu_buffer)); return absl::OkStatus(); } +} // namespace api2 } // namespace mediapipe From 837225c53d55700ff485367bb0fa71890f905e2e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 22 Nov 2022 17:30:23 -0800 Subject: [PATCH 043/346] Internal change PiperOrigin-RevId: 490374976 --- mediapipe/framework/validated_graph_config.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mediapipe/framework/validated_graph_config.cc b/mediapipe/framework/validated_graph_config.cc index 16aad6e9b..01e3da83e 100644 --- a/mediapipe/framework/validated_graph_config.cc +++ b/mediapipe/framework/validated_graph_config.cc @@ -1048,6 +1048,14 @@ absl::Status ValidatedGraphConfig::ValidateRequiredSidePacketTypes( for (const auto& required_item : required_side_packets_) { auto iter = side_packet_types.find(required_item.first); if (iter == side_packet_types.end()) { + bool is_optional = true; + for (int index : required_item.second) { + is_optional &= input_side_packets_[index].packet_type->IsOptional(); + } + if (is_optional) { + // Side packets that are optional and not provided are ignored. + continue; + } statuses.push_back(mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Side packet \"" << required_item.first << "\" is required but was not provided."); From 3bbc0e9af9150797142295f47b1d87a0403d8f44 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 22 Nov 2022 17:34:58 -0800 Subject: [PATCH 044/346] Internal change PiperOrigin-RevId: 490375672 --- mediapipe/tasks/web/BUILD | 18 +++--------------- mediapipe/tasks/web/audio.ts | 3 +-- mediapipe/tasks/web/text.ts | 3 +-- mediapipe/tasks/web/vision.ts | 6 +----- 4 files changed, 6 insertions(+), 24 deletions(-) diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index af76a1fe8..7e5d02892 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -24,10 +24,7 @@ mediapipe_files(srcs = [ mediapipe_ts_library( name = "audio_lib", srcs = ["audio.ts"], - deps = [ - "//mediapipe/tasks/web/audio/audio_classifier", - "//mediapipe/tasks/web/audio/audio_embedder", - ], + deps = ["//mediapipe/tasks/web/audio:audio_lib"], ) rollup_bundle( @@ -69,10 +66,7 @@ pkg_npm( mediapipe_ts_library( name = "text_lib", srcs = ["text.ts"], - deps = [ - "//mediapipe/tasks/web/text/text_classifier", - "//mediapipe/tasks/web/text/text_embedder", - ], + deps = ["//mediapipe/tasks/web/text:text_lib"], ) rollup_bundle( @@ -114,13 +108,7 @@ pkg_npm( mediapipe_ts_library( name = "vision_lib", srcs = ["vision.ts"], - deps = [ - "//mediapipe/tasks/web/vision/gesture_recognizer", - "//mediapipe/tasks/web/vision/hand_landmarker", - "//mediapipe/tasks/web/vision/image_classifier", - "//mediapipe/tasks/web/vision/image_embedder", - "//mediapipe/tasks/web/vision/object_detector", - ], + deps = ["//mediapipe/tasks/web/vision:vision_lib"], ) rollup_bundle( diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts index 056426f50..8c522efcc 100644 --- a/mediapipe/tasks/web/audio.ts +++ b/mediapipe/tasks/web/audio.ts @@ -14,8 +14,7 @@ * limitations under the License. */ -import {AudioClassifier as AudioClassifierImpl} from '../../tasks/web/audio/audio_classifier/audio_classifier'; -import {AudioEmbedder as AudioEmbedderImpl} from '../../tasks/web/audio/audio_embedder/audio_embedder'; +import {AudioClassifier as AudioClassifierImpl, AudioEmbedder as AudioEmbedderImpl} from '../../tasks/web/audio/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. diff --git a/mediapipe/tasks/web/text.ts b/mediapipe/tasks/web/text.ts index 39d101237..8f15075c5 100644 --- a/mediapipe/tasks/web/text.ts +++ b/mediapipe/tasks/web/text.ts @@ -14,8 +14,7 @@ * limitations under the License. */ -import {TextClassifier as TextClassifierImpl} from '../../tasks/web/text/text_classifier/text_classifier'; -import {TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/text_embedder/text_embedder'; +import {TextClassifier as TextClassifierImpl, TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. diff --git a/mediapipe/tasks/web/vision.ts b/mediapipe/tasks/web/vision.ts index 4e4fab43f..74a056464 100644 --- a/mediapipe/tasks/web/vision.ts +++ b/mediapipe/tasks/web/vision.ts @@ -14,11 +14,7 @@ * limitations under the License. */ -import {GestureRecognizer as GestureRecognizerImpl} from '../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; -import {HandLandmarker as HandLandmarkerImpl} from '../../tasks/web/vision/hand_landmarker/hand_landmarker'; -import {ImageClassifier as ImageClassifierImpl} from '../../tasks/web/vision/image_classifier/image_classifier'; -import {ImageEmbedder as ImageEmbedderImpl} from '../../tasks/web/vision/image_embedder/image_embedder'; -import {ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/object_detector/object_detector'; +import {GestureRecognizer as GestureRecognizerImpl, HandLandmarker as HandLandmarkerImpl, ImageClassifier as ImageClassifierImpl, ImageEmbedder as ImageEmbedderImpl, ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. From a55839de51dafe27b4c2b705954444895a842c3c Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 22 Nov 2022 18:07:26 -0800 Subject: [PATCH 045/346] This storage only needs a "done writing" callback on simulator, so only set it there - When not on simulator, we pass nullptr instead of a do-nothing callback. - The callback is no longer a method, but a function. Only the CVPixelBuffer is captured. PiperOrigin-RevId: 490380248 --- .../gpu/gpu_buffer_storage_cv_pixel_buffer.cc | 45 +++++++++++-------- .../gpu/gpu_buffer_storage_cv_pixel_buffer.h | 1 - 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc index f3954a6e4..014cc1c69 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc @@ -70,25 +70,9 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( return GetTexture(plane, nullptr); } -GlTextureView GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types, int plane) { - return GetTexture(plane, [this](const mediapipe::GlTextureView& view) { - ViewDoneWriting(view); - }); -} - -std::shared_ptr GpuBufferStorageCvPixelBuffer::GetReadView( - internal::types) const { - return CreateImageFrameForCVPixelBuffer(**this); -} -std::shared_ptr GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types) { - return CreateImageFrameForCVPixelBuffer(**this); -} - -void GpuBufferStorageCvPixelBuffer::ViewDoneWriting(const GlTextureView& view) { #if TARGET_IPHONE_SIMULATOR - CVPixelBufferRef pixel_buffer = **this; +static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer, + const GlTextureView& view) { CHECK(pixel_buffer); CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); CHECK(err == kCVReturnSuccess) @@ -126,7 +110,30 @@ void GpuBufferStorageCvPixelBuffer::ViewDoneWriting(const GlTextureView& view) { err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); CHECK(err == kCVReturnSuccess) << "CVPixelBufferUnlockBaseAddress failed: " << err; -#endif +} +#endif // TARGET_IPHONE_SIMULATOR + +GlTextureView GpuBufferStorageCvPixelBuffer::GetWriteView( + internal::types, int plane) { + return GetTexture(plane, +#if TARGET_IPHONE_SIMULATOR + [pixel_buffer = CFHolder(*this)]( + const mediapipe::GlTextureView& view) { + ViewDoneWritingSimulatorWorkaround(*pixel_buffer, view); + } +#else + nullptr +#endif // TARGET_IPHONE_SIMULATOR + ); +} + +std::shared_ptr GpuBufferStorageCvPixelBuffer::GetReadView( + internal::types) const { + return CreateImageFrameForCVPixelBuffer(**this); +} +std::shared_ptr GpuBufferStorageCvPixelBuffer::GetWriteView( + internal::types) { + return CreateImageFrameForCVPixelBuffer(**this); } static std::shared_ptr ConvertFromImageFrame( diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h index a9389ab8a..8723a1087 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h @@ -63,7 +63,6 @@ class GpuBufferStorageCvPixelBuffer private: GlTextureView GetTexture(int plane, GlTextureView::DoneWritingFn done_writing) const; - void ViewDoneWriting(const GlTextureView& view); }; inline CFHolder GpuBufferStorageCvPixelBuffer::GetReadView( From 05681fc0e17089a4e1d3f999bd17f3020cabb9bc Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 23 Nov 2022 01:26:15 -0800 Subject: [PATCH 046/346] Internal PiperOrigin-RevId: 490439195 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 8b09260bd..762184842 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -18,7 +18,6 @@ load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_build load("@build_bazel_rules_android//android:rules.bzl", "android_library") _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [ - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", From c5ce5236972a6045f42bb23d526ebb27a7e58bb7 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 23 Nov 2022 02:02:18 -0800 Subject: [PATCH 047/346] Add cosine APIs to Embedder tasks PiperOrigin-RevId: 490444597 --- .../tasks/web/audio/audio_embedder/BUILD | 1 + .../audio/audio_embedder/audio_embedder.ts | 15 +++++ mediapipe/tasks/web/components/utils/BUILD | 11 ++++ .../web/components/utils/cosine_similarity.ts | 62 +++++++++++++++++++ mediapipe/tasks/web/text/text_embedder/BUILD | 1 + .../web/text/text_embedder/text_embedder.ts | 15 +++++ .../tasks/web/vision/image_embedder/BUILD | 1 + .../vision/image_embedder/image_embedder.ts | 15 +++++ 8 files changed, 121 insertions(+) create mode 100644 mediapipe/tasks/web/components/utils/BUILD create mode 100644 mediapipe/tasks/web/components/utils/cosine_similarity.ts diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index 7d9a994a3..1a66464bd 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -22,6 +22,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/web/graph_runner:graph_runner_ts", diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 46a7b6729..9dce02862 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -20,8 +20,10 @@ import {AudioEmbedderGraphOptions as AudioEmbedderGraphOptionsProto} from '../.. import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -144,6 +146,19 @@ export class AudioEmbedder extends AudioTaskRunner { return this.processAudioClip(audioData, sampleRate); } + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioEmbedderResult[] { diff --git a/mediapipe/tasks/web/components/utils/BUILD b/mediapipe/tasks/web/components/utils/BUILD new file mode 100644 index 000000000..1c1ba69ca --- /dev/null +++ b/mediapipe/tasks/web/components/utils/BUILD @@ -0,0 +1,11 @@ +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "cosine_similarity", + srcs = ["cosine_similarity.ts"], + deps = [ + "//mediapipe/tasks/web/components/containers:embedding_result", + ], +) diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.ts new file mode 100644 index 000000000..fb1d0c185 --- /dev/null +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.ts @@ -0,0 +1,62 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + *

Licensed under the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + *

http://www.apache.org/licenses/LICENSE-2.0 + * + *

Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; + +/** + * Computes cosine similarity[1] between two `Embedding` objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types (float vs. quantized), + * have different sizes, or have an L2-norm of 0. + */ +export function computeCosineSimilarity(u: Embedding, v: Embedding): number { + if (u.floatEmbedding && v.floatEmbedding) { + return compute(u.floatEmbedding, v.floatEmbedding); + } + if (u.quantizedEmbedding && v.quantizedEmbedding) { + return compute( + convertToBytes(u.quantizedEmbedding), + convertToBytes(v.quantizedEmbedding)); + } + throw new Error( + 'Cannot compute cosine similarity between quantized and float embeddings.'); +} +function convertToBytes(data: Uint8Array): number[] { + return Array.from(data, v => v - 128); +} + +function compute(u: number[], v: number[]) { + if (u.length !== v.length) { + throw new Error( + `Cannot compute cosine similarity between embeddings of different sizes (${ + u.length} vs. ${v.length}).`); + } + let dotProduct = 0.0; + let normU = 0.0; + let normV = 0.0; + for (let i = 0; i < u.length; i++) { + dotProduct += u[i] * v[i]; + normU += u[i] * u[i]; + normV += v[i] * v[i]; + } + if (normU <= 0 || normV <= 0) { + throw new Error( + 'Cannot compute cosine similarity on embedding with 0 norm.'); + } + return dotProduct / Math.sqrt(normU * normV); +} diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index c555f8d33..3f92b8ae1 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -22,6 +22,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:task_runner", diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 57b91d575..2042a0985 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -18,9 +18,11 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; @@ -143,6 +145,19 @@ export class TextEmbedder extends TaskRunner { return this.embeddingResult; } + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + /** Updates the MediaPipe graph configuration. */ private refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index feb3ae054..2f012dc5e 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -21,6 +21,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/vision/core:vision_task_options", diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index c60665052..f96f1e961 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -19,8 +19,10 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ImageEmbedderGraphOptions} from '../../../../tasks/cc/vision/image_embedder/proto/image_embedder_graph_options_pb'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; @@ -157,6 +159,19 @@ export class ImageEmbedder extends VisionTaskRunner { return this.processVideoData(imageFrame, timestamp); } + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + /** Runs the embedding extraction and blocks on the response. */ protected process(image: ImageSource, timestamp: number): ImageEmbedderResult { From b5189758f7fc913e050ae0e6d4f7f999365e8118 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 23 Nov 2022 02:03:35 -0800 Subject: [PATCH 048/346] Move ImagePreprocessing to "processors" folder. PiperOrigin-RevId: 490444821 --- mediapipe/tasks/cc/components/BUILD | 45 --- .../tasks/cc/components/processors/BUILD | 33 ++ .../image_preprocessing_graph.cc} | 42 ++- .../image_preprocessing_graph.h} | 26 +- .../image_preprocessing_graph_test.cc | 343 ++++++++++++++++++ .../cc/components/processors/proto/BUILD | 10 + .../image_preprocessing_graph_options.proto} | 6 +- .../tasks/cc/vision/gesture_recognizer/BUILD | 4 - .../gesture_recognizer/gesture_recognizer.cc | 1 - .../hand_gesture_recognizer_graph.cc | 2 - mediapipe/tasks/cc/vision/hand_detector/BUILD | 2 +- .../hand_detector/hand_detector_graph.cc | 20 +- .../tasks/cc/vision/hand_landmarker/BUILD | 3 +- .../vision/hand_landmarker/hand_landmarker.cc | 1 - .../hand_landmarks_detector_graph.cc | 17 +- .../tasks/cc/vision/image_classifier/BUILD | 4 +- .../image_classifier_graph.cc | 19 +- .../tasks/cc/vision/image_embedder/BUILD | 4 +- .../image_embedder/image_embedder_graph.cc | 19 +- .../tasks/cc/vision/image_segmenter/BUILD | 4 +- .../image_segmenter/image_segmenter_graph.cc | 19 +- .../tasks/cc/vision/object_detector/BUILD | 2 +- .../object_detector/object_detector_graph.cc | 17 +- 23 files changed, 493 insertions(+), 150 deletions(-) rename mediapipe/tasks/cc/components/{image_preprocessing.cc => processors/image_preprocessing_graph.cc} (90%) rename mediapipe/tasks/cc/components/{image_preprocessing.h => processors/image_preprocessing_graph.h} (72%) create mode 100644 mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc rename mediapipe/tasks/cc/components/{image_preprocessing_options.proto => processors/proto/image_preprocessing_graph_options.proto} (89%) diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD index c90349ab2..54a5207d2 100644 --- a/mediapipe/tasks/cc/components/BUILD +++ b/mediapipe/tasks/cc/components/BUILD @@ -12,55 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") - package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -mediapipe_proto_library( - name = "image_preprocessing_options_proto", - srcs = ["image_preprocessing_options.proto"], - deps = [ - "//mediapipe/calculators/tensor:image_to_tensor_calculator_proto", - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - ], -) - -cc_library( - name = "image_preprocessing", - srcs = ["image_preprocessing.cc"], - hdrs = ["image_preprocessing.h"], - deps = [ - ":image_preprocessing_options_cc_proto", - "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/calculators/image:image_clone_calculator", - "//mediapipe/calculators/image:image_clone_calculator_cc_proto", - "//mediapipe/calculators/image:image_properties_calculator", - "//mediapipe/calculators/tensor:image_to_tensor_calculator", - "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/framework/formats:tensor", - "//mediapipe/gpu:gpu_origin_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", - "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/schema:schema_fbs", - ], - alwayslink = 1, -) - -# TODO: Enable this test - # TODO: Investigate rewriting the build rule to only link # the Bert Preprocessor if it's needed. cc_library( diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 32a628db7..4946683f5 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -100,3 +100,36 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "image_preprocessing_graph", + srcs = ["image_preprocessing_graph.cc"], + hdrs = ["image_preprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/calculators/image:image_clone_calculator", + "//mediapipe/calculators/image:image_clone_calculator_cc_proto", + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/gpu:gpu_origin_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", + "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) + +# TODO: Enable this test diff --git a/mediapipe/tasks/cc/components/image_preprocessing.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc similarity index 90% rename from mediapipe/tasks/cc/components/image_preprocessing.cc rename to mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc index ef447df97..b24b7f0cb 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.cc +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include #include @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/gpu/gpu_origin.pb.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" @@ -42,6 +42,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { using ::mediapipe::Tensor; @@ -144,9 +145,9 @@ bool DetermineImagePreprocessingGpuBackend( return acceleration.has_gpu(); } -absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, - bool use_gpu, - ImagePreprocessingOptions* options) { +absl::Status ConfigureImagePreprocessingGraph( + const ModelResources& model_resources, bool use_gpu, + proto::ImagePreprocessingGraphOptions* options) { ASSIGN_OR_RETURN(auto image_tensor_specs, BuildImageTensorSpecs(model_resources)); MP_RETURN_IF_ERROR(ConfigureImageToTensorCalculator( @@ -154,9 +155,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, // The GPU backend isn't able to process int data. If the input tensor is // quantized, forces the image preprocessing graph to use CPU backend. if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) { - options->set_backend(ImagePreprocessingOptions::GPU_BACKEND); + options->set_backend(proto::ImagePreprocessingGraphOptions::GPU_BACKEND); } else { - options->set_backend(ImagePreprocessingOptions::CPU_BACKEND); + options->set_backend(proto::ImagePreprocessingGraphOptions::CPU_BACKEND); } return absl::OkStatus(); } @@ -170,8 +171,7 @@ Source AddDataConverter(Source image_in, Graph& graph, return image_converter[Output("")]; } -// A "mediapipe.tasks.components.ImagePreprocessingSubgraph" performs image -// preprocessing. +// An ImagePreprocessingGraph performs image preprocessing. // - Accepts CPU input images and outputs CPU tensors. // // Inputs: @@ -192,7 +192,7 @@ Source AddDataConverter(Source image_in, Graph& graph, // An std::array representing the letterbox padding from the 4 // sides ([left, top, right, bottom]) of the output image, normalized to // [0.f, 1.f] by the output dimensions. The padding values are non-zero only -// when the "keep_aspect_ratio" is true in ImagePreprocessingOptions. +// when the "keep_aspect_ratio" is true in ImagePreprocessingGraphOptions. // IMAGE_SIZE - std::pair @Optional // The size of the original input image as a pair. // IMAGE - Image @Optional @@ -200,15 +200,15 @@ Source AddDataConverter(Source image_in, Graph& graph, // GPU). // // The recommended way of using this subgraph is through the GraphBuilder API -// using the 'ConfigureImagePreprocessing()' function. See header file for more -// details. -class ImagePreprocessingSubgraph : public Subgraph { +// using the 'ConfigureImagePreprocessingGraph()' function. See header file for +// more details. +class ImagePreprocessingGraph : public Subgraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; auto output_streams = BuildImagePreprocessing( - sc->Options(), + sc->Options(), graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph); output_streams.tensors >> graph[Output>(kTensorsTag)]; @@ -233,24 +233,25 @@ class ImagePreprocessingSubgraph : public Subgraph { // - the image that has pixel data stored on the target storage // (mediapipe::Image). // - // options: the mediapipe tasks ImagePreprocessingOptions. + // options: the mediapipe tasks ImagePreprocessingGraphOptions. // image_in: (mediapipe::Image) stream to preprocess. // graph: the mediapipe builder::Graph instance to be updated. ImagePreprocessingOutputStreams BuildImagePreprocessing( - const ImagePreprocessingOptions& options, Source image_in, - Source norm_rect_in, Graph& graph) { + const proto::ImagePreprocessingGraphOptions& options, + Source image_in, Source norm_rect_in, + Graph& graph) { // Convert image to tensor. auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator"); image_to_tensor.GetOptions() .CopyFrom(options.image_to_tensor_options()); switch (options.backend()) { - case ImagePreprocessingOptions::CPU_BACKEND: { + case proto::ImagePreprocessingGraphOptions::CPU_BACKEND: { auto cpu_image = AddDataConverter(image_in, graph, /*output_on_gpu=*/false); cpu_image >> image_to_tensor.In(kImageTag); break; } - case ImagePreprocessingOptions::GPU_BACKEND: { + case proto::ImagePreprocessingGraphOptions::GPU_BACKEND: { auto gpu_image = AddDataConverter(image_in, graph, /*output_on_gpu=*/true); gpu_image >> image_to_tensor.In(kImageTag); @@ -284,8 +285,9 @@ class ImagePreprocessingSubgraph : public Subgraph { } }; REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::ImagePreprocessingSubgraph); + ::mediapipe::tasks::components::processors::ImagePreprocessingGraph); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/image_preprocessing.h b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h similarity index 72% rename from mediapipe/tasks/cc/components/image_preprocessing.h rename to mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h index 6963b6556..455a9b316 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.h +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h @@ -13,35 +13,36 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ #include "absl/status/status.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { -// Configures an ImagePreprocessing subgraph using the provided model resources +// Configures an ImagePreprocessingGraph using the provided model resources // When use_gpu is true, use GPU as backend to convert image to tensor. // - Accepts CPU input images and outputs CPU tensors. // // Example usage: // // auto& preprocessing = -// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); +// graph.AddNode("mediapipe.tasks.components.processors.ImagePreprocessingGraph"); // core::proto::Acceleration acceleration; // acceleration.mutable_xnnpack(); // bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); -// MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( +// MP_RETURN_IF_ERROR(ConfigureImagePreprocessingGraph( // model_resources, // use_gpu, -// &preprocessing.GetOptions())); +// &preprocessing.GetOptions())); // -// The resulting ImagePreprocessing subgraph has the following I/O: +// The resulting ImagePreprocessingGraph has the following I/O: // Inputs: // IMAGE - Image // The image to preprocess. @@ -61,17 +62,18 @@ namespace components { // IMAGE - Image @Optional // The image that has the pixel data stored on the target storage (CPU vs // GPU). -absl::Status ConfigureImagePreprocessing( +absl::Status ConfigureImagePreprocessingGraph( const core::ModelResources& model_resources, bool use_gpu, - ImagePreprocessingOptions* options); + proto::ImagePreprocessingGraphOptions* options); -// Determine if the image preprocessing subgraph should use GPU as the backend +// Determine if the image preprocessing graph should use GPU as the backend // according to the given acceleration setting. bool DetermineImagePreprocessingGpuBackend( const core::proto::Acceleration& acceleration); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc new file mode 100644 index 000000000..6c094c6bc --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc @@ -0,0 +1,343 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::DecodeImageFromFile; +using ::testing::ContainerEq; +using ::testing::HasSubstr; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kMobileNetFloatWithMetadata[] = "mobilenet_v2_1.0_224.tflite"; +constexpr char kMobileNetFloatWithoutMetadata[] = + "mobilenet_v1_0.25_224_1_default_1.tflite"; +constexpr char kMobileNetQuantizedWithMetadata[] = + "mobilenet_v1_0.25_224_quant.tflite"; +constexpr char kMobileNetQuantizedWithoutMetadata[] = + "mobilenet_v1_0.25_192_quantized_1_default_1.tflite"; + +constexpr char kTestImage[] = "burger.jpg"; +constexpr int kTestImageWidth = 480; +constexpr int kTestImageHeight = 325; + +constexpr char kTestModelResourcesTag[] = "test_model_resources"; +constexpr std::array kIdentityMatrix = {1, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 1, 0, 0, 0, 0, 1}; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image_in"; +constexpr char kMatrixTag[] = "MATRIX"; +constexpr char kMatrixName[] = "matrix_out"; +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kTensorsName[] = "tensors_out"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kImageSizeName[] = "image_size_out"; +constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING"; +constexpr char kLetterboxPaddingName[] = "letterbox_padding_out"; + +constexpr float kLetterboxMaxAbsError = 1e-5; + +// Helper function to get ModelResources. +absl::StatusOr> CreateModelResourcesForModel( + absl::string_view model_name) { + auto external_file = std::make_unique(); + external_file->set_file_name(JoinPath("./", kTestDataDirectory, model_name)); + return ModelResources::Create(kTestModelResourcesTag, + std::move(external_file)); +} + +// Helper function to create a TaskRunner from ModelResources. +absl::StatusOr> CreateTaskRunner( + const ModelResources& model_resources, bool keep_aspect_ratio) { + Graph graph; + + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + auto& options = + preprocessing.GetOptions(); + options.mutable_image_to_tensor_options()->set_keep_aspect_ratio( + keep_aspect_ratio); + MP_RETURN_IF_ERROR( + ConfigureImagePreprocessingGraph(model_resources, false, &options)); + graph[Input(kImageTag)].SetName(kImageName) >> + preprocessing.In(kImageTag); + preprocessing.Out(kTensorsTag).SetName(kTensorsName) >> + graph[Output>(kTensorsTag)]; + preprocessing.Out(kMatrixTag).SetName(kMatrixName) >> + graph[Output>(kMatrixTag)]; + preprocessing.Out(kImageSizeTag).SetName(kImageSizeName) >> + graph[Output>(kImageSizeTag)]; + preprocessing.Out(kLetterboxPaddingTag).SetName(kLetterboxPaddingName) >> + graph[Output>(kLetterboxPaddingTag)]; + + return TaskRunner::Create(graph.GetConfig()); +} + +class ConfigureTest : public tflite_shims::testing::Test {}; + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetQuantizedWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, false, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_uint_range { min: 0 max: 255 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetQuantizedWithoutMetadata)); + + proto::ImagePreprocessingGraphOptions options; + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, false, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 192 + output_tensor_height: 192 + output_tensor_uint_range { min: 0 max: 255 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetFloatWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, false, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_float_range { min: -1 max: 1 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelFallbacksCpuBackend) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetQuantizedWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + core::proto::Acceleration acceleration; + acceleration.mutable_gpu(); + bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); + EXPECT_TRUE(use_gpu); + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, use_gpu, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_uint_range { min: 0 max: 255 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithFloatModelGpuBackend) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetFloatWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + core::proto::Acceleration acceleration; + acceleration.mutable_gpu(); + bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); + EXPECT_TRUE(use_gpu); + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, use_gpu, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_float_range { min: -1 max: 1 } + gpu_origin: TOP_LEFT + } + backend: GPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, FailsWithFloatModelWithoutMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetFloatWithoutMetadata)); + + proto::ImagePreprocessingGraphOptions options; + auto status = + ConfigureImagePreprocessingGraph(*model_resources, false, &options); + + EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); + EXPECT_THAT(status.message(), + HasSubstr("requires specifying NormalizationOptions metadata")); +} + +// Struct holding the parameters for parameterized PreprocessingTest class. +struct PreprocessingParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of the model to test. + std::string input_model_name; + // If true, keep test image aspect ratio. + bool keep_aspect_ratio; + // The expected output tensor type. + Tensor::ElementType expected_type; + // The expected outoput tensor shape. + std::vector expected_shape; + // The expected output letterbox padding; + std::array expected_letterbox_padding; +}; + +class PreprocessingTest : public testing::TestWithParam {}; + +TEST_P(PreprocessingTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kTestImage))); + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(GetParam().input_model_name)); + MP_ASSERT_OK_AND_ASSIGN( + auto task_runner, + CreateTaskRunner(*model_resources, GetParam().keep_aspect_ratio)); + + auto output_packets = + task_runner->Process({{kImageName, MakePacket(std::move(image))}}); + MP_ASSERT_OK(output_packets); + + const std::vector& tensors = + (*output_packets)[kTensorsName].Get>(); + EXPECT_EQ(tensors.size(), 1); + EXPECT_EQ(tensors[0].element_type(), GetParam().expected_type); + EXPECT_THAT(tensors[0].shape().dims, ContainerEq(GetParam().expected_shape)); + auto& matrix = (*output_packets)[kMatrixName].Get>(); + if (!GetParam().keep_aspect_ratio) { + for (int i = 0; i < matrix.size(); ++i) { + EXPECT_FLOAT_EQ(matrix[i], kIdentityMatrix[i]); + } + } + auto& image_size = + (*output_packets)[kImageSizeName].Get>(); + EXPECT_EQ(image_size.first, kTestImageWidth); + EXPECT_EQ(image_size.second, kTestImageHeight); + std::array letterbox_padding = + (*output_packets)[kLetterboxPaddingName].Get>(); + for (int i = 0; i < letterbox_padding.size(); ++i) { + EXPECT_NEAR(letterbox_padding[i], GetParam().expected_letterbox_padding[i], + kLetterboxMaxAbsError); + } +} + +INSTANTIATE_TEST_SUITE_P( + PreprocessingTest, PreprocessingTest, + Values( + PreprocessingParams{.test_name = "kMobileNetQuantizedWithMetadata", + .input_model_name = kMobileNetQuantizedWithMetadata, + .keep_aspect_ratio = false, + .expected_type = Tensor::ElementType::kUInt8, + .expected_shape = {1, 224, 224, 3}, + .expected_letterbox_padding = {0, 0, 0, 0}}, + PreprocessingParams{ + .test_name = "kMobileNetQuantizedWithoutMetadata", + .input_model_name = kMobileNetQuantizedWithoutMetadata, + .keep_aspect_ratio = false, + .expected_type = Tensor::ElementType::kUInt8, + .expected_shape = {1, 192, 192, 3}, + .expected_letterbox_padding = {0, 0, 0, 0}}, + PreprocessingParams{.test_name = "kMobileNetFloatWithMetadata", + .input_model_name = kMobileNetFloatWithMetadata, + .keep_aspect_ratio = false, + .expected_type = Tensor::ElementType::kFloat32, + .expected_shape = {1, 224, 224, 3}, + .expected_letterbox_padding = {0, 0, 0, 0}}, + PreprocessingParams{ + .test_name = "kMobileNetFloatWithMetadataKeepAspectRatio", + .input_model_name = kMobileNetFloatWithMetadata, + .keep_aspect_ratio = true, + .expected_type = Tensor::ElementType::kFloat32, + .expected_shape = {1, 224, 224, 3}, + .expected_letterbox_padding = {/*left*/ 0, + /*top*/ 0.161458, + /*right*/ 0, + /*bottom*/ 0.161458}}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace processors +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index 23ebbe008..9c58a8585 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -49,3 +49,13 @@ mediapipe_proto_library( "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto", ], ) + +mediapipe_proto_library( + name = "image_preprocessing_graph_options_proto", + srcs = ["image_preprocessing_graph_options.proto"], + deps = [ + "//mediapipe/calculators/tensor:image_to_tensor_calculator_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/image_preprocessing_options.proto b/mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto similarity index 89% rename from mediapipe/tasks/cc/components/image_preprocessing_options.proto rename to mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto index d1685c319..bf4fc9067 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto @@ -15,14 +15,14 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components; +package mediapipe.tasks.components.processors.proto; import "mediapipe/calculators/tensor/image_to_tensor_calculator.proto"; import "mediapipe/framework/calculator.proto"; -message ImagePreprocessingOptions { +message ImagePreprocessingGraphOptions { extend mediapipe.CalculatorOptions { - optional ImagePreprocessingOptions ext = 456882436; + optional ImagePreprocessingGraphOptions ext = 456882436; } // Options for the ImageToTensor calculator encapsulated by the diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index 7b144e7aa..d473a8dc3 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -37,7 +37,6 @@ cc_library( "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", @@ -105,10 +104,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:model_asset_bundle_resources", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources_cache", diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index 8d555b12c..e7fcf6fd9 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/packet.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/model_resources.h" diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 7b6a8c79d..d7e983d81 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -29,8 +29,6 @@ limitations under the License. #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" -#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" diff --git a/mediapipe/tasks/cc/vision/hand_detector/BUILD b/mediapipe/tasks/cc/vision/hand_detector/BUILD index 71cef6270..55162d09b 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/BUILD @@ -46,7 +46,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index 06bb2e549..c24548c9b 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -35,7 +35,7 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -226,21 +226,23 @@ class HandDetectorGraph : public core::ModelTaskGraph { Source norm_rect_in, Graph& graph) { // Add image preprocessing subgraph. The model expects aspect ratio // unchanged. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); auto& image_to_tensor_options = *preprocessing - .GetOptions() + .GetOptions() .mutable_image_to_tensor_options(); image_to_tensor_options.set_keep_aspect_ratio(true); image_to_tensor_options.set_border_mode( mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - subgraph_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions< + components::processors::proto::ImagePreprocessingGraphOptions>())); image_in >> preprocessing.In("IMAGE"); norm_rect_in >> preprocessing.In("NORM_RECT"); auto preprocessed_tensors = preprocessing.Out("TENSORS"); diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 3b869eab4..46948ee6c 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -35,7 +35,6 @@ cc_library( "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", @@ -89,7 +88,7 @@ cc_library( "//mediapipe/modules/hand_landmark/calculators:hand_landmarks_to_rect_calculator", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/utils:gate", - "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc index 3a9ed5bc2..2b818b2e5 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc @@ -22,7 +22,6 @@ limitations under the License. #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/model_resources.h" diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index 1f127deb8..014830ba2 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/utils/gate.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" @@ -281,14 +281,15 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph { Source hand_rect, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options)); - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - subgraph_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In("IMAGE"); hand_rect >> preprocessing.In("NORM_RECT"); auto image_size = preprocessing[Output>("IMAGE_SIZE")]; diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index 2b93aa262..514e601ef 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -59,11 +59,11 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 2fc88bcb6..2d0379c66 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -23,10 +23,10 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" @@ -135,14 +135,15 @@ class ImageClassifierGraph : public core::ModelTaskGraph { Source norm_rect_in, Graph& graph) { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); diff --git a/mediapipe/tasks/cc/vision/image_embedder/BUILD b/mediapipe/tasks/cc/vision/image_embedder/BUILD index 8fdb97ccd..d729eaf1a 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/BUILD +++ b/mediapipe/tasks/cc/vision/image_embedder/BUILD @@ -57,12 +57,12 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", "@com_google_absl//absl/status:statusor", diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index bf0dcf3c7..81ccb5361 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -20,10 +20,10 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h" @@ -130,14 +130,15 @@ class ImageEmbedderGraph : public core::ModelTaskGraph { Source norm_rect_in, Graph& graph) { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 595eef568..2124fe6e0 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -56,10 +56,10 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator", "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator_cc_proto", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 44742e043..d5eb5af0d 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -27,8 +27,8 @@ limitations under the License. #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" @@ -243,14 +243,15 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index b8002fa96..c2dd9995d 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -71,9 +71,9 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/utils:source_or_node_output", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index b149cea0f..f5dc7e061 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" @@ -561,14 +561,15 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); From 3c53ec2cdbe5df2aabf6a20f3b6c9b4efa76cb71 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Wed, 23 Nov 2022 10:09:42 -0800 Subject: [PATCH 049/346] Do not expose DrishtiGraphGPUData.h in public header This class is an implementation detail. PiperOrigin-RevId: 490530823 --- mediapipe/gpu/BUILD | 7 +------ mediapipe/gpu/MPPMetalHelper.h | 24 +++++++++++------------- mediapipe/gpu/MPPMetalHelper.mm | 6 ++++++ mediapipe/objc/MPPGraph.mm | 1 - 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index f97eed678..42cd9cdc6 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -550,12 +550,7 @@ cc_library( name = "gpu_shared_data_header", textual_hdrs = [ "gpu_shared_data_internal.h", - ] + select({ - "//conditions:default": [], - "//mediapipe:apple": [ - "MPPGraphGPUData.h", - ], - }), + ], visibility = ["//visibility:private"], deps = [ ":gl_base", diff --git a/mediapipe/gpu/MPPMetalHelper.h b/mediapipe/gpu/MPPMetalHelper.h index f3662422e..6ae0f3cf9 100644 --- a/mediapipe/gpu/MPPMetalHelper.h +++ b/mediapipe/gpu/MPPMetalHelper.h @@ -21,37 +21,35 @@ #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_type.h" -#include "mediapipe/gpu/MPPGraphGPUData.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" NS_ASSUME_NONNULL_BEGIN @interface MPPMetalHelper : NSObject { - MPPGraphGPUData* _gpuShared; } - (instancetype)init NS_UNAVAILABLE; /// Initialize. This initializer is recommended for calculators. -- (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext*)cc; +- (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext *)cc; /// Initialize. -- (instancetype)initWithGpuResources:(mediapipe::GpuResources*)gpuResources +- (instancetype)initWithGpuResources:(mediapipe::GpuResources *)gpuResources NS_DESIGNATED_INITIALIZER; /// Configures a calculator's contract for accessing GPU resources. /// Calculators should use this in GetContract. -+ (absl::Status)updateContract:(mediapipe::CalculatorContract*)cc; ++ (absl::Status)updateContract:(mediapipe::CalculatorContract *)cc; /// Deprecated initializer. -- (instancetype)initWithSidePackets:(const mediapipe::PacketSet&)inputSidePackets; +- (instancetype)initWithSidePackets:(const mediapipe::PacketSet &)inputSidePackets; /// Deprecated initializer. -- (instancetype)initWithGpuSharedData:(mediapipe::GpuSharedData*)gpuShared; +- (instancetype)initWithGpuSharedData:(mediapipe::GpuSharedData *)gpuShared; /// Configures a calculator's side packets for accessing GPU resources. /// Calculators should use this in FillExpectations. -+ (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet*)inputSidePackets; ++ (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet *)inputSidePackets; /// Get a metal command buffer. /// Calculators should use this method instead of getting a buffer from the @@ -63,23 +61,23 @@ NS_ASSUME_NONNULL_BEGIN /// Creates a CVMetalTextureRef linked to the provided GpuBuffer. /// Ownership follows the copy rule, so the caller is responsible for /// releasing the CVMetalTextureRef. -- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer; +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer; /// Creates a CVMetalTextureRef linked to the provided GpuBuffer given a specific plane. /// Ownership follows the copy rule, so the caller is responsible for /// releasing the CVMetalTextureRef. -- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer plane:(size_t)plane; /// Returns a MTLTexture linked to the provided GpuBuffer. /// A calculator can freely use it as a rendering source, but it should not /// use it as a rendering target if the GpuBuffer was provided as an input. -- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer; +- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer; /// Returns a MTLTexture linked to the provided GpuBuffer given a specific plane. /// A calculator can freely use it as a rendering source, but it should not /// use it as a rendering target if the GpuBuffer was provided as an input. -- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer +- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer plane:(size_t)plane; /// Obtains a new GpuBuffer to be used as an output destination. @@ -91,7 +89,7 @@ NS_ASSUME_NONNULL_BEGIN format:(mediapipe::GpuBufferFormat)format; /// Convenience method to load a Metal library stored as a bundle resource. -- (id)newLibraryWithResourceName:(NSString*)name error:(NSError* _Nullable*)error; +- (id)newLibraryWithResourceName:(NSString *)name error:(NSError *_Nullable *)error; /// Shared Metal resources. @property(readonly) id mtlDevice; diff --git a/mediapipe/gpu/MPPMetalHelper.mm b/mediapipe/gpu/MPPMetalHelper.mm index ce6620972..dc1e27a5c 100644 --- a/mediapipe/gpu/MPPMetalHelper.mm +++ b/mediapipe/gpu/MPPMetalHelper.mm @@ -14,11 +14,17 @@ #import "mediapipe/gpu/MPPMetalHelper.h" +#import "mediapipe/gpu/MPPGraphGPUData.h" #import "mediapipe/gpu/graph_support.h" #import "GTMDefines.h" #include "mediapipe/framework/port/ret_check.h" +@interface MPPMetalHelper () { + MPPGraphGPUData* _gpuShared; +} +@end + namespace mediapipe { // Using a C++ class so it can be declared as a friend of LegacyCalculatorSupport. diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index 080cca20f..1bd177e80 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -24,7 +24,6 @@ #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/graph_service.h" -#include "mediapipe/gpu/MPPGraphGPUData.h" #include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" #include "mediapipe/objc/util.h" From 54d1744c8f5ee102679386b84e3e3812e352bc7a Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Wed, 23 Nov 2022 10:13:48 -0800 Subject: [PATCH 050/346] Remove DrishtiGraphGPUData, add MetalSharedResources This class is unused except by the Metal helper; let's narrow it down and simplify gpu_shared_data. PiperOrigin-RevId: 490531767 --- mediapipe/gpu/BUILD | 50 +++------ mediapipe/gpu/MPPGraphGPUData.h | 71 ------------- mediapipe/gpu/MPPGraphGPUData.mm | 124 ---------------------- mediapipe/gpu/MPPGraphGPUDataTests.mm | 86 --------------- mediapipe/gpu/MPPMetalHelper.mm | 31 +++--- mediapipe/gpu/gpu_shared_data_internal.cc | 13 +-- mediapipe/gpu/gpu_shared_data_internal.h | 18 ++-- mediapipe/objc/BUILD | 2 +- 8 files changed, 46 insertions(+), 349 deletions(-) delete mode 100644 mediapipe/gpu/MPPGraphGPUData.h delete mode 100644 mediapipe/gpu/MPPGraphGPUData.mm delete mode 100644 mediapipe/gpu/MPPGraphGPUDataTests.mm diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 42cd9cdc6..9cc670fb6 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -470,12 +470,9 @@ objc_library( ) objc_library( - name = "MPPGraphGPUData", - srcs = [ - "MPPGraphGPUData.mm", - "gpu_shared_data_internal.cc", - ], - hdrs = ["MPPGraphGPUData.h"], + name = "metal_shared_resources", + srcs = ["metal_shared_resources.mm"], + hdrs = ["metal_shared_resources.h"], copts = [ "-x objective-c++", "-Wno-shorten-64-to-32", @@ -484,25 +481,9 @@ objc_library( sdk_frameworks = [ "CoreVideo", "Metal", - ] + select({ - "//conditions:default": [ - "OpenGLES", - ], - "//mediapipe:macos": [ - "OpenGL", - "AppKit", - ], - }), + ], visibility = ["//visibility:public"], deps = [ - ":gl_base", - ":gl_context", - ":gpu_buffer_multi_pool", - ":gpu_shared_data_header", - ":graph_support", - ":cv_texture_cache_manager", - "//mediapipe/gpu:gl_context_options_cc_proto", - "//mediapipe/framework:calculator_context", "//mediapipe/framework/port:ret_check", "@google_toolbox_for_mac//:GTM_Defines", ] + [ @@ -584,16 +565,19 @@ cc_library( cc_library( name = "gpu_shared_data_internal_actual", - srcs = select({ - "//conditions:default": [ - "gpu_shared_data_internal.cc", - ], - # iOS uses an Objective-C++ version of this, built in MPPGraphGPUData. - "//mediapipe:apple": [], - }), + srcs = [ + "gpu_shared_data_internal.cc", + ], hdrs = [ "gpu_shared_data_internal.h", ], + copts = select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + }), visibility = ["//visibility:private"], deps = [ "//mediapipe/gpu:gl_context_options_cc_proto", @@ -610,7 +594,7 @@ cc_library( ] + select({ "//conditions:default": [], "//mediapipe:apple": [ - ":MPPGraphGPUData", + ":metal_shared_resources", ":cv_texture_cache_manager", ], }), @@ -1139,8 +1123,8 @@ objc_library( name = "gl_ios_test_lib", testonly = 1, srcs = [ - "MPPGraphGPUDataTests.mm", "gl_ios_test.mm", + "metal_shared_resources_test.mm", ], copts = [ "-Wno-shorten-64-to-32", @@ -1150,7 +1134,7 @@ objc_library( ], features = ["-layering_check"], deps = [ - ":MPPGraphGPUData", + ":metal_shared_resources", ":gl_scaler_calculator", ":gpu_buffer_to_image_frame_calculator", ":gpu_shared_data_internal", diff --git a/mediapipe/gpu/MPPGraphGPUData.h b/mediapipe/gpu/MPPGraphGPUData.h deleted file mode 100644 index 3d8fc0c94..000000000 --- a/mediapipe/gpu/MPPGraphGPUData.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ -#define MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ - -#import -#import -#import - -#import "mediapipe/gpu/gl_base.h" -#import "mediapipe/gpu/gl_context.h" - -namespace mediapipe { -class GlContext; -class GpuBufferMultiPool; -} // namespace mediapipe - -@interface MPPGraphGPUData : NSObject { - // Shared buffer pool for GPU calculators. - mediapipe::GpuBufferMultiPool* _gpuBufferPool; - mediapipe::GlContext* _glContext; -} - -- (instancetype)init NS_UNAVAILABLE; - -/// Initialize. The provided multipool pointer must remain valid throughout -/// this object's lifetime. -- (instancetype)initWithContext:(mediapipe::GlContext*)context - multiPool:(mediapipe::GpuBufferMultiPool*)pool NS_DESIGNATED_INITIALIZER; - -/// Shared texture pool for GPU calculators. -/// For internal use by GlCalculatorHelper. -@property(readonly) mediapipe::GpuBufferMultiPool* gpuBufferPool; - -/// Shared OpenGL context. -#if TARGET_OS_OSX -@property(readonly) NSOpenGLContext* glContext; -@property(readonly) NSOpenGLPixelFormat* glPixelFormat; -#else -@property(readonly) EAGLContext* glContext; -#endif // TARGET_OS_OSX - -/// Shared texture cache. -#if TARGET_OS_OSX -@property(readonly) CVOpenGLTextureCacheRef textureCache; -#else -@property(readonly) CVOpenGLESTextureCacheRef textureCache; -#endif // TARGET_OS_OSX - -/// Shared Metal resources. -@property(readonly) id mtlDevice; -@property(readonly) id mtlCommandQueue; -#if COREVIDEO_SUPPORTS_METAL -@property(readonly) CVMetalTextureCacheRef mtlTextureCache; -#endif - -@end - -#endif // MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ diff --git a/mediapipe/gpu/MPPGraphGPUData.mm b/mediapipe/gpu/MPPGraphGPUData.mm deleted file mode 100644 index 8ac1eefa5..000000000 --- a/mediapipe/gpu/MPPGraphGPUData.mm +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "mediapipe/gpu/MPPGraphGPUData.h" - -#import "GTMDefines.h" - -#include "mediapipe/gpu/gl_context.h" -#include "mediapipe/gpu/gpu_buffer_multi_pool.h" - -#if TARGET_OS_OSX -#import -#else -#import -#endif // TARGET_OS_OSX - -@implementation MPPGraphGPUData - -@synthesize textureCache = _textureCache; -@synthesize mtlDevice = _mtlDevice; -@synthesize mtlCommandQueue = _mtlCommandQueue; -#if COREVIDEO_SUPPORTS_METAL -@synthesize mtlTextureCache = _mtlTextureCache; -#endif - -#if TARGET_OS_OSX -typedef CVOpenGLTextureCacheRef CVTextureCacheType; -#else -typedef CVOpenGLESTextureCacheRef CVTextureCacheType; -#endif // TARGET_OS_OSX - -- (instancetype)initWithContext:(mediapipe::GlContext *)context - multiPool:(mediapipe::GpuBufferMultiPool *)pool { - self = [super init]; - if (self) { - _gpuBufferPool = pool; - _glContext = context; - } - return self; -} - -- (void)dealloc { - if (_textureCache) { - _textureCache = NULL; - } -#if COREVIDEO_SUPPORTS_METAL - if (_mtlTextureCache) { - CFRelease(_mtlTextureCache); - _mtlTextureCache = NULL; - } -#endif -} - -#if TARGET_OS_OSX -- (NSOpenGLContext *)glContext { - return _glContext->nsgl_context(); -} - -- (NSOpenGLPixelFormat *) glPixelFormat { - return _glContext->nsgl_pixel_format(); -} -#else -- (EAGLContext *)glContext { - return _glContext->eagl_context(); -} -#endif // TARGET_OS_OSX - -- (CVTextureCacheType)textureCache { - @synchronized(self) { - if (!_textureCache) { - _textureCache = _glContext->cv_texture_cache(); - } - } - return _textureCache; -} - -- (mediapipe::GpuBufferMultiPool *)gpuBufferPool { - return _gpuBufferPool; -} - -- (id)mtlDevice { - @synchronized(self) { - if (!_mtlDevice) { - _mtlDevice = MTLCreateSystemDefaultDevice(); - } - } - return _mtlDevice; -} - -- (id)mtlCommandQueue { - @synchronized(self) { - if (!_mtlCommandQueue) { - _mtlCommandQueue = [self.mtlDevice newCommandQueue]; - } - } - return _mtlCommandQueue; -} - -#if COREVIDEO_SUPPORTS_METAL -- (CVMetalTextureCacheRef)mtlTextureCache { - @synchronized(self) { - if (!_mtlTextureCache) { - CVReturn __unused err = - CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); - NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d", err); - // TODO: register and flush metal caches too. - } - } - return _mtlTextureCache; -} -#endif - -@end diff --git a/mediapipe/gpu/MPPGraphGPUDataTests.mm b/mediapipe/gpu/MPPGraphGPUDataTests.mm deleted file mode 100644 index e8b50845b..000000000 --- a/mediapipe/gpu/MPPGraphGPUDataTests.mm +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import -#import - -#include - -#include "absl/memory/memory.h" -#include "mediapipe/framework/port/threadpool.h" - -#import "mediapipe/gpu/MPPGraphGPUData.h" -#import "mediapipe/gpu/gpu_shared_data_internal.h" - -@interface MPPGraphGPUDataTests : XCTestCase { -} -@end - -@implementation MPPGraphGPUDataTests - -// This test verifies that the internal Objective-C object is correctly -// released when the C++ wrapper is released. -- (void)testCorrectlyReleased { - __weak id gpuData = nil; - std::weak_ptr gpuRes; - @autoreleasepool { - mediapipe::GpuSharedData gpu_shared; - gpuRes = gpu_shared.gpu_resources; - gpuData = gpu_shared.gpu_resources->ios_gpu_data(); - XCTAssertNotEqual(gpuRes.lock(), nullptr); - XCTAssertNotNil(gpuData); - } - XCTAssertEqual(gpuRes.lock(), nullptr); - XCTAssertNil(gpuData); -} - -// This test verifies that the lazy initialization of the glContext instance -// variable is thread-safe. All threads should read the same value. -- (void)testGlContextThreadSafeLazyInitialization { - mediapipe::GpuSharedData gpu_shared; - constexpr int kNumThreads = 10; - EAGLContext* ogl_context[kNumThreads]; - auto pool = absl::make_unique(kNumThreads); - pool->StartWorkers(); - for (int i = 0; i < kNumThreads; ++i) { - pool->Schedule([&gpu_shared, &ogl_context, i] { - ogl_context[i] = gpu_shared.gpu_resources->ios_gpu_data().glContext; - }); - } - pool.reset(); - for (int i = 0; i < kNumThreads - 1; ++i) { - XCTAssertEqual(ogl_context[i], ogl_context[i + 1]); - } -} - -// This test verifies that the lazy initialization of the textureCache instance -// variable is thread-safe. All threads should read the same value. -- (void)testTextureCacheThreadSafeLazyInitialization { - mediapipe::GpuSharedData gpu_shared; - constexpr int kNumThreads = 10; - CFHolder texture_cache[kNumThreads]; - auto pool = absl::make_unique(kNumThreads); - pool->StartWorkers(); - for (int i = 0; i < kNumThreads; ++i) { - pool->Schedule([&gpu_shared, &texture_cache, i] { - texture_cache[i].reset(gpu_shared.gpu_resources->ios_gpu_data().textureCache); - }); - } - pool.reset(); - for (int i = 0; i < kNumThreads - 1; ++i) { - XCTAssertEqual(*texture_cache[i], *texture_cache[i + 1]); - } -} - -@end diff --git a/mediapipe/gpu/MPPMetalHelper.mm b/mediapipe/gpu/MPPMetalHelper.mm index dc1e27a5c..1acf7cbfb 100644 --- a/mediapipe/gpu/MPPMetalHelper.mm +++ b/mediapipe/gpu/MPPMetalHelper.mm @@ -14,14 +14,15 @@ #import "mediapipe/gpu/MPPMetalHelper.h" -#import "mediapipe/gpu/MPPGraphGPUData.h" +#import "mediapipe/gpu/gpu_buffer.h" #import "mediapipe/gpu/graph_support.h" +#import "mediapipe/gpu/metal_shared_resources.h" #import "GTMDefines.h" #include "mediapipe/framework/port/ret_check.h" @interface MPPMetalHelper () { - MPPGraphGPUData* _gpuShared; + mediapipe::GpuResources* _gpuResources; } @end @@ -46,7 +47,7 @@ class MetalHelperLegacySupport { - (instancetype)initWithGpuResources:(mediapipe::GpuResources*)gpuResources { self = [super init]; if (self) { - _gpuShared = gpuResources->ios_gpu_data(); + _gpuResources = gpuResources; } return self; } @@ -111,19 +112,19 @@ class MetalHelperLegacySupport { } - (id)mtlDevice { - return _gpuShared.mtlDevice; + return _gpuResources->metal_shared().resources().mtlDevice; } - (id)mtlCommandQueue { - return _gpuShared.mtlCommandQueue; + return _gpuResources->metal_shared().resources().mtlCommandQueue; } - (CVMetalTextureCacheRef)mtlTextureCache { - return _gpuShared.mtlTextureCache; + return _gpuResources->metal_shared().resources().mtlTextureCache; } - (id)commandBuffer { - return [_gpuShared.mtlCommandQueue commandBuffer]; + return [_gpuResources->metal_shared().resources().mtlCommandQueue commandBuffer]; } - (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer @@ -175,8 +176,9 @@ class MetalHelperLegacySupport { CVMetalTextureRef texture; CVReturn err = CVMetalTextureCacheCreateTextureFromImage( - NULL, _gpuShared.mtlTextureCache, mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, - metalPixelFormat, width, height, plane, &texture); + NULL, _gpuResources->metal_shared().resources().mtlTextureCache, + mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, metalPixelFormat, width, height, plane, + &texture); CHECK_EQ(err, kCVReturnSuccess); return texture; } @@ -197,19 +199,20 @@ class MetalHelperLegacySupport { } - (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width height:(int)height { - return _gpuShared.gpuBufferPool->GetBuffer(width, height); + return _gpuResources->gpu_buffer_pool().GetBuffer(width, height); } - (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width height:(int)height format:(mediapipe::GpuBufferFormat)format { - return _gpuShared.gpuBufferPool->GetBuffer(width, height, format); + return _gpuResources->gpu_buffer_pool().GetBuffer(width, height, format); } - (id)newLibraryWithResourceName:(NSString*)name error:(NSError * _Nullable *)error { - return [_gpuShared.mtlDevice newLibraryWithFile:[[NSBundle bundleForClass:[self class]] - pathForResource:name ofType:@"metallib"] - error:error]; + return [_gpuResources->metal_shared().resources().mtlDevice + newLibraryWithFile:[[NSBundle bundleForClass:[self class]] pathForResource:name + ofType:@"metallib"] + error:error]; } @end diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index 91723a7d1..203a8dfd1 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -21,7 +21,7 @@ #include "mediapipe/gpu/graph_support.h" #if __APPLE__ -#import "mediapipe/gpu/MPPGraphGPUData.h" +#include "mediapipe/gpu/metal_shared_resources.h" #endif // __APPLE__ namespace mediapipe { @@ -97,15 +97,14 @@ GpuResources::GpuResources(std::shared_ptr gl_context) #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER texture_caches_->RegisterTextureCache(gl_context->cv_texture_cache()); #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - ios_gpu_data_ = [[MPPGraphGPUData alloc] initWithContext:gl_context.get() - multiPool:&gpu_buffer_pool_]; + metal_shared_ = std::make_unique(); #endif // __APPLE__ } GpuResources::~GpuResources() { #if __APPLE__ - // Note: on Apple platforms, this object contains Objective-C objects. The - // destructor will release them, but ARC must be on. + // Note: on Apple platforms, this object contains Objective-C objects. + // The destructor will release them, but ARC must be on. #if !__has_feature(objc_arc) #error This file must be built with ARC. #endif @@ -196,10 +195,6 @@ GlContext::StatusOrGlContext GpuResources::GetOrCreateGlContext( GpuSharedData::GpuSharedData() : GpuSharedData(kPlatformGlContextNone) {} -#if __APPLE__ -MPPGraphGPUData* GpuResources::ios_gpu_data() { return ios_gpu_data_; } -#endif // __APPLE__ - extern const GraphService kGpuService; #if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER diff --git a/mediapipe/gpu/gpu_shared_data_internal.h b/mediapipe/gpu/gpu_shared_data_internal.h index 4fe6ba04e..3f7c67e2e 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.h +++ b/mediapipe/gpu/gpu_shared_data_internal.h @@ -31,15 +31,14 @@ #ifdef __APPLE__ #include "mediapipe/gpu/cv_texture_cache_manager.h" -#ifdef __OBJC__ -@class MPPGraphGPUData; -#else -struct MPPGraphGPUData; -#endif // __OBJC__ #endif // defined(__APPLE__) namespace mediapipe { +#ifdef __APPLE__ +class MetalSharedResources; +#endif // defined(__APPLE__) + // TODO: rename to GpuService or GpuManager or something. class GpuResources { public: @@ -56,9 +55,7 @@ class GpuResources { // Shared GL context for calculators. // TODO: require passing a context or node identifier. - const std::shared_ptr& gl_context() { - return gl_context(nullptr); - }; + const std::shared_ptr& gl_context() { return gl_context(nullptr); } const std::shared_ptr& gl_context(CalculatorContext* cc); @@ -66,7 +63,7 @@ class GpuResources { GpuBufferMultiPool& gpu_buffer_pool() { return gpu_buffer_pool_; } #ifdef __APPLE__ - MPPGraphGPUData* ios_gpu_data(); + MetalSharedResources& metal_shared() { return *metal_shared_; } #endif // defined(__APPLE__)§ absl::Status PrepareGpuNode(CalculatorNode* node); @@ -96,8 +93,7 @@ class GpuResources { GpuBufferMultiPool gpu_buffer_pool_; #ifdef __APPLE__ - // Note that this is an Objective-C object. - MPPGraphGPUData* ios_gpu_data_; + std::unique_ptr metal_shared_; #endif // defined(__APPLE__) std::map> named_executors_; diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index d77692164..fafdfee8a 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -83,11 +83,11 @@ objc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "//mediapipe/framework/port:threadpool", - "//mediapipe/gpu:MPPGraphGPUData", "//mediapipe/gpu:gl_base", "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_shared_data_internal", "//mediapipe/gpu:graph_support", + "//mediapipe/gpu:metal_shared_resources", "//mediapipe/gpu:pixel_buffer_pool_util", "//mediapipe/util:cpu_util", "@com_google_absl//absl/base:core_headers", From bfa57310c4dfb43e9ea3d5b24059b7e042836911 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 23 Nov 2022 10:17:46 -0800 Subject: [PATCH 051/346] Move TextPreprocessing to "processors" folder. PiperOrigin-RevId: 490532670 --- mediapipe/tasks/cc/components/BUILD | 43 ------------------- .../tasks/cc/components/processors/BUILD | 26 +++++++++++ .../cc/components/processors/proto/BUILD | 9 ++++ .../text_preprocessing_graph_options.proto | 2 +- .../text_preprocessing_graph.cc | 22 +++++----- .../text_preprocessing_graph.h | 30 +++++++------ mediapipe/tasks/cc/components/proto/BUILD | 9 ---- mediapipe/tasks/cc/text/text_classifier/BUILD | 4 +- .../text_classifier/text_classifier_graph.cc | 12 +++--- mediapipe/tasks/cc/text/text_embedder/BUILD | 4 +- .../text/text_embedder/text_embedder_graph.cc | 12 +++--- 11 files changed, 80 insertions(+), 93 deletions(-) delete mode 100644 mediapipe/tasks/cc/components/BUILD rename mediapipe/tasks/cc/components/{ => processors}/proto/text_preprocessing_graph_options.proto (96%) rename mediapipe/tasks/cc/components/{ => processors}/text_preprocessing_graph.cc (94%) rename mediapipe/tasks/cc/components/{ => processors}/text_preprocessing_graph.h (67%) diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD deleted file mode 100644 index 54a5207d2..000000000 --- a/mediapipe/tasks/cc/components/BUILD +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -# TODO: Investigate rewriting the build rule to only link -# the Bert Preprocessor if it's needed. -cc_library( - name = "text_preprocessing_graph", - srcs = ["text_preprocessing_graph.cc"], - hdrs = ["text_preprocessing_graph.h"], - deps = [ - "//mediapipe/calculators/tensor:bert_preprocessor_calculator", - "//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto", - "//mediapipe/calculators/tensor:regex_preprocessor_calculator", - "//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto", - "//mediapipe/calculators/tensor:text_to_tensor_calculator", - "//mediapipe/framework:subgraph", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:tensor", - "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/metadata:metadata_extractor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], - alwayslink = 1, -) diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 4946683f5..185bf231b 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -133,3 +133,29 @@ cc_library( ) # TODO: Enable this test + +# TODO: Investigate rewriting the build rule to only link +# the Bert Preprocessor if it's needed. +cc_library( + name = "text_preprocessing_graph", + srcs = ["text_preprocessing_graph.cc"], + hdrs = ["text_preprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/tensor:bert_preprocessor_calculator", + "//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto", + "//mediapipe/calculators/tensor:regex_preprocessor_calculator", + "//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto", + "//mediapipe/calculators/tensor:text_to_tensor_calculator", + "//mediapipe/framework:subgraph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index 9c58a8585..f48c4bad8 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -59,3 +59,12 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_proto", ], ) + +mediapipe_proto_library( + name = "text_preprocessing_graph_options_proto", + srcs = ["text_preprocessing_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto similarity index 96% rename from mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto rename to mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto index 926e3d7fb..a67cfd8a9 100644 --- a/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.components.processors.proto; import "mediapipe/framework/calculator.proto"; diff --git a/mediapipe/tasks/cc/components/text_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc similarity index 94% rename from mediapipe/tasks/cc/components/text_preprocessing_graph.cc rename to mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc index 6aad8fdd5..de16375bd 100644 --- a/mediapipe/tasks/cc/components/text_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include @@ -25,13 +25,14 @@ limitations under the License. #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/subgraph.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { @@ -41,7 +42,8 @@ using ::mediapipe::api2::SideInput; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::SideSource; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::proto::TextPreprocessingGraphOptions; +using ::mediapipe::tasks::components::processors::proto:: + TextPreprocessingGraphOptions; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; @@ -169,7 +171,7 @@ absl::StatusOr GetMaxSeqLen(const tflite::SubGraph& model_graph) { } } // namespace -absl::Status ConfigureTextPreprocessingSubgraph( +absl::Status ConfigureTextPreprocessingGraph( const ModelResources& model_resources, TextPreprocessingGraphOptions& options) { if (model_resources.GetTfLiteModel()->subgraphs()->size() != 1) { @@ -200,8 +202,7 @@ absl::Status ConfigureTextPreprocessingSubgraph( return absl::OkStatus(); } -// A "mediapipe.tasks.components.TextPreprocessingSubgraph" performs text -// preprocessing. +// A TextPreprocessingGraph performs text preprocessing. // - Accepts a std::string input and outputs CPU tensors. // // Inputs: @@ -216,9 +217,9 @@ absl::Status ConfigureTextPreprocessingSubgraph( // Vector containing the preprocessed input tensors for the TFLite model. // // The recommended way of using this subgraph is through the GraphBuilder API -// using the 'ConfigureTextPreprocessing()' function. See header file for more -// details. -class TextPreprocessingSubgraph : public mediapipe::Subgraph { +// using the 'ConfigureTextPreprocessingGraph()' function. See header file for +// more details. +class TextPreprocessingGraph : public mediapipe::Subgraph { public: absl::StatusOr GetConfig( mediapipe::SubgraphContext* sc) override { @@ -267,8 +268,9 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph { } }; REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::TextPreprocessingSubgraph); + ::mediapipe::tasks::components::processors::TextPreprocessingGraph); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/text_preprocessing_graph.h b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h similarity index 67% rename from mediapipe/tasks/cc/components/text_preprocessing_graph.h rename to mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h index b031a5550..43d57be29 100644 --- a/mediapipe/tasks/cc/components/text_preprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h @@ -13,26 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_ #include "absl/status/status.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" -// Configures a TextPreprocessing subgraph using the provided `model_resources` +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { + +// Configures a TextPreprocessingGraph using the provided `model_resources` // and TextPreprocessingGraphOptions. // - Accepts a std::string input and outputs CPU tensors. // // Example usage: // // auto& preprocessing = -// graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); +// graph.AddNode("mediapipe.tasks.components.processors.TextPreprocessingSubgraph"); // MP_RETURN_IF_ERROR(ConfigureTextPreprocessingSubgraph( // model_resources, // &preprocessing.GetOptions())); // -// The resulting TextPreprocessing subgraph has the following I/O: +// The resulting TextPreprocessingGraph has the following I/O: // Inputs: // TEXT - std::string // The text to preprocess. @@ -43,16 +48,13 @@ limitations under the License. // Outputs: // TENSORS - std::vector // Vector containing the preprocessed input tensors for the TFLite model. -namespace mediapipe { -namespace tasks { -namespace components { - -absl::Status ConfigureTextPreprocessingSubgraph( - const tasks::core::ModelResources& model_resources, - tasks::components::proto::TextPreprocessingGraphOptions& options); +absl::Status ConfigureTextPreprocessingGraph( + const core::ModelResources& model_resources, + proto::TextPreprocessingGraphOptions& options); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/proto/BUILD b/mediapipe/tasks/cc/components/proto/BUILD index 4534a1652..569023753 100644 --- a/mediapipe/tasks/cc/components/proto/BUILD +++ b/mediapipe/tasks/cc/components/proto/BUILD @@ -22,12 +22,3 @@ mediapipe_proto_library( name = "segmenter_options_proto", srcs = ["segmenter_options.proto"], ) - -mediapipe_proto_library( - name = "text_preprocessing_graph_options_proto", - srcs = ["text_preprocessing_graph_options.proto"], - deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - ], -) diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index 01adc9fc3..61395cf4e 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -52,11 +52,11 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", - "//mediapipe/tasks/cc/components:text_preprocessing_graph", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:text_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources_calculator", "//mediapipe/tasks/cc/core:model_task_graph", diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc index 9a7dce1aa..3be92f309 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc @@ -25,8 +25,8 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" @@ -115,12 +115,12 @@ class TextClassifierGraph : public core::ModelTaskGraph { Graph& graph) { // Adds preprocessing calculators and connects them to the text input // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); - MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.TextPreprocessingGraph"); + MP_RETURN_IF_ERROR(components::processors::ConfigureTextPreprocessingGraph( model_resources, preprocessing.GetOptions< - tasks::components::proto::TextPreprocessingGraphOptions>())); + components::processors::proto::TextPreprocessingGraphOptions>())); text_in >> preprocessing.In(kTextTag); // Adds both InferenceCalculator and ModelResourcesCalculator. diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD index 27c9cb730..f19af35be 100644 --- a/mediapipe/tasks/cc/text/text_embedder/BUILD +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -54,11 +54,11 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", - "//mediapipe/tasks/cc/components:text_preprocessing_graph", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:text_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc index c54636ee2..225ef07bd 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc @@ -23,8 +23,8 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" @@ -107,12 +107,12 @@ class TextEmbedderGraph : public core::ModelTaskGraph { Graph& graph) { // Adds preprocessing calculators and connects them to the text input // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); - MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.TextPreprocessingGraph"); + MP_RETURN_IF_ERROR(components::processors::ConfigureTextPreprocessingGraph( model_resources, preprocessing.GetOptions< - tasks::components::proto::TextPreprocessingGraphOptions>())); + components::processors::proto::TextPreprocessingGraphOptions>())); text_in >> preprocessing.In(kTextTag); // Adds both InferenceCalculator and ModelResourcesCalculator. From 41a7f9d7d6fdc0bfd1c9e7d4cc00532512474de2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 23 Nov 2022 15:23:02 -0800 Subject: [PATCH 052/346] Internal change PiperOrigin-RevId: 490595529 --- mediapipe/web/graph_runner/graph_runner.ts | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index c4654794c..378bc0a4d 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -176,10 +176,14 @@ export class GraphRunner { if (glCanvas !== undefined) { this.wasmModule.canvas = glCanvas; - } else { + } else if (typeof OffscreenCanvas !== 'undefined') { // If no canvas is provided, assume Chrome/Firefox and just make an // OffscreenCanvas for GPU processing. this.wasmModule.canvas = new OffscreenCanvas(1, 1); + } else { + console.warn('OffscreenCanvas not detected and GraphRunner constructor ' + + 'glCanvas parameter is undefined. Creating backup canvas.'); + this.wasmModule.canvas = document.createElement('canvas'); } } From 0bdb48ceb18a772158b92793daf6ac4bf8ce6f76 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Wed, 23 Nov 2022 16:17:02 -0800 Subject: [PATCH 053/346] Use kUtilityFramebuffer in GlCalculatorHelper All calculators using the same context can share a single framebuffer object. PiperOrigin-RevId: 490605074 --- mediapipe/gpu/gl_calculator_helper.cc | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/mediapipe/gpu/gl_calculator_helper.cc b/mediapipe/gpu/gl_calculator_helper.cc index 7d317e0f1..9b217ddfd 100644 --- a/mediapipe/gpu/gl_calculator_helper.cc +++ b/mediapipe/gpu/gl_calculator_helper.cc @@ -27,19 +27,7 @@ namespace mediapipe { GlCalculatorHelper::GlCalculatorHelper() {} -GlCalculatorHelper::~GlCalculatorHelper() { - if (!Initialized()) return; - RunInGlContext( - [this] { - if (framebuffer_) { - glDeleteFramebuffers(1, &framebuffer_); - framebuffer_ = 0; - } - return absl::OkStatus(); - }, - /*calculator_context=*/nullptr) - .IgnoreError(); -} +GlCalculatorHelper::~GlCalculatorHelper() {} void GlCalculatorHelper::InitializeInternal(CalculatorContext* cc, GpuResources* gpu_resources) { @@ -125,9 +113,9 @@ void GlCalculatorHelper::CreateFramebuffer() { // Our framebuffer will have a color attachment but no depth attachment, // so it's important that the depth test be off. It is disabled by default, // but we wanted to be explicit. - // TODO: move this to glBindFramebuffer? + // TODO: move this to glBindFramebuffer? Or just remove. glDisable(GL_DEPTH_TEST); - glGenFramebuffers(1, &framebuffer_); + framebuffer_ = kUtilityFramebuffer.Get(*gl_context_); } void GlCalculatorHelper::BindFramebuffer(const GlTexture& dst) { From 395d9d8ea21c93bbefb37ad980ad41f66b9a2f9f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sun, 27 Nov 2022 00:05:08 -0800 Subject: [PATCH 054/346] Instantiate GetDetectionVectorItemCalculator variant of GetVectorItemCalculator<>. PiperOrigin-RevId: 491123314 --- mediapipe/calculators/core/BUILD | 1 + mediapipe/calculators/core/get_vector_item_calculator.cc | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 39837fadb..3b658eb5b 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -1299,6 +1299,7 @@ cc_library( "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/calculators/core/get_vector_item_calculator.cc b/mediapipe/calculators/core/get_vector_item_calculator.cc index 51fb46b98..3306e4ff3 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.cc +++ b/mediapipe/calculators/core/get_vector_item_calculator.cc @@ -15,6 +15,7 @@ #include "mediapipe/calculators/core/get_vector_item_calculator.h" #include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" namespace mediapipe { @@ -32,5 +33,9 @@ using GetClassificationListVectorItemCalculator = GetVectorItemCalculator; REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator); +using GetDetectionVectorItemCalculator = + GetVectorItemCalculator; +REGISTER_CALCULATOR(GetDetectionVectorItemCalculator); + } // namespace api2 } // namespace mediapipe From 153edc59a111c12b940169a272b36772fcd519a1 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 28 Nov 2022 09:52:40 -0800 Subject: [PATCH 055/346] Add support for browsers without SIMD PiperOrigin-RevId: 491371277 --- mediapipe/tasks/web/BUILD | 12 ++ mediapipe/tasks/web/audio.ts | 5 +- mediapipe/tasks/web/audio/BUILD | 1 + .../tasks/web/audio/audio_classifier/BUILD | 2 +- .../audio_classifier/audio_classifier.ts | 41 ++---- .../audio/audio_embedder/audio_embedder.ts | 28 ++-- mediapipe/tasks/web/audio/index.ts | 1 + mediapipe/tasks/web/core/BUILD | 9 +- mediapipe/tasks/web/core/fileset_resolver.ts | 130 ++++++++++++++++++ mediapipe/tasks/web/core/task_runner.ts | 45 +++++- ..._loader_options.d.ts => wasm_fileset.d.ts} | 4 +- mediapipe/tasks/web/text.ts | 5 +- mediapipe/tasks/web/text/BUILD | 1 + mediapipe/tasks/web/text/index.ts | 1 + .../tasks/web/text/text_classifier/BUILD | 1 - .../text/text_classifier/text_classifier.ts | 39 ++---- mediapipe/tasks/web/text/text_embedder/BUILD | 1 - .../web/text/text_embedder/text_embedder.ts | 42 ++---- mediapipe/tasks/web/vision.ts | 4 +- mediapipe/tasks/web/vision/BUILD | 1 + .../gesture_recognizer/gesture_recognizer.ts | 46 +++---- .../vision/hand_landmarker/hand_landmarker.ts | 46 +++---- .../image_classifier/image_classifier.ts | 41 ++---- .../vision/image_embedder/image_embedder.ts | 40 ++---- mediapipe/tasks/web/vision/index.ts | 1 + .../vision/object_detector/object_detector.ts | 40 ++---- mediapipe/web/graph_runner/graph_runner.ts | 8 +- third_party/wasm_files.bzl | 76 +++++++--- 28 files changed, 410 insertions(+), 261 deletions(-) create mode 100644 mediapipe/tasks/web/core/fileset_resolver.ts rename mediapipe/tasks/web/core/{wasm_loader_options.d.ts => wasm_fileset.d.ts} (88%) diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index 7e5d02892..20e717433 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -13,10 +13,16 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_files(srcs = [ "wasm/audio_wasm_internal.js", "wasm/audio_wasm_internal.wasm", + "wasm/audio_wasm_nosimd_internal.js", + "wasm/audio_wasm_nosimd_internal.wasm", "wasm/text_wasm_internal.js", "wasm/text_wasm_internal.wasm", + "wasm/text_wasm_nosimd_internal.js", + "wasm/text_wasm_nosimd_internal.wasm", "wasm/vision_wasm_internal.js", "wasm/vision_wasm_internal.wasm", + "wasm/vision_wasm_nosimd_internal.js", + "wasm/vision_wasm_nosimd_internal.wasm", ]) # Audio @@ -57,6 +63,8 @@ pkg_npm( deps = [ "wasm/audio_wasm_internal.js", "wasm/audio_wasm_internal.wasm", + "wasm/audio_wasm_nosimd_internal.js", + "wasm/audio_wasm_nosimd_internal.wasm", ":audio_bundle", ], ) @@ -99,6 +107,8 @@ pkg_npm( deps = [ "wasm/text_wasm_internal.js", "wasm/text_wasm_internal.wasm", + "wasm/text_wasm_nosimd_internal.js", + "wasm/text_wasm_nosimd_internal.wasm", ":text_bundle", ], ) @@ -141,6 +151,8 @@ pkg_npm( deps = [ "wasm/vision_wasm_internal.js", "wasm/vision_wasm_internal.wasm", + "wasm/vision_wasm_nosimd_internal.js", + "wasm/vision_wasm_nosimd_internal.wasm", ":vision_bundle", ], ) diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts index 8c522efcc..2f4fb0315 100644 --- a/mediapipe/tasks/web/audio.ts +++ b/mediapipe/tasks/web/audio.ts @@ -14,11 +14,12 @@ * limitations under the License. */ -import {AudioClassifier as AudioClassifierImpl, AudioEmbedder as AudioEmbedderImpl} from '../../tasks/web/audio/index'; +import {AudioClassifier as AudioClassifierImpl, AudioEmbedder as AudioEmbedderImpl, FilesetResolver as FilesetResolverImpl} from '../../tasks/web/audio/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. const AudioClassifier = AudioClassifierImpl; const AudioEmbedder = AudioEmbedderImpl; +const FilesetResolver = FilesetResolverImpl; -export {AudioClassifier, AudioEmbedder}; +export {AudioClassifier, AudioEmbedder, FilesetResolver}; diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index acd7494d7..d08602521 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -10,5 +10,6 @@ mediapipe_ts_library( deps = [ "//mediapipe/tasks/web/audio/audio_classifier", "//mediapipe/tasks/web/audio/audio_embedder", + "//mediapipe/tasks/web/core:fileset_resolver", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 498b17845..c419d3b98 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/web/graph_runner:graph_runner_ts", + "//mediapipe/tasks/web/core:task_runner", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 20c745383..e606019f2 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -22,8 +22,8 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; // Placeholder for internal dependency on trusted resource url import {AudioClassifierOptions} from './audio_classifier_options'; @@ -50,28 +50,17 @@ export class AudioClassifier extends AudioTaskRunner { /** * Initializes the Wasm runtime and creates a new audio classifier from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param audioClassifierOptions The options for the audio classifier. Note * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, - audioClassifierOptions: AudioClassifierOptions): + wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file loaded with this mechanism is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const classifier = await createMediaPipeLib( - AudioClassifier, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const classifier = await TaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset); await classifier.setOptions(audioClassifierOptions); return classifier; } @@ -79,31 +68,31 @@ export class AudioClassifier extends AudioTaskRunner { /** * Initializes the Wasm runtime and creates a new audio classifier based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return AudioClassifier.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new audio classifier based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return AudioClassifier.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } protected override get baseOptions(): BaseOptionsProto|undefined { diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 9dce02862..c87aceabe 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -24,7 +24,7 @@ import {Embedding} from '../../../../tasks/web/components/containers/embedding_r import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -52,25 +52,25 @@ export class AudioEmbedder extends AudioTaskRunner { /** * Initializes the Wasm runtime and creates a new audio embedder from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param audioEmbedderOptions The options for the audio embedder. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, audioEmbedderOptions: AudioEmbedderOptions): Promise { // Create a file locator based on the loader options const fileLocator: FileLocator = { locateFile() { // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); + return wasmFileset.wasmBinaryPath.toString(); } }; const embedder = await createMediaPipeLib( - AudioEmbedder, wasmLoaderOptions.wasmLoaderPath, + AudioEmbedder, wasmFileset.wasmLoaderPath, /* assetLoaderScript= */ undefined, /* glCanvas= */ undefined, fileLocator); await embedder.setOptions(audioEmbedderOptions); @@ -80,31 +80,31 @@ export class AudioEmbedder extends AudioTaskRunner { /** * Initializes the Wasm runtime and creates a new audio embedder based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the TFLite model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return AudioEmbedder.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new audio embedder based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return AudioEmbedder.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } protected override get baseOptions(): BaseOptionsProto|undefined { diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index 17a908f30..dbad8c617 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -16,3 +16,4 @@ export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index 6eca8bb4a..d709e3409 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -8,7 +8,7 @@ mediapipe_ts_declaration( name = "core", srcs = [ "base_options.d.ts", - "wasm_loader_options.d.ts", + "wasm_fileset.d.ts", ], ) @@ -18,12 +18,19 @@ mediapipe_ts_library( "task_runner.ts", ], deps = [ + ":core", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", ], ) +mediapipe_ts_library( + name = "fileset_resolver", + srcs = ["fileset_resolver.ts"], + deps = [":core"], +) + mediapipe_ts_declaration( name = "classifier_options", srcs = ["classifier_options.d.ts"], diff --git a/mediapipe/tasks/web/core/fileset_resolver.ts b/mediapipe/tasks/web/core/fileset_resolver.ts new file mode 100644 index 000000000..7d68dbc16 --- /dev/null +++ b/mediapipe/tasks/web/core/fileset_resolver.ts @@ -0,0 +1,130 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Placeholder for internal dependency on trusted resource URL builder + +import {WasmFileset} from './wasm_fileset'; + +let supportsSimd: boolean|undefined; + +/** + * Simple WASM program to test compatibility with the M91 instruction set. + * Compiled from + * https://github.com/GoogleChromeLabs/wasm-feature-detect/blob/main/src/detectors/simd/module.wat + */ +const WASM_SIMD_CHECK = new Uint8Array([ + 0, 97, 115, 109, 1, 0, 0, 0, 1, 5, 1, 96, 0, 1, 123, 3, + 2, 1, 0, 10, 10, 1, 8, 0, 65, 0, 253, 15, 253, 98, 11 +]); + +async function isSimdSupported(): Promise { + if (supportsSimd === undefined) { + try { + await WebAssembly.instantiate(WASM_SIMD_CHECK); + supportsSimd = true; + } catch { + supportsSimd = false; + } + } + + return supportsSimd; +} + +async function createFileset( + taskName: string, basePath: string = '.'): Promise { + if (await isSimdSupported()) { + return { + wasmLoaderPath: + `/${basePath}/${taskName}_wasm_internal.js`, + wasmBinaryPath: + `/${basePath}/${taskName}_wasm_internal.wasm`, + }; + } else { + return { + wasmLoaderPath: + `/${basePath}/${taskName}_wasm_nosimd_internal.js`, + wasmBinaryPath: `/${basePath}/${ + taskName}_wasm_nosimd_internal.wasm`, + }; + } +} + +// tslint:disable:class-as-namespace + +/** + * Resolves the files required for the MediaPipe Task APIs. + * + * This class verifies whether SIMD is supported in the current environment and + * loads the SIMD files only if support is detected. The returned filesets + * require that the Wasm files are published without renaming. If this is not + * possible, you can invoke the MediaPipe Tasks APIs using a manually created + * `WasmFileset`. + */ +export class FilesetResolver { + /** + * Returns whether SIMD is supported in the current environment. + * + * If your environment requires custom locations for the MediaPipe Wasm files, + * you can use `isSimdSupported()` to decide whether to load the SIMD-based + * assets. + * + * @return Whether SIMD support was detected in the current environment. + */ + static isSimdSupported(): Promise { + return isSimdSupported(); + } + + /** + * Creates a fileset for the MediaPipe Audio tasks. + * + * @param basePath An optional base path to specify the directory the Wasm + * files should be loaded from. If not specified, the Wasm files are + * loaded from the host's root directory. + * @return A `WasmFileset` that can be used to initialize MediaPipe Audio + * tasks. + */ + static forAudioTasks(basePath?: string): Promise { + return createFileset('audio', basePath); + } + + /** + * Creates a fileset for the MediaPipe Text tasks. + * + * @param basePath An optional base path to specify the directory the Wasm + * files should be loaded from. If not specified, the Wasm files are + * loaded from the host's root directory. + * @return A `WasmFileset` that can be used to initialize MediaPipe Text + * tasks. + */ + static forTextTasks(basePath?: string): Promise { + return createFileset('text', basePath); + } + + /** + * Creates a fileset for the MediaPipe Vision tasks. + * + * @param basePath An optional base path to specify the directory the Wasm + * files should be loaded from. If not specified, the Wasm files are + * loaded from the host's root directory. + * @return A `WasmFileset` that can be used to initialize MediaPipe Vision + * tasks. + */ + static forVisionTasks(basePath?: string): Promise { + return createFileset('vision', basePath); + } +} + + diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 67aa4e4df..4085be697 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -14,9 +14,14 @@ * limitations under the License. */ -import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; +import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; -import {GraphRunner, WasmModule} from '../../../web/graph_runner/graph_runner'; +import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; + +import {WasmFileset} from './wasm_fileset'; + +// None of the MP Tasks ship bundle assets. +const NO_ASSETS = undefined; // tslint:disable-next-line:enforce-name-casing const WasmMediaPipeImageLib = @@ -26,8 +31,40 @@ const WasmMediaPipeImageLib = export abstract class TaskRunner extends WasmMediaPipeImageLib { private processingErrors: Error[] = []; - constructor(wasmModule: WasmModule) { - super(wasmModule); + /** + * Creates a new instance of a Mediapipe Task. Determines if SIMD is + * supported and loads the relevant WASM binary. + * @return A fully instantiated instance of `T`. + */ + protected static async createInstance( + type: WasmMediaPipeConstructor, initializeCanvas: boolean, + fileset: WasmFileset): Promise { + const fileLocator: FileLocator = { + locateFile() { + // The only file loaded with this mechanism is the Wasm binary + return fileset.wasmBinaryPath.toString(); + } + }; + + if (initializeCanvas) { + // Fall back to an OffscreenCanvas created by the GraphRunner if + // OffscreenCanvas is available + const canvas = typeof OffscreenCanvas === 'undefined' ? + document.createElement('canvas') : + undefined; + return createMediaPipeLib( + type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); + } else { + return createMediaPipeLib( + type, fileset.wasmLoaderPath, NO_ASSETS, /* glCanvas= */ null, + fileLocator); + } + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); // Disables the automatic render-to-screen code, which allows for pure // CPU processing. diff --git a/mediapipe/tasks/web/core/wasm_loader_options.d.ts b/mediapipe/tasks/web/core/wasm_fileset.d.ts similarity index 88% rename from mediapipe/tasks/web/core/wasm_loader_options.d.ts rename to mediapipe/tasks/web/core/wasm_fileset.d.ts index 74436583d..18227eab9 100644 --- a/mediapipe/tasks/web/core/wasm_loader_options.d.ts +++ b/mediapipe/tasks/web/core/wasm_fileset.d.ts @@ -16,8 +16,8 @@ // Placeholder for internal dependency on trusted resource url -/** An object containing the locations of all Wasm assets */ -export declare interface WasmLoaderOptions { +/** An object containing the locations of the Wasm assets */ +export declare interface WasmFileset { /** The path to the Wasm loader script. */ wasmLoaderPath: string; /** The path to the Wasm binary. */ diff --git a/mediapipe/tasks/web/text.ts b/mediapipe/tasks/web/text.ts index 8f15075c5..0636714b8 100644 --- a/mediapipe/tasks/web/text.ts +++ b/mediapipe/tasks/web/text.ts @@ -14,11 +14,12 @@ * limitations under the License. */ -import {TextClassifier as TextClassifierImpl, TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/index'; +import {FilesetResolver as FilesetResolverImpl, TextClassifier as TextClassifierImpl, TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. +const FilesetResolver = FilesetResolverImpl; const TextClassifier = TextClassifierImpl; const TextEmbedder = TextEmbedderImpl; -export {TextClassifier, TextEmbedder}; +export {FilesetResolver, TextClassifier, TextEmbedder}; diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index 4b465b0f5..159db1a0d 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -8,6 +8,7 @@ mediapipe_ts_library( name = "text_lib", srcs = ["index.ts"], deps = [ + "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/text/text_classifier", "//mediapipe/tasks/web/text/text_embedder", ], diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index d50db209c..a28e4dd1c 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -16,3 +16,4 @@ export * from '../../../tasks/web/text/text_classifier/text_classifier'; export * from '../../../tasks/web/text/text_embedder/text_embedder'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 71ef02c92..f3d272daa 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -26,7 +26,6 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 04789f5e1..197869a36 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -22,8 +22,7 @@ import {convertBaseOptionsToProto} from '../../../../tasks/web/components/proces import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; // Placeholder for internal dependency on trusted resource url import {TextClassifierOptions} from './text_classifier_options'; @@ -48,27 +47,17 @@ export class TextClassifier extends TaskRunner { /** * Initializes the Wasm runtime and creates a new text classifier from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param textClassifierOptions The options for the text classifier. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, textClassifierOptions: TextClassifierOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const classifier = await createMediaPipeLib( - TextClassifier, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const classifier = await TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset); await classifier.setOptions(textClassifierOptions); return classifier; } @@ -76,31 +65,31 @@ export class TextClassifier extends TaskRunner { /** * Initializes the Wasm runtime and creates a new text classifier based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return TextClassifier.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new text classifier based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return TextClassifier.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } /** diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 3f92b8ae1..b858f6b83 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -26,7 +26,6 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 2042a0985..511fd2411 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -24,8 +24,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; // Placeholder for internal dependency on trusted resource url import {TextEmbedderOptions} from './text_embedder_options'; @@ -52,27 +51,17 @@ export class TextEmbedder extends TaskRunner { /** * Initializes the Wasm runtime and creates a new text embedder from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param textEmbedderOptions The options for the text embedder. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, textEmbedderOptions: TextEmbedderOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const embedder = await createMediaPipeLib( - TextEmbedder, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const embedder = await TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset); await embedder.setOptions(textEmbedderOptions); return embedder; } @@ -80,31 +69,31 @@ export class TextEmbedder extends TaskRunner { /** * Initializes the Wasm runtime and creates a new text embedder based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the TFLite model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return TextEmbedder.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new text embedder based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return TextEmbedder.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } /** @@ -122,14 +111,11 @@ export class TextEmbedder extends TaskRunner { options.baseOptions, this.options.getBaseOptions()); this.options.setBaseOptions(baseOptionsProto); } - this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); } - /** * Performs embeding extraction on the provided text and waits synchronously * for the response. diff --git a/mediapipe/tasks/web/vision.ts b/mediapipe/tasks/web/vision.ts index 74a056464..f1ced59af 100644 --- a/mediapipe/tasks/web/vision.ts +++ b/mediapipe/tasks/web/vision.ts @@ -14,10 +14,11 @@ * limitations under the License. */ -import {GestureRecognizer as GestureRecognizerImpl, HandLandmarker as HandLandmarkerImpl, ImageClassifier as ImageClassifierImpl, ImageEmbedder as ImageEmbedderImpl, ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/index'; +import {FilesetResolver as FilesetResolverImpl, GestureRecognizer as GestureRecognizerImpl, HandLandmarker as HandLandmarkerImpl, ImageClassifier as ImageClassifierImpl, ImageEmbedder as ImageEmbedderImpl, ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. +const FilesetResolver = FilesetResolverImpl; const GestureRecognizer = GestureRecognizerImpl; const HandLandmarker = HandLandmarkerImpl; const ImageClassifier = ImageClassifierImpl; @@ -25,6 +26,7 @@ const ImageEmbedder = ImageEmbedderImpl; const ObjectDetector = ObjectDetectorImpl; export { + FilesetResolver, GestureRecognizer, HandLandmarker, ImageClassifier, diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 3c45fbfa6..42bc0a494 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -8,6 +8,7 @@ mediapipe_ts_library( name = "vision_lib", srcs = ["index.ts"], deps = [ + "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/vision/gesture_recognizer", "//mediapipe/tasks/web/vision/hand_landmarker", "//mediapipe/tasks/web/vision/image_classifier", diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index dd050d0f1..7441911c1 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -29,9 +29,9 @@ import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/han import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {GestureRecognizerOptions} from './gesture_recognizer_options'; @@ -82,28 +82,18 @@ export class GestureRecognizer extends /** * Initializes the Wasm runtime and creates a new gesture recognizer from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param gestureRecognizerOptions The options for the gesture recognizer. * Note that either a path to the model asset or a model buffer needs to * be provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, gestureRecognizerOptions: GestureRecognizerOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load via this mechanism is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const recognizer = await createMediaPipeLib( - GestureRecognizer, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const recognizer = await VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset); await recognizer.setOptions(gestureRecognizerOptions); return recognizer; } @@ -111,35 +101,37 @@ export class GestureRecognizer extends /** * Initializes the Wasm runtime and creates a new gesture recognizer based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return GestureRecognizer.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new gesture recognizer based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return GestureRecognizer.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } - constructor(wasmModule: WasmModule) { - super(wasmModule); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); this.options = new GestureRecognizerGraphOptions(); this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions(); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 32b1eed4b..6d69d568c 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -25,9 +25,9 @@ import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landm import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark} from '../../../../tasks/web/components/containers/landmark'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {HandLandmarkerOptions} from './hand_landmarker_options'; @@ -71,27 +71,17 @@ export class HandLandmarker extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new `HandLandmarker` from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param handLandmarkerOptions The options for the HandLandmarker. * Note that either a path to the model asset or a model buffer needs to * be provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, handLandmarkerOptions: HandLandmarkerOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load via this mechanism is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const landmarker = await createMediaPipeLib( - HandLandmarker, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const landmarker = await VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset); await landmarker.setOptions(handLandmarkerOptions); return landmarker; } @@ -99,35 +89,37 @@ export class HandLandmarker extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new `HandLandmarker` based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return HandLandmarker.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new `HandLandmarker` based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return HandLandmarker.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } - constructor(wasmModule: WasmModule) { - super(wasmModule); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); this.options = new HandLandmarkerGraphOptions(); this.handLandmarksDetectorGraphOptions = diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index b59cb6fb1..604795f9f 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -21,9 +21,9 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_classifier/proto/image_classifier_graph_options_pb'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageClassifierOptions} from './image_classifier_options'; @@ -49,28 +49,17 @@ export class ImageClassifier extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new image classifier from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location + * Wasm binary and its loader. * @param imageClassifierOptions The options for the image classifier. Note * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, - imageClassifierOptions: ImageClassifierOptions): + wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const classifier = await createMediaPipeLib( - ImageClassifier, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const classifier = await VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset); await classifier.setOptions(imageClassifierOptions); return classifier; } @@ -78,31 +67,31 @@ export class ImageClassifier extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new image classifier based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return ImageClassifier.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new image classifier based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return ImageClassifier.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } protected override get baseOptions(): BaseOptionsProto|undefined { diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index f96f1e961..68068db6d 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -23,9 +23,9 @@ import {Embedding} from '../../../../tasks/web/components/containers/embedding_r import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageEmbedderOptions} from './image_embedder_options'; @@ -51,27 +51,17 @@ export class ImageEmbedder extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new image embedder from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param imageEmbedderOptions The options for the image embedder. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, imageEmbedderOptions: ImageEmbedderOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const embedder = await createMediaPipeLib( - ImageEmbedder, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const embedder = await VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset); await embedder.setOptions(imageEmbedderOptions); return embedder; } @@ -79,31 +69,31 @@ export class ImageEmbedder extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new image embedder based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the TFLite model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return ImageEmbedder.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new image embedder based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return ImageEmbedder.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } protected override get baseOptions(): BaseOptionsProto|undefined { diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index d68c00cc7..0337a0f2f 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -19,3 +19,4 @@ export * from '../../../tasks/web/vision/image_embedder/image_embedder'; export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; export * from '../../../tasks/web/vision/object_detector/object_detector'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 44046cd1e..0f039acb2 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -19,9 +19,9 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ObjectDetectorOptions} from './object_detector_options'; @@ -48,27 +48,17 @@ export class ObjectDetector extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new object detector from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param objectDetectorOptions The options for the Object Detector. Note that * either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, objectDetectorOptions: ObjectDetectorOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const detector = await createMediaPipeLib( - ObjectDetector, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const detector = await VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset); await detector.setOptions(objectDetectorOptions); return detector; } @@ -76,31 +66,31 @@ export class ObjectDetector extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new object detector based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return ObjectDetector.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new object detector based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return ObjectDetector.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } protected override get baseOptions(): BaseOptionsProto|undefined { diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 378bc0a4d..9a0f7148c 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -133,9 +133,11 @@ export type ImageSource = /** A listener that will be invoked with an absl::StatusCode and message. */ export type ErrorListener = (code: number, message: string) => void; -// Internal type of constructors used for initializing GraphRunner and -// subclasses. -type WasmMediaPipeConstructor = +/** + * Internal type of constructors used for initializing GraphRunner and + * subclasses. + */ +export type WasmMediaPipeConstructor = (new ( module: WasmModule, canvas?: HTMLCanvasElement|OffscreenCanvas|null) => LibType); diff --git a/third_party/wasm_files.bzl b/third_party/wasm_files.bzl index 6bfde21ba..504f8567a 100644 --- a/third_party/wasm_files.bzl +++ b/third_party/wasm_files.bzl @@ -12,36 +12,72 @@ def wasm_files(): http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_js", - sha256 = "9419766229f24790388805d891af907cf11fe8e2cdacabcf016feb054b720c82", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1667934266184984"], - ) - - http_file( - name = "com_google_mediapipe_wasm_text_wasm_internal_js", - sha256 = "39d9445ab3b90f625a3332251fe82e59b40cd0501a5657475f3b115b7c6122c8", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1667934268229056"], - ) - - http_file( - name = "com_google_mediapipe_wasm_vision_wasm_internal_js", - sha256 = "b43c7078fe5da72990394af4fefd798bd844b4ac47849a49067bd68c3c910a3d", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1667934270239845"], + sha256 = "42d2d0ade6e2e8b81425b23686be93eb1423b7777f043eb8f18ad671e2ca803f", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1669173769507080"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm", - sha256 = "9f2abe2a51d1ebc854859f620759cec1cc643773f3748d0d19e0868578c3d746", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1667934272818542"], + sha256 = "20200ee9b0866d5176f633a9b375e8a44e53204c01ea2e159e2f9245afb00e80", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1669173772528997"], + ) + + http_file( + name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js", + sha256 = "11bbf73d48723b19a5a6a13ec296ecdb2aa178cdc3db9d7bc54265a7d4b94c6a", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1669173774625527"], + ) + + http_file( + name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm", + sha256 = "d4528972219033996a83a62798952b6ee8b6b396bcffd96fd5bda5458d57d3a3", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1669173777474822"], + ) + + http_file( + name = "com_google_mediapipe_wasm_text_wasm_internal_js", + sha256 = "29e72e177122f92bda6a3ecd463ebacf30b920559b06c97068112a22eeea4d0e", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1669173779706893"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_wasm", - sha256 = "8334caec5fb10cd1f936f6ee41f8853771c7bf3a421f5c15c39ee41aa503ca54", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1667934275451198"], + sha256 = "84e5f5ac70f7718baeaa09a89b155abbea67386e7d50663301b3af7ef0941e74", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1669173782728605"], + ) + + http_file( + name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js", + sha256 = "36f247673124e32535f217265b96508c1badee8fe2458c11c1efa95b6bec5daa", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1669173785027190"], + ) + + http_file( + name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm", + sha256 = "cc74d90a8aaf6d006ec24048cc80c33f96baeeb0075a6c6739f30d41da54e450", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1669173787903754"], + ) + + http_file( + name = "com_google_mediapipe_wasm_vision_wasm_internal_js", + sha256 = "c3451423186766b08008e07ef6d52f628fcc0aca75beedd9bb4d87d380f29edd", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1669173790070986"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm", - sha256 = "b996eaa324da151359ad8e16edad27d9768505f1fd073625bc50dbb0f252e098", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1667934277855507"], + sha256 = "d1e8ad748913e3f190bfd3f72e0e8a4a308f78b918d54c79cec60a2cf30a49f0", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1669173792993881"], + ) + + http_file( + name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js", + sha256 = "e5f1b5e8264ff9a90371653cb0fdbf9ce3b30b712acbd72068af18ebca2293ac", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1669173794969702"], + ) + + http_file( + name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm", + sha256 = "24351fe580e88f2065b1978b8b3c0f3ad7b90f1c95805aafa07971ce422b5854", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1669173797596874"], ) From c48ca1f674e2fef6b23a28100fd092ebe656e96a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 28 Nov 2022 13:29:35 -0800 Subject: [PATCH 056/346] internal change PiperOrigin-RevId: 491429214 --- .../tasks/cc/components/containers/BUILD | 5 --- .../tasks/cc/vision/hand_landmarker/BUILD | 6 +++ .../hand_landmarker/hand_landmark.h} | 10 ++--- .../tasks/components/containers/BUILD | 12 ------ .../com/google/mediapipe/tasks/vision/BUILD | 2 + .../handlandmarker}/HandLandmark.java | 2 +- .../python/components/containers/landmark.py | 26 ------------ .../tasks/python/vision/hand_landmarker.py | 26 ++++++++++++ .../web/components/containers/landmark.d.ts | 25 ----------- .../tasks/web/vision/hand_landmarker/BUILD | 1 + .../vision/hand_landmarker/hand_landmark.d.ts | 41 +++++++++++++++++++ 11 files changed, 82 insertions(+), 74 deletions(-) rename mediapipe/tasks/cc/{components/containers/landmark.h => vision/hand_landmarker/hand_landmark.h} (78%) rename mediapipe/tasks/java/com/google/mediapipe/tasks/{components/containers => vision/handlandmarker}/HandLandmark.java (97%) create mode 100644 mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index dec977fb8..35d3f4785 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -49,8 +49,3 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", ], ) - -cc_library( - name = "landmark", - hdrs = ["landmark.h"], -) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 46948ee6c..03ec45f7d 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -54,6 +54,12 @@ cc_library( ], ) +cc_library( + name = "hand_landmark", + hdrs = ["hand_landmark.h"], + visibility = ["//visibility:public"], +) + cc_library( name = "hand_landmarks_detector_graph", srcs = ["hand_landmarks_detector_graph.cc"], diff --git a/mediapipe/tasks/cc/components/containers/landmark.h b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h similarity index 78% rename from mediapipe/tasks/cc/components/containers/landmark.h rename to mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h index 6fdd294ae..c8dbc9254 100644 --- a/mediapipe/tasks/cc/components/containers/landmark.h +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ +#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARK_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARK_H_ -namespace mediapipe::tasks::components::containers { +namespace mediapipe::tasks::vision::hand_landmarker { // The 21 hand landmarks. enum HandLandmark { @@ -43,6 +43,6 @@ enum HandLandmark { PINKY_TIP = 20 }; -} // namespace mediapipe::tasks::components::containers +} // namespace mediapipe::tasks::vision::hand_landmarker -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ +#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARK_H_ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index 869157295..d6e6ac740 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -74,18 +74,6 @@ android_library( ], ) -android_library( - name = "handlandmark", - srcs = ["HandLandmark.java"], - javacopts = [ - "-Xep:AndroidJdkLibsChecker:OFF", - ], - deps = [ - "@maven//:androidx_annotation_annotation", - "@maven//:com_google_guava_guava", - ], -) - android_library( name = "landmark", srcs = ["Landmark.java"], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 72cee133f..b7febb118 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -145,6 +145,7 @@ android_library( android_library( name = "handlandmarker", srcs = [ + "handlandmarker/HandLandmark.java", "handlandmarker/HandLandmarker.java", "handlandmarker/HandLandmarkerResult.java", ], @@ -168,6 +169,7 @@ android_library( "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", + "@maven//:androidx_annotation_annotation", "@maven//:com_google_guava_guava", ], ) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmark.java similarity index 97% rename from mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java rename to mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmark.java index da7c4e0ca..7b21ebddf 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmark.java @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package com.google.mediapipe.tasks.components.containers; +package com.google.mediapipe.tasks.vision.handlandmarker; import androidx.annotation.IntDef; diff --git a/mediapipe/tasks/python/components/containers/landmark.py b/mediapipe/tasks/python/components/containers/landmark.py index 81b2943dc..dee2a16ad 100644 --- a/mediapipe/tasks/python/components/containers/landmark.py +++ b/mediapipe/tasks/python/components/containers/landmark.py @@ -14,7 +14,6 @@ """Landmark data class.""" import dataclasses -import enum from typing import Optional from mediapipe.framework.formats import landmark_pb2 @@ -121,28 +120,3 @@ class NormalizedLandmark: z=pb2_obj.z, visibility=pb2_obj.visibility, presence=pb2_obj.presence) - - -class HandLandmark(enum.IntEnum): - """The 21 hand landmarks.""" - WRIST = 0 - THUMB_CMC = 1 - THUMB_MCP = 2 - THUMB_IP = 3 - THUMB_TIP = 4 - INDEX_FINGER_MCP = 5 - INDEX_FINGER_PIP = 6 - INDEX_FINGER_DIP = 7 - INDEX_FINGER_TIP = 8 - MIDDLE_FINGER_MCP = 9 - MIDDLE_FINGER_PIP = 10 - MIDDLE_FINGER_DIP = 11 - MIDDLE_FINGER_TIP = 12 - RING_FINGER_MCP = 13 - RING_FINGER_PIP = 14 - RING_FINGER_DIP = 15 - RING_FINGER_TIP = 16 - PINKY_MCP = 17 - PINKY_PIP = 18 - PINKY_DIP = 19 - PINKY_TIP = 20 diff --git a/mediapipe/tasks/python/vision/hand_landmarker.py b/mediapipe/tasks/python/vision/hand_landmarker.py index 3367f1da7..a0cd99a83 100644 --- a/mediapipe/tasks/python/vision/hand_landmarker.py +++ b/mediapipe/tasks/python/vision/hand_landmarker.py @@ -14,6 +14,7 @@ """MediaPipe hand landmarker task.""" import dataclasses +import enum from typing import Callable, Mapping, Optional, List from mediapipe.framework.formats import classification_pb2 @@ -53,6 +54,31 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 +class HandLandmark(enum.IntEnum): + """The 21 hand landmarks.""" + WRIST = 0 + THUMB_CMC = 1 + THUMB_MCP = 2 + THUMB_IP = 3 + THUMB_TIP = 4 + INDEX_FINGER_MCP = 5 + INDEX_FINGER_PIP = 6 + INDEX_FINGER_DIP = 7 + INDEX_FINGER_TIP = 8 + MIDDLE_FINGER_MCP = 9 + MIDDLE_FINGER_PIP = 10 + MIDDLE_FINGER_DIP = 11 + MIDDLE_FINGER_TIP = 12 + RING_FINGER_MCP = 13 + RING_FINGER_PIP = 14 + RING_FINGER_DIP = 15 + RING_FINGER_TIP = 16 + PINKY_MCP = 17 + PINKY_PIP = 18 + PINKY_DIP = 19 + PINKY_TIP = 20 + + @dataclasses.dataclass class HandLandmarkerResult: """The hand landmarks result from HandLandmarker, where each vector element represents a single hand detected in the image. diff --git a/mediapipe/tasks/web/components/containers/landmark.d.ts b/mediapipe/tasks/web/components/containers/landmark.d.ts index 352717a2f..c887303d0 100644 --- a/mediapipe/tasks/web/components/containers/landmark.d.ts +++ b/mediapipe/tasks/web/components/containers/landmark.d.ts @@ -33,28 +33,3 @@ export declare interface Landmark { /** Whether this landmark is normalized with respect to the image size. */ normalized: boolean; } - -/** The 21 hand landmarks. */ -export const enum HandLandmark { - WRIST = 0, - THUMB_CMC = 1, - THUMB_MCP = 2, - THUMB_IP = 3, - THUMB_TIP = 4, - INDEX_FINGER_MCP = 5, - INDEX_FINGER_PIP = 6, - INDEX_FINGER_DIP = 7, - INDEX_FINGER_TIP = 8, - MIDDLE_FINGER_MCP = 9, - MIDDLE_FINGER_PIP = 10, - MIDDLE_FINGER_DIP = 11, - MIDDLE_FINGER_TIP = 12, - RING_FINGER_MCP = 13, - RING_FINGER_PIP = 14, - RING_FINGER_DIP = 15, - RING_FINGER_TIP = 16, - PINKY_MCP = 17, - PINKY_PIP = 18, - PINKY_DIP = 19, - PINKY_TIP = 20 -} diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index 1849687c5..fc3e6ef1f 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -34,6 +34,7 @@ mediapipe_ts_library( mediapipe_ts_declaration( name = "hand_landmarker_types", srcs = [ + "hand_landmark.d.ts", "hand_landmarker_options.d.ts", "hand_landmarker_result.d.ts", ], diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts new file mode 100644 index 000000000..ca2543f78 --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts @@ -0,0 +1,41 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** The 21 hand landmarks. */ +export const enum HandLandmark { + WRIST = 0, + THUMB_CMC = 1, + THUMB_MCP = 2, + THUMB_IP = 3, + THUMB_TIP = 4, + INDEX_FINGER_MCP = 5, + INDEX_FINGER_PIP = 6, + INDEX_FINGER_DIP = 7, + INDEX_FINGER_TIP = 8, + MIDDLE_FINGER_MCP = 9, + MIDDLE_FINGER_PIP = 10, + MIDDLE_FINGER_DIP = 11, + MIDDLE_FINGER_TIP = 12, + RING_FINGER_MCP = 13, + RING_FINGER_PIP = 14, + RING_FINGER_DIP = 15, + RING_FINGER_TIP = 16, + PINKY_MCP = 17, + PINKY_PIP = 18, + PINKY_DIP = 19, + PINKY_TIP = 20 +} From 342f95fa2044c4957ea7cb65352268a868e3d680 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 28 Nov 2022 13:51:59 -0800 Subject: [PATCH 057/346] Typo fix PiperOrigin-RevId: 491434987 --- mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h | 2 +- mediapipe/tasks/python/vision/image_segmenter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 43bf5b7e6..511d3b9c1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -98,7 +98,7 @@ struct ImageSegmenterOptions { // - list of segmented masks. // - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. // - if `output_type` is CONFIDENCE_MASK, float32 Image list of size -// `cahnnels`. +// `channels`. // - batch is always 1 // An example of such model can be found at: // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index 9ef911f75..62fc8bb7c 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -110,7 +110,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): - list of segmented masks. - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. - if `output_type` is CONFIDENCE_MASK, float32 Image list of size - `cahnnels`. + `channels`. - batch is always 1 An example of such model can be found at: From b65c40b302ccf397d6da3c27ab2795335e5c63cd Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 28 Nov 2022 14:15:16 -0800 Subject: [PATCH 058/346] Internal change PiperOrigin-RevId: 491441446 --- mediapipe/objc/MPPLayerRenderer.m | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mediapipe/objc/MPPLayerRenderer.m b/mediapipe/objc/MPPLayerRenderer.m index 7c3027fb6..edd2216ee 100644 --- a/mediapipe/objc/MPPLayerRenderer.m +++ b/mediapipe/objc/MPPLayerRenderer.m @@ -54,10 +54,11 @@ glGenRenderbuffers(1, &renderbuffer_); glBindRenderbuffer(GL_RENDERBUFFER, renderbuffer_); glFramebufferRenderbuffer(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_RENDERBUFFER, renderbuffer_); - BOOL success = [_glRenderer.glContext renderbufferStorage:GL_RENDERBUFFER fromDrawable:_layer]; + BOOL success __unused = [_glRenderer.glContext renderbufferStorage:GL_RENDERBUFFER + fromDrawable:_layer]; NSAssert(success, @"could not create renderbuffer storage for layer with bounds %@", NSStringFromCGRect(_layer.bounds)); - GLenum status = glCheckFramebufferStatus(GL_FRAMEBUFFER); + GLenum status __unused = glCheckFramebufferStatus(GL_FRAMEBUFFER); NSAssert(status == GL_FRAMEBUFFER_COMPLETE, @"failed to make complete framebuffer object %x", status); } From 26a7ca5c64cd885978677931a7218d33cd7d1dec Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 28 Nov 2022 15:02:55 -0800 Subject: [PATCH 059/346] fix typo and minor formatting issues PiperOrigin-RevId: 491453662 --- mediapipe/python/solutions/drawing_utils.py | 42 ++++++++++----------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/mediapipe/python/solutions/drawing_utils.py b/mediapipe/python/solutions/drawing_utils.py index bebcbe97c..1b8b173f7 100644 --- a/mediapipe/python/solutions/drawing_utils.py +++ b/mediapipe/python/solutions/drawing_utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """MediaPipe solution drawing utils.""" import math @@ -135,15 +134,14 @@ def draw_landmarks( the image. connections: A list of landmark index tuples that specifies how landmarks to be connected in the drawing. - landmark_drawing_spec: Either a DrawingSpec object or a mapping from - hand landmarks to the DrawingSpecs that specifies the landmarks' drawing - settings such as color, line thickness, and circle radius. - If this argument is explicitly set to None, no landmarks will be drawn. - connection_drawing_spec: Either a DrawingSpec object or a mapping from - hand connections to the DrawingSpecs that specifies the - connections' drawing settings such as color and line thickness. - If this argument is explicitly set to None, no landmark connections will - be drawn. + landmark_drawing_spec: Either a DrawingSpec object or a mapping from hand + landmarks to the DrawingSpecs that specifies the landmarks' drawing + settings such as color, line thickness, and circle radius. If this + argument is explicitly set to None, no landmarks will be drawn. + connection_drawing_spec: Either a DrawingSpec object or a mapping from hand + connections to the DrawingSpecs that specifies the connections' drawing + settings such as color and line thickness. If this argument is explicitly + set to None, no landmark connections will be drawn. Raises: ValueError: If one of the followings: @@ -197,14 +195,13 @@ def draw_landmarks( drawing_spec.color, drawing_spec.thickness) -def draw_axis( - image: np.ndarray, - rotation: np.ndarray, - translation: np.ndarray, - focal_length: Tuple[float, float] = (1.0, 1.0), - principal_point: Tuple[float, float] = (0.0, 0.0), - axis_length: float = 0.1, - axis_drawing_spec: DrawingSpec = DrawingSpec()): +def draw_axis(image: np.ndarray, + rotation: np.ndarray, + translation: np.ndarray, + focal_length: Tuple[float, float] = (1.0, 1.0), + principal_point: Tuple[float, float] = (0.0, 0.0), + axis_length: float = 0.1, + axis_drawing_spec: DrawingSpec = DrawingSpec()): """Draws the 3D axis on the image. Args: @@ -214,8 +211,8 @@ def draw_axis( focal_length: camera focal length along x and y directions. principal_point: camera principal point in x and y. axis_length: length of the axis in the drawing. - axis_drawing_spec: A DrawingSpec object that specifies the xyz axis - drawing settings such as line thickness. + axis_drawing_spec: A DrawingSpec object that specifies the xyz axis drawing + settings such as line thickness. Raises: ValueError: If one of the followings: @@ -226,7 +223,7 @@ def draw_axis( image_rows, image_cols, _ = image.shape # Create axis points in camera coordinate frame. axis_world = np.float32([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]) - axis_cam = np.matmul(rotation, axis_length*axis_world.T).T + translation + axis_cam = np.matmul(rotation, axis_length * axis_world.T).T + translation x = axis_cam[..., 0] y = axis_cam[..., 1] z = axis_cam[..., 2] @@ -274,8 +271,9 @@ def plot_landmarks(landmark_list: landmark_pb2.NormalizedLandmarkList, connections' drawing settings such as color and line thickness. elevation: The elevation from which to view the plot. azimuth: the azimuth angle to rotate the plot. + Raises: - ValueError: If any connetions contain invalid landmark index. + ValueError: If any connection contains an invalid landmark index. """ if not landmark_list: return From 7b74fd53f592ab115f60180278952eafeeb61634 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 28 Nov 2022 15:46:30 -0800 Subject: [PATCH 060/346] Verify that kernel cache is only used when OpenCL is active PiperOrigin-RevId: 491463306 --- .../calculators/tensor/inference_calculator_gl_advanced.cc | 6 +++--- mediapipe/calculators/tflite/tflite_inference_calculator.cc | 6 +++--- mediapipe/util/tflite/tflite_gpu_runner.h | 4 +++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index c2c723402..b226dbbd8 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -258,9 +258,9 @@ InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches( tflite::gpu::TFLiteGPURunner* gpu_runner) const { if (use_kernel_caching_) { // Save kernel file. - auto kernel_cache = absl::make_unique>( - gpu_runner->GetSerializedBinaryCache()); - std::string cache_str(kernel_cache->begin(), kernel_cache->end()); + ASSIGN_OR_RETURN(std::vector kernel_cache, + gpu_runner->GetSerializedBinaryCache()); + std::string cache_str(kernel_cache.begin(), kernel_cache.end()); MP_RETURN_IF_ERROR( mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); } diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index afdc9ed6f..0f7fa933e 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -485,9 +485,9 @@ absl::Status TfLiteInferenceCalculator::WriteKernelsToFile() { #if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID) if (use_kernel_caching_) { // Save kernel file. - auto kernel_cache = absl::make_unique>( - tflite_gpu_runner_->GetSerializedBinaryCache()); - std::string cache_str(kernel_cache->begin(), kernel_cache->end()); + ASSIGN_OR_RETURN(std::vector kernel_cache, + tflite_gpu_runner_->GetSerializedBinaryCache()); + std::string cache_str(kernel_cache.begin(), kernel_cache.end()); MP_RETURN_IF_ERROR( mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); } diff --git a/mediapipe/util/tflite/tflite_gpu_runner.h b/mediapipe/util/tflite/tflite_gpu_runner.h index dfbc8d659..5eeaa230f 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.h +++ b/mediapipe/util/tflite/tflite_gpu_runner.h @@ -21,6 +21,7 @@ #include "absl/status/status.h" #include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -89,7 +90,8 @@ class TFLiteGPURunner { serialized_binary_cache_ = std::move(cache); } - std::vector GetSerializedBinaryCache() { + absl::StatusOr> GetSerializedBinaryCache() { + RET_CHECK(cl_environment_) << "CL environment is not initialized."; return cl_environment_->GetSerializedBinaryCache(); } From e987b69f397af3d7bb4976d4e77029dacaae999a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 28 Nov 2022 16:48:17 -0800 Subject: [PATCH 061/346] Add alternative method to determine unique kernel cache path PiperOrigin-RevId: 491476293 --- .../tensor/inference_calculator_gl_advanced.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index b226dbbd8..8fd55efa7 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -236,14 +236,21 @@ absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init( const mediapipe::InferenceCalculatorOptions& options, const mediapipe::InferenceCalculatorOptions::Delegate::Gpu& gpu_delegate_options) { - use_kernel_caching_ = gpu_delegate_options.has_cached_kernel_path(); + // The kernel cache needs a unique filename based on either model_path or the + // model token, to prevent the cache from being overwritten if the graph has + // more than one model. + use_kernel_caching_ = + gpu_delegate_options.has_cached_kernel_path() && + (options.has_model_path() || gpu_delegate_options.has_model_token()); use_serialized_model_ = gpu_delegate_options.has_serialized_model_dir() && gpu_delegate_options.has_model_token(); if (use_kernel_caching_) { + std::string basename = options.has_model_path() + ? mediapipe::File::Basename(options.model_path()) + : gpu_delegate_options.model_token(); cached_kernel_filename_ = mediapipe::file::JoinPath( - gpu_delegate_options.cached_kernel_path(), - mediapipe::File::Basename(options.model_path()) + ".ker"); + gpu_delegate_options.cached_kernel_path(), basename + ".ker"); } if (use_serialized_model_) { serialized_model_path_ = From fc526374abac9e1080e06470004ab292fe0c162a Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 28 Nov 2022 17:48:37 -0800 Subject: [PATCH 062/346] Use GpuResources in GpuTestBase and update GpuBufferMultiPoolTest PiperOrigin-RevId: 491486495 --- mediapipe/gpu/gpu_test_base.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mediapipe/gpu/gpu_test_base.h b/mediapipe/gpu/gpu_test_base.h index e9fd64725..6ec53603b 100644 --- a/mediapipe/gpu/gpu_test_base.h +++ b/mediapipe/gpu/gpu_test_base.h @@ -24,13 +24,14 @@ namespace mediapipe { class GpuTestBase : public ::testing::Test { protected: - GpuTestBase() { helper_.InitializeForTest(&gpu_shared_); } + GpuTestBase() { helper_.InitializeForTest(gpu_resources_.get()); } void RunInGlContext(std::function gl_func) { helper_.RunInGlContext(std::move(gl_func)); } GpuSharedData gpu_shared_; + std::shared_ptr gpu_resources_ = gpu_shared_.gpu_resources; GlCalculatorHelper helper_; }; From cc11b4522837ce2f3763831fca0447e3b7cef495 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 28 Nov 2022 17:52:35 -0800 Subject: [PATCH 063/346] Remove unneeded GPU_SHARED side packet for GlSurfaceSink PiperOrigin-RevId: 491487092 --- mediapipe/gpu/gl_surface_sink_calculator.cc | 1 - mediapipe/java/com/google/mediapipe/framework/jni/graph.cc | 2 -- 2 files changed, 3 deletions(-) diff --git a/mediapipe/gpu/gl_surface_sink_calculator.cc b/mediapipe/gpu/gl_surface_sink_calculator.cc index 31500ed9a..ad867c2be 100644 --- a/mediapipe/gpu/gl_surface_sink_calculator.cc +++ b/mediapipe/gpu/gl_surface_sink_calculator.cc @@ -37,7 +37,6 @@ enum { kAttribVertex, kAttribTexturePosition, kNumberOfAttributes }; // VIDEO or index 0: GpuBuffers to be rendered. // Side inputs: // SURFACE: unique_ptr to an EglSurfaceHolder to draw to. -// GPU_SHARED: shared GPU resources. // // See GlSurfaceSinkCalculatorOptions for options. class GlSurfaceSinkCalculator : public Node { diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc index 6a67c01cb..23bd553af 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc @@ -231,8 +231,6 @@ int64_t Graph::AddSurfaceOutput(const std::string& output_stream_name) { *graph_config(), absl::StrCat("egl_surface_sink_", output_stream_name))); sink_node->set_calculator("GlSurfaceSinkCalculator"); sink_node->add_input_stream(output_stream_name); - sink_node->add_input_side_packet( - absl::StrCat(kGpuSharedTagName, ":", kGpuSharedSidePacketName)); const std::string input_side_packet_name = mediapipe::tool::GetUnusedSidePacketName( From c8a413bb4e5da6b977695987809a27b8f044f15a Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 29 Nov 2022 10:17:21 -0800 Subject: [PATCH 064/346] Open up mediapipe framework's visibility. PiperOrigin-RevId: 491672877 --- mediapipe/calculators/image/BUILD | 41 +-------- mediapipe/calculators/tensorflow/BUILD | 70 +--------------- mediapipe/calculators/tflite/BUILD | 20 +---- mediapipe/calculators/util/BUILD | 83 ------------------- mediapipe/calculators/video/BUILD | 29 +------ mediapipe/examples/desktop/hello_world/BUILD | 3 +- mediapipe/framework/BUILD | 2 +- mediapipe/framework/formats/BUILD | 28 +------ mediapipe/framework/formats/annotation/BUILD | 4 +- mediapipe/framework/formats/motion/BUILD | 7 +- .../framework/formats/object_detection/BUILD | 4 +- mediapipe/framework/stream_handler/BUILD | 19 +---- .../holistic_landmark/calculators/BUILD | 3 - mediapipe/util/tracking/BUILD | 17 ---- 14 files changed, 11 insertions(+), 319 deletions(-) diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index c78bc5cf7..530dd3d4a 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -16,12 +16,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "opencv_image_encoder_calculator_proto", srcs = ["opencv_image_encoder_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -31,7 +30,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "scale_image_calculator_proto", srcs = ["scale_image_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -42,7 +40,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "set_alpha_calculator_proto", srcs = ["set_alpha_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -52,7 +49,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "image_cropping_calculator_proto", srcs = ["image_cropping_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -62,7 +58,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "bilateral_filter_calculator_proto", srcs = ["bilateral_filter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -72,7 +67,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "recolor_calculator_proto", srcs = ["recolor_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -83,7 +77,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "segmentation_smoothing_calculator_proto", srcs = ["segmentation_smoothing_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -93,7 +86,6 @@ mediapipe_proto_library( cc_library( name = "color_convert_calculator", srcs = ["color_convert_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -112,7 +104,6 @@ cc_library( cc_library( name = "opencv_encoded_image_to_image_frame_calculator", srcs = ["opencv_encoded_image_to_image_frame_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":opencv_encoded_image_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -127,7 +118,6 @@ cc_library( cc_library( name = "opencv_image_encoder_calculator", srcs = ["opencv_image_encoder_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":opencv_image_encoder_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -142,7 +132,6 @@ cc_library( cc_library( name = "opencv_put_text_calculator", srcs = ["opencv_put_text_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame_opencv", @@ -156,7 +145,6 @@ cc_library( cc_library( name = "set_alpha_calculator", srcs = ["set_alpha_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":set_alpha_calculator_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", @@ -183,7 +171,6 @@ cc_library( cc_library( name = "bilateral_filter_calculator", srcs = ["bilateral_filter_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":bilateral_filter_calculator_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", @@ -212,13 +199,11 @@ cc_library( mediapipe_proto_library( name = "rotation_mode_proto", srcs = ["rotation_mode.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "image_transformation_calculator_proto", srcs = ["image_transformation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ ":rotation_mode_proto", "//mediapipe/framework:calculator_options_proto", @@ -243,7 +228,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":rotation_mode_cc_proto", ":image_transformation_calculator_cc_proto", @@ -287,7 +271,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_cropping_calculator_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", @@ -330,7 +313,6 @@ cc_test( cc_library( name = "luminance_calculator", srcs = ["luminance_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -344,7 +326,6 @@ cc_library( cc_library( name = "sobel_edges_calculator", srcs = ["sobel_edges_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -358,7 +339,6 @@ cc_library( cc_library( name = "recolor_calculator", srcs = ["recolor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":recolor_calculator_cc_proto", "//mediapipe/util:color_cc_proto", @@ -385,9 +365,6 @@ cc_library( name = "scale_image_utils", srcs = ["scale_image_utils.cc"], hdrs = ["scale_image_utils.h"], - visibility = [ - "//mediapipe:__subpackages__", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:logging", @@ -400,9 +377,6 @@ cc_library( cc_library( name = "scale_image_calculator", srcs = ["scale_image_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ ":scale_image_utils", "//mediapipe/calculators/image:scale_image_calculator_cc_proto", @@ -429,7 +403,6 @@ cc_library( mediapipe_proto_library( name = "image_clone_calculator_proto", srcs = ["image_clone_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -439,7 +412,6 @@ mediapipe_proto_library( cc_library( name = "image_clone_calculator", srcs = ["image_clone_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":image_clone_calculator_cc_proto", "//mediapipe/framework/api2:node", @@ -459,7 +431,6 @@ cc_library( cc_library( name = "image_properties_calculator", srcs = ["image_properties_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/api2:node", "//mediapipe/framework:calculator_framework", @@ -524,7 +495,6 @@ cc_test( mediapipe_proto_library( name = "mask_overlay_calculator_proto", srcs = ["mask_overlay_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -534,7 +504,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "opencv_encoded_image_to_image_frame_calculator_proto", srcs = ["opencv_encoded_image_to_image_frame_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -544,7 +513,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "feature_detector_calculator_proto", srcs = ["feature_detector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -554,7 +522,6 @@ mediapipe_proto_library( cc_library( name = "mask_overlay_calculator", srcs = ["mask_overlay_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":mask_overlay_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -570,7 +537,6 @@ cc_library( cc_library( name = "feature_detector_calculator", srcs = ["feature_detector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":feature_detector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -597,7 +563,6 @@ cc_library( cc_library( name = "image_file_properties_calculator", srcs = ["image_file_properties_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_file_properties_cc_proto", @@ -627,7 +592,6 @@ cc_test( cc_library( name = "segmentation_smoothing_calculator", srcs = ["segmentation_smoothing_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":segmentation_smoothing_calculator_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", @@ -724,7 +688,6 @@ cc_library( mediapipe_proto_library( name = "warp_affine_calculator_proto", srcs = ["warp_affine_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -736,7 +699,6 @@ cc_library( name = "warp_affine_calculator", srcs = ["warp_affine_calculator.cc"], hdrs = ["warp_affine_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":affine_transformation", ":warp_affine_calculator_cc_proto", @@ -817,7 +779,6 @@ cc_test( cc_library( name = "yuv_to_image_calculator", srcs = ["yuv_to_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index 45f64f4f7..0f8f8706a 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -17,12 +17,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "graph_tensors_packet_generator_proto", srcs = ["graph_tensors_packet_generator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/framework:packet_generator_proto", @@ -32,49 +31,42 @@ proto_library( proto_library( name = "matrix_to_tensor_calculator_options_proto", srcs = ["matrix_to_tensor_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "lapped_tensor_buffer_calculator_proto", srcs = ["lapped_tensor_buffer_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "object_detection_tensors_to_detections_calculator_proto", srcs = ["object_detection_tensors_to_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensorflow_inference_calculator_proto", srcs = ["tensorflow_inference_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_squeeze_dimensions_calculator_proto", srcs = ["tensor_squeeze_dimensions_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_image_frame_calculator_proto", srcs = ["tensor_to_image_frame_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_matrix_calculator_proto", srcs = ["tensor_to_matrix_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:time_series_header_proto", @@ -84,30 +76,24 @@ proto_library( proto_library( name = "tensor_to_vector_float_calculator_options_proto", srcs = ["tensor_to_vector_float_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_vector_int_calculator_options_proto", srcs = ["tensor_to_vector_int_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_vector_string_calculator_options_proto", srcs = ["tensor_to_vector_string_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) mediapipe_proto_library( name = "unpack_media_sequence_calculator_proto", srcs = ["unpack_media_sequence_calculator.proto"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/core:packet_resampler_calculator_proto", "//mediapipe/framework:calculator_proto", @@ -118,14 +104,12 @@ mediapipe_proto_library( proto_library( name = "vector_float_to_tensor_calculator_options_proto", srcs = ["vector_float_to_tensor_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "vector_string_to_tensor_calculator_options_proto", srcs = ["vector_string_to_tensor_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) @@ -136,7 +120,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:packet_generator_cc_proto", ], - visibility = ["//visibility:public"], deps = [":graph_tensors_packet_generator_proto"], ) @@ -147,7 +130,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":image_frame_to_tensor_calculator_proto"], ) @@ -155,7 +137,6 @@ mediapipe_cc_proto_library( name = "matrix_to_tensor_calculator_options_cc_proto", srcs = ["matrix_to_tensor_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":matrix_to_tensor_calculator_options_proto"], ) @@ -163,7 +144,6 @@ mediapipe_cc_proto_library( name = "lapped_tensor_buffer_calculator_cc_proto", srcs = ["lapped_tensor_buffer_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":lapped_tensor_buffer_calculator_proto"], ) @@ -171,7 +151,6 @@ mediapipe_cc_proto_library( name = "object_detection_tensors_to_detections_calculator_cc_proto", srcs = ["object_detection_tensors_to_detections_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":object_detection_tensors_to_detections_calculator_proto"], ) @@ -179,7 +158,6 @@ mediapipe_cc_proto_library( name = "tensorflow_inference_calculator_cc_proto", srcs = ["tensorflow_inference_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensorflow_inference_calculator_proto"], ) @@ -190,7 +168,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:packet_generator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_frozen_graph_generator_proto"], ) @@ -201,7 +178,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_frozen_graph_calculator_proto"], ) @@ -212,7 +188,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:packet_generator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_saved_model_generator_proto"], ) @@ -223,7 +198,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_saved_model_calculator_proto"], ) @@ -231,7 +205,6 @@ mediapipe_cc_proto_library( name = "tensor_squeeze_dimensions_calculator_cc_proto", srcs = ["tensor_squeeze_dimensions_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_squeeze_dimensions_calculator_proto"], ) @@ -239,7 +212,6 @@ mediapipe_cc_proto_library( name = "tensor_to_image_frame_calculator_cc_proto", srcs = ["tensor_to_image_frame_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_image_frame_calculator_proto"], ) @@ -250,7 +222,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", ], - visibility = ["//visibility:public"], deps = [":tensor_to_matrix_calculator_proto"], ) @@ -258,7 +229,6 @@ mediapipe_cc_proto_library( name = "tensor_to_vector_float_calculator_options_cc_proto", srcs = ["tensor_to_vector_float_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_vector_float_calculator_options_proto"], ) @@ -266,7 +236,6 @@ mediapipe_cc_proto_library( name = "tensor_to_vector_int_calculator_options_cc_proto", srcs = ["tensor_to_vector_int_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_vector_int_calculator_options_proto"], ) @@ -274,7 +243,6 @@ mediapipe_cc_proto_library( name = "tensor_to_vector_string_calculator_options_cc_proto", srcs = ["tensor_to_vector_string_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_vector_string_calculator_options_proto"], ) @@ -285,7 +253,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":vector_int_to_tensor_calculator_options_proto"], ) @@ -293,7 +260,6 @@ mediapipe_cc_proto_library( name = "vector_float_to_tensor_calculator_options_cc_proto", srcs = ["vector_float_to_tensor_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":vector_float_to_tensor_calculator_options_proto"], ) @@ -301,14 +267,12 @@ mediapipe_cc_proto_library( name = "vector_string_to_tensor_calculator_options_cc_proto", srcs = ["vector_string_to_tensor_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":vector_string_to_tensor_calculator_options_proto"], ) cc_library( name = "graph_tensors_packet_generator", srcs = ["graph_tensors_packet_generator.cc"], - visibility = ["//visibility:public"], deps = [ ":graph_tensors_packet_generator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -323,7 +287,6 @@ cc_library( cc_library( name = "image_frame_to_tensor_calculator", srcs = ["image_frame_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":image_frame_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -344,7 +307,6 @@ cc_library( cc_library( name = "matrix_to_tensor_calculator", srcs = ["matrix_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":matrix_to_tensor_calculator_options_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", @@ -366,7 +328,6 @@ cc_library( cc_library( name = "lapped_tensor_buffer_calculator", srcs = ["lapped_tensor_buffer_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":lapped_tensor_buffer_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -388,9 +349,6 @@ cc_library( # Layering check doesn't play nicely with portable proto wrappers. "no_layering_check", ], - visibility = [ - "//visibility:public", - ], deps = [ ":object_detection_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -407,9 +365,6 @@ cc_library( cc_library( name = "pack_media_sequence_calculator", srcs = ["pack_media_sequence_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", @@ -432,9 +387,6 @@ cc_library( cc_library( name = "string_to_sequence_example_calculator", srcs = ["string_to_sequence_example_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -449,7 +401,6 @@ cc_library( cc_library( name = "tensorflow_inference_calculator", srcs = ["tensorflow_inference_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensorflow_inference_calculator_cc_proto", ":tensorflow_session", @@ -487,7 +438,6 @@ cc_library( "tensorflow_session.h", ], features = ["no_layering_check"], - visibility = ["//visibility:public"], deps = select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:core", @@ -505,7 +455,6 @@ cc_library( name = "tensorflow_session_from_frozen_graph_calculator", srcs = ["tensorflow_session_from_frozen_graph_calculator.cc"], features = ["no_layering_check"], - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", "//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_calculator_cc_proto", @@ -537,7 +486,6 @@ cc_library( name = "tensorflow_session_from_frozen_graph_generator", srcs = ["tensorflow_session_from_frozen_graph_generator.cc"], features = ["no_layering_check"], - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", ":tensorflow_session_from_frozen_graph_generator_cc_proto", @@ -572,7 +520,6 @@ cc_library( "//mediapipe:android": ["__ANDROID__"], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", ":tensorflow_session_from_saved_model_calculator_cc_proto", @@ -611,7 +558,6 @@ cc_library( "//mediapipe:android": ["__ANDROID__"], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", ":tensorflow_session_from_saved_model_generator_cc_proto", @@ -637,7 +583,6 @@ cc_library( cc_library( name = "tensor_squeeze_dimensions_calculator", srcs = ["tensor_squeeze_dimensions_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_squeeze_dimensions_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -651,7 +596,6 @@ cc_library( cc_library( name = "tensor_to_image_frame_calculator", srcs = ["tensor_to_image_frame_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -666,7 +610,6 @@ cc_library( cc_library( name = "tensor_to_matrix_calculator", srcs = ["tensor_to_matrix_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_matrix_calculator_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", @@ -688,7 +631,6 @@ cc_library( cc_library( name = "tfrecord_reader_calculator", srcs = ["tfrecord_reader_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:integral_types", @@ -704,7 +646,6 @@ cc_library( cc_library( name = "tensor_to_vector_float_calculator", srcs = ["tensor_to_vector_float_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_vector_float_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -724,7 +665,6 @@ cc_library( cc_library( name = "tensor_to_vector_int_calculator", srcs = ["tensor_to_vector_int_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_vector_int_calculator_options_cc_proto", "@com_google_absl//absl/base:core_headers", @@ -746,7 +686,6 @@ cc_library( cc_library( name = "tensor_to_vector_string_calculator", srcs = ["tensor_to_vector_string_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", @@ -766,9 +705,6 @@ cc_library( cc_library( name = "unpack_media_sequence_calculator", srcs = ["unpack_media_sequence_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/calculators/tensorflow:unpack_media_sequence_calculator_cc_proto", @@ -786,7 +722,6 @@ cc_library( cc_library( name = "vector_int_to_tensor_calculator", srcs = ["vector_int_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":vector_int_to_tensor_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -800,7 +735,6 @@ cc_library( cc_library( name = "vector_float_to_tensor_calculator", srcs = ["vector_float_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":vector_float_to_tensor_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -814,7 +748,6 @@ cc_library( cc_library( name = "vector_string_to_tensor_calculator", srcs = ["vector_string_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":vector_string_to_tensor_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -828,7 +761,6 @@ cc_library( cc_library( name = "unpack_yt8m_sequence_example_calculator", srcs = ["unpack_yt8m_sequence_example_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":lapped_tensor_buffer_calculator_cc_proto", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 8edaeee02..db2a27630 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -18,12 +18,11 @@ load("@bazel_skylib//lib:selects.bzl", "selects") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "ssd_anchors_calculator_proto", srcs = ["ssd_anchors_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -33,7 +32,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_custom_op_resolver_calculator_proto", srcs = ["tflite_custom_op_resolver_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -43,7 +41,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_inference_calculator_proto", srcs = ["tflite_inference_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -53,7 +50,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_converter_calculator_proto", srcs = ["tflite_converter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -63,7 +59,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_segmentation_calculator_proto", srcs = ["tflite_tensors_to_segmentation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -73,7 +68,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_detections_calculator_proto", srcs = ["tflite_tensors_to_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -83,7 +77,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_classification_calculator_proto", srcs = ["tflite_tensors_to_classification_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -93,7 +86,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_landmarks_calculator_proto", srcs = ["tflite_tensors_to_landmarks_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -103,7 +95,6 @@ mediapipe_proto_library( cc_library( name = "ssd_anchors_calculator", srcs = ["ssd_anchors_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":ssd_anchors_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -117,7 +108,6 @@ cc_library( cc_library( name = "tflite_custom_op_resolver_calculator", srcs = ["tflite_custom_op_resolver_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_custom_op_resolver_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -208,7 +198,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tflite_inference_calculator_cc_proto", "@com_google_absl//absl/memory", @@ -287,7 +276,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tflite_converter_calculator_cc_proto", "//mediapipe/util/tflite:config", @@ -326,7 +314,6 @@ cc_library( cc_library( name = "tflite_model_calculator", srcs = ["tflite_model_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -340,7 +327,6 @@ cc_library( cc_library( name = "tflite_tensors_to_segmentation_calculator", srcs = ["tflite_tensors_to_segmentation_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_segmentation_calculator_cc_proto", "@com_google_absl//absl/strings:str_format", @@ -408,7 +394,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -444,7 +429,6 @@ cc_library( cc_library( name = "tflite_tensors_to_classification_calculator", srcs = ["tflite_tensors_to_classification_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_classification_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -476,7 +460,6 @@ cc_library( cc_library( name = "tflite_tensors_to_landmarks_calculator", srcs = ["tflite_tensors_to_landmarks_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_landmarks_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -490,7 +473,6 @@ cc_library( cc_library( name = "tflite_tensors_to_floats_calculator", srcs = ["tflite_tensors_to_floats_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 24e976a73..43eadd53b 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -21,7 +21,6 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "alignment_points_to_rects_calculator", srcs = ["alignment_points_to_rects_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":detections_to_rects_calculator_cc_proto", "//mediapipe/calculators/util:detections_to_rects_calculator", @@ -39,7 +38,6 @@ cc_library( mediapipe_proto_library( name = "annotation_overlay_calculator_proto", srcs = ["annotation_overlay_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -50,7 +48,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "detection_label_id_to_text_calculator_proto", srcs = ["detection_label_id_to_text_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -61,7 +58,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "filter_detections_calculator_proto", srcs = ["filter_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -71,7 +67,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "timed_box_list_id_to_label_calculator_proto", srcs = ["timed_box_list_id_to_label_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -81,13 +76,11 @@ mediapipe_proto_library( mediapipe_proto_library( name = "latency_proto", srcs = ["latency.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "non_max_suppression_calculator_proto", srcs = ["non_max_suppression_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -97,13 +90,11 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_frequency_proto", srcs = ["packet_frequency.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "packet_frequency_calculator_proto", srcs = ["packet_frequency_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -113,7 +104,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_latency_calculator_proto", srcs = ["packet_latency_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -123,7 +113,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "collection_has_min_size_calculator_proto", srcs = ["collection_has_min_size_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -133,7 +122,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "association_calculator_proto", srcs = ["association_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -143,7 +131,6 @@ mediapipe_proto_library( cc_library( name = "packet_frequency_calculator", srcs = ["packet_frequency_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:packet_frequency_calculator_cc_proto", "//mediapipe/calculators/util:packet_frequency_cc_proto", @@ -188,7 +175,6 @@ cc_test( cc_library( name = "packet_latency_calculator", srcs = ["packet_latency_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:latency_cc_proto", "//mediapipe/calculators/util:packet_latency_calculator_cc_proto", @@ -228,9 +214,6 @@ cc_test( cc_library( name = "clock_timestamp_calculator", srcs = ["clock_timestamp_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -246,9 +229,6 @@ cc_library( cc_library( name = "clock_latency_calculator", srcs = ["clock_latency_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -263,7 +243,6 @@ cc_library( cc_library( name = "annotation_overlay_calculator", srcs = ["annotation_overlay_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":annotation_overlay_calculator_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", @@ -296,7 +275,6 @@ cc_library( cc_library( name = "detection_label_id_to_text_calculator", srcs = ["detection_label_id_to_text_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":detection_label_id_to_text_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -328,7 +306,6 @@ cc_library( cc_library( name = "timed_box_list_id_to_label_calculator", srcs = ["timed_box_list_id_to_label_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":timed_box_list_id_to_label_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -357,7 +334,6 @@ cc_library( cc_library( name = "detection_transformation_calculator", srcs = ["detection_transformation_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -391,7 +367,6 @@ cc_test( cc_library( name = "non_max_suppression_calculator", srcs = ["non_max_suppression_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":non_max_suppression_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -408,7 +383,6 @@ cc_library( cc_library( name = "thresholding_calculator", srcs = ["thresholding_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":thresholding_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -421,7 +395,6 @@ cc_library( cc_library( name = "detection_to_landmarks_calculator", srcs = ["detection_to_landmarks_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -436,7 +409,6 @@ cc_library( cc_library( name = "filter_detections_calculator", srcs = ["filter_detections_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":filter_detections_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -450,7 +422,6 @@ cc_library( cc_library( name = "landmarks_to_detection_calculator", srcs = ["landmarks_to_detection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmarks_to_detection_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -471,7 +442,6 @@ cc_library( hdrs = [ "detections_to_rects_calculator.h", ], - visibility = ["//visibility:public"], deps = [ ":detections_to_rects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -489,7 +459,6 @@ cc_library( cc_library( name = "rect_transformation_calculator", srcs = ["rect_transformation_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":rect_transformation_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -504,7 +473,6 @@ cc_library( cc_library( name = "rect_projection_calculator", srcs = ["rect_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:rect_cc_proto", @@ -535,7 +503,6 @@ cc_test( mediapipe_proto_library( name = "rect_to_render_data_calculator_proto", srcs = ["rect_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -547,7 +514,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "rect_to_render_scale_calculator_proto", srcs = ["rect_to_render_scale_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -557,7 +523,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "detections_to_render_data_calculator_proto", srcs = ["detections_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -569,7 +534,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "landmarks_to_render_data_calculator_proto", srcs = ["landmarks_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -581,7 +545,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "timed_box_list_to_render_data_calculator_proto", srcs = ["timed_box_list_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -593,7 +556,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "labels_to_render_data_calculator_proto", srcs = ["labels_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -605,7 +567,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "thresholding_calculator_proto", srcs = ["thresholding_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -617,7 +578,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "detections_to_rects_calculator_proto", srcs = ["detections_to_rects_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -627,7 +587,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "landmark_projection_calculator_proto", srcs = ["landmark_projection_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -637,7 +596,6 @@ mediapipe_proto_library( cc_library( name = "landmark_visibility_calculator", srcs = ["landmark_visibility_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -649,7 +607,6 @@ cc_library( cc_library( name = "set_landmark_visibility_calculator", srcs = ["set_landmark_visibility_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -661,7 +618,6 @@ cc_library( mediapipe_proto_library( name = "landmarks_to_floats_calculator_proto", srcs = ["landmarks_to_floats_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -671,7 +627,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "rect_transformation_calculator_proto", srcs = ["rect_transformation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -681,7 +636,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "landmarks_to_detection_calculator_proto", srcs = ["landmarks_to_detection_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -693,7 +647,6 @@ mediapipe_proto_library( cc_library( name = "detections_to_render_data_calculator", srcs = ["detections_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":detections_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -713,7 +666,6 @@ cc_library( name = "landmarks_to_render_data_calculator", srcs = ["landmarks_to_render_data_calculator.cc"], hdrs = ["landmarks_to_render_data_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":landmarks_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -732,7 +684,6 @@ cc_library( cc_library( name = "timed_box_list_to_render_data_calculator", srcs = ["timed_box_list_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":timed_box_list_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -751,7 +702,6 @@ cc_library( cc_library( name = "labels_to_render_data_calculator", srcs = ["labels_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":labels_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -770,7 +720,6 @@ cc_library( cc_library( name = "rect_to_render_data_calculator", srcs = ["rect_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":rect_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -785,7 +734,6 @@ cc_library( cc_library( name = "rect_to_render_scale_calculator", srcs = ["rect_to_render_scale_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":rect_to_render_scale_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -820,7 +768,6 @@ cc_test( cc_library( name = "detection_letterbox_removal_calculator", srcs = ["detection_letterbox_removal_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -834,7 +781,6 @@ cc_library( cc_library( name = "detection_projection_calculator", srcs = ["detection_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -867,7 +813,6 @@ cc_test( cc_library( name = "landmark_letterbox_removal_calculator", srcs = ["landmark_letterbox_removal_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -881,7 +826,6 @@ cc_library( cc_library( name = "landmark_projection_calculator", srcs = ["landmark_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmark_projection_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -914,7 +858,6 @@ cc_test( cc_library( name = "world_landmark_projection_calculator", srcs = ["world_landmark_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -928,7 +871,6 @@ cc_library( mediapipe_proto_library( name = "landmarks_smoothing_calculator_proto", srcs = ["landmarks_smoothing_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -938,7 +880,6 @@ mediapipe_proto_library( cc_library( name = "landmarks_smoothing_calculator", srcs = ["landmarks_smoothing_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmarks_smoothing_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -956,7 +897,6 @@ cc_library( mediapipe_proto_library( name = "visibility_smoothing_calculator_proto", srcs = ["visibility_smoothing_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -966,7 +906,6 @@ mediapipe_proto_library( cc_library( name = "visibility_smoothing_calculator", srcs = ["visibility_smoothing_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":visibility_smoothing_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -982,7 +921,6 @@ cc_library( mediapipe_proto_library( name = "visibility_copy_calculator_proto", srcs = ["visibility_copy_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -992,7 +930,6 @@ mediapipe_proto_library( cc_library( name = "visibility_copy_calculator", srcs = ["visibility_copy_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":visibility_copy_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1007,7 +944,6 @@ cc_library( cc_library( name = "landmarks_to_floats_calculator", srcs = ["landmarks_to_floats_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmarks_to_floats_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1054,7 +990,6 @@ cc_test( mediapipe_proto_library( name = "top_k_scores_calculator_proto", srcs = ["top_k_scores_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1064,7 +999,6 @@ mediapipe_proto_library( cc_library( name = "top_k_scores_calculator", srcs = ["top_k_scores_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":top_k_scores_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -1108,7 +1042,6 @@ cc_test( mediapipe_proto_library( name = "local_file_contents_calculator_proto", srcs = ["local_file_contents_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1118,7 +1051,6 @@ mediapipe_proto_library( cc_library( name = "local_file_contents_calculator", srcs = ["local_file_contents_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":local_file_contents_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1132,7 +1064,6 @@ cc_library( cc_library( name = "local_file_pattern_contents_calculator", srcs = ["local_file_pattern_contents_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:file_helpers", @@ -1146,7 +1077,6 @@ cc_library( name = "filter_collection_calculator", srcs = ["filter_collection_calculator.cc"], hdrs = ["filter_collection_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:classification_cc_proto", @@ -1164,7 +1094,6 @@ cc_library( name = "collection_has_min_size_calculator", srcs = ["collection_has_min_size_calculator.cc"], hdrs = ["collection_has_min_size_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":collection_has_min_size_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1192,7 +1121,6 @@ cc_test( cc_library( name = "association_calculator", hdrs = ["association_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":association_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -1209,7 +1137,6 @@ cc_library( cc_library( name = "association_norm_rect_calculator", srcs = ["association_norm_rect_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":association_calculator", "//mediapipe/framework:calculator_context", @@ -1224,7 +1151,6 @@ cc_library( cc_library( name = "association_detection_calculator", srcs = ["association_detection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":association_calculator", "//mediapipe/framework:calculator_context", @@ -1259,7 +1185,6 @@ cc_test( cc_library( name = "detections_to_timed_box_list_calculator", srcs = ["detections_to_timed_box_list_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -1274,7 +1199,6 @@ cc_library( cc_library( name = "detection_unique_id_calculator", srcs = ["detection_unique_id_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -1287,7 +1211,6 @@ cc_library( mediapipe_proto_library( name = "logic_calculator_proto", srcs = ["logic_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1297,7 +1220,6 @@ mediapipe_proto_library( cc_library( name = "logic_calculator", srcs = ["logic_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":logic_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1310,7 +1232,6 @@ cc_library( cc_library( name = "to_image_calculator", srcs = ["to_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework:calculator_options_cc_proto", @@ -1333,7 +1254,6 @@ cc_library( cc_library( name = "from_image_calculator", srcs = ["from_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework:calculator_options_cc_proto", @@ -1385,7 +1305,6 @@ cc_test( mediapipe_proto_library( name = "refine_landmarks_from_heatmap_calculator_proto", srcs = ["refine_landmarks_from_heatmap_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1403,7 +1322,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":refine_landmarks_from_heatmap_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1454,7 +1372,6 @@ cc_library( name = "inverse_matrix_calculator", srcs = ["inverse_matrix_calculator.cc"], hdrs = ["inverse_matrix_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index 2db3ed252..f2b8135f2 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -21,19 +21,17 @@ load( licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "flow_to_image_calculator_proto", srcs = ["flow_to_image_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "opencv_video_encoder_calculator_proto", srcs = ["opencv_video_encoder_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) @@ -58,7 +56,6 @@ proto_library( proto_library( name = "box_tracker_calculator_proto", srcs = ["box_tracker_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:box_tracker_proto", @@ -68,7 +65,6 @@ proto_library( proto_library( name = "tracked_detection_manager_calculator_proto", srcs = ["tracked_detection_manager_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:tracked_detection_manager_config_proto", @@ -78,7 +74,6 @@ proto_library( proto_library( name = "box_detector_calculator_proto", srcs = ["box_detector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:box_detector_proto", @@ -88,7 +83,6 @@ proto_library( proto_library( name = "video_pre_stream_calculator_proto", srcs = ["video_pre_stream_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", ], @@ -101,7 +95,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:motion_analysis_cc_proto", ], - visibility = ["//visibility:public"], deps = [":motion_analysis_calculator_proto"], ) @@ -112,7 +105,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:flow_packager_cc_proto", ], - visibility = ["//visibility:public"], deps = [":flow_packager_calculator_proto"], ) @@ -123,7 +115,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:box_tracker_cc_proto", ], - visibility = ["//visibility:public"], deps = [":box_tracker_calculator_proto"], ) @@ -134,7 +125,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:tracked_detection_manager_config_cc_proto", ], - visibility = ["//visibility:public"], deps = [":tracked_detection_manager_calculator_proto"], ) @@ -145,7 +135,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:box_detector_cc_proto", ], - visibility = ["//visibility:public"], deps = [":box_detector_calculator_proto"], ) @@ -155,7 +144,6 @@ mediapipe_cc_proto_library( cc_deps = [ "//mediapipe/framework:calculator_cc_proto", ], - visibility = ["//visibility:public"], deps = [":video_pre_stream_calculator_proto"], ) @@ -163,7 +151,6 @@ mediapipe_cc_proto_library( name = "flow_to_image_calculator_cc_proto", srcs = ["flow_to_image_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":flow_to_image_calculator_proto"], ) @@ -171,14 +158,12 @@ mediapipe_cc_proto_library( name = "opencv_video_encoder_calculator_cc_proto", srcs = ["opencv_video_encoder_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":opencv_video_encoder_calculator_proto"], ) cc_library( name = "flow_to_image_calculator", srcs = ["flow_to_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":flow_to_image_calculator_cc_proto", "//mediapipe/calculators/video/tool:flow_quantizer_model", @@ -198,7 +183,6 @@ cc_library( cc_library( name = "opencv_video_decoder_calculator", srcs = ["opencv_video_decoder_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_format_cc_proto", @@ -217,7 +201,6 @@ cc_library( cc_library( name = "opencv_video_encoder_calculator", srcs = ["opencv_video_encoder_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":opencv_video_encoder_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -240,7 +223,6 @@ cc_library( cc_library( name = "tvl1_optical_flow_calculator", srcs = ["tvl1_optical_flow_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -256,7 +238,6 @@ cc_library( cc_library( name = "motion_analysis_calculator", srcs = ["motion_analysis_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":motion_analysis_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -282,7 +263,6 @@ cc_library( cc_library( name = "flow_packager_calculator", srcs = ["flow_packager_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":flow_packager_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -300,7 +280,6 @@ cc_library( cc_library( name = "box_tracker_calculator", srcs = ["box_tracker_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":box_tracker_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -327,7 +306,6 @@ cc_library( cc_library( name = "box_detector_calculator", srcs = ["box_detector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":box_detector_calculator_cc_proto", "@com_google_absl//absl/memory", @@ -369,7 +347,6 @@ cc_library( cc_library( name = "tracked_detection_manager_calculator", srcs = ["tracked_detection_manager_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tracked_detection_manager_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -390,7 +367,6 @@ cc_library( cc_library( name = "video_pre_stream_calculator", srcs = ["video_pre_stream_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":video_pre_stream_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -407,7 +383,6 @@ filegroup( "testdata/format_MKV_VP8_VORBIS.video", "testdata/format_MP4_AVC720P_AAC.video", ], - visibility = ["//visibility:public"], ) cc_test( @@ -480,7 +455,6 @@ mediapipe_binary_graph( name = "parallel_tracker_binarypb", graph = "testdata/parallel_tracker_graph.pbtxt", output_name = "testdata/parallel_tracker.binarypb", - visibility = ["//visibility:public"], deps = [ ":box_tracker_calculator", ":flow_packager_calculator", @@ -494,7 +468,6 @@ mediapipe_binary_graph( name = "tracker_binarypb", graph = "testdata/tracker_graph.pbtxt", output_name = "testdata/tracker.binarypb", - visibility = ["//visibility:public"], deps = [ ":box_tracker_calculator", ":flow_packager_calculator", diff --git a/mediapipe/examples/desktop/hello_world/BUILD b/mediapipe/examples/desktop/hello_world/BUILD index edf98bf13..27aa088e7 100644 --- a/mediapipe/examples/desktop/hello_world/BUILD +++ b/mediapipe/examples/desktop/hello_world/BUILD @@ -14,12 +14,11 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/examples:__subpackages__"]) +package(default_visibility = ["//visibility:public"]) cc_binary( name = "hello_world", srcs = ["hello_world.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_graph", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index e3429f1e9..3cc72b4f1 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -139,7 +139,7 @@ mediapipe_proto_library( name = "test_calculators_proto", testonly = 1, srcs = ["test_calculators.proto"], - visibility = ["//visibility:public"], + visibility = [":mediapipe_internal"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index 4276ffc3a..fdb698c48 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -17,7 +17,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("//mediapipe/framework:mediapipe_register_type.bzl", "mediapipe_register_type") package( - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], features = ["-layering_check"], ) @@ -26,7 +26,6 @@ licenses(["notice"]) mediapipe_proto_library( name = "detection_proto", srcs = ["detection.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/formats:location_data_proto"], ) @@ -45,7 +44,6 @@ mediapipe_register_type( mediapipe_proto_library( name = "classification_proto", srcs = ["classification.proto"], - visibility = ["//visibility:public"], ) mediapipe_register_type( @@ -64,46 +62,39 @@ mediapipe_register_type( mediapipe_proto_library( name = "image_format_proto", srcs = ["image_format.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "matrix_data_proto", srcs = ["matrix_data.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "location_data_proto", srcs = ["location_data.proto"], portable_deps = ["//mediapipe/framework/formats/annotation:rasterization_cc_proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/formats/annotation:rasterization_proto"], ) mediapipe_proto_library( name = "affine_transform_data_proto", srcs = ["affine_transform_data.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "time_series_header_proto", srcs = ["time_series_header.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "image_file_properties_proto", srcs = ["image_file_properties.proto"], - visibility = ["//visibility:public"], ) cc_library( name = "deleting_file", srcs = ["deleting_file.cc"], hdrs = ["deleting_file.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:logging", ], @@ -113,7 +104,6 @@ cc_library( name = "matrix", srcs = ["matrix.cc"], hdrs = ["matrix.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/formats:matrix_data_cc_proto", @@ -129,9 +119,6 @@ cc_library( name = "affine_transform", srcs = ["affine_transform.cc"], hdrs = ["affine_transform.h"], - visibility = [ - "//visibility:public", - ], deps = [ ":affine_transform_data_cc_proto", "//mediapipe/framework:port", @@ -154,7 +141,6 @@ cc_library( name = "image_frame", srcs = ["image_frame.cc"], hdrs = ["image_frame.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", "@com_google_absl//absl/base", @@ -179,7 +165,6 @@ cc_library( name = "image_frame_opencv", srcs = ["image_frame_opencv.cc"], hdrs = ["image_frame_opencv.h"], - visibility = ["//visibility:public"], deps = [ ":image_frame", "//mediapipe/framework/formats:image_format_cc_proto", @@ -206,7 +191,6 @@ cc_library( name = "location", srcs = ["location.cc"], hdrs = ["location.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_protobuf//:protobuf", "//mediapipe/framework/formats/annotation:locus_cc_proto", @@ -238,7 +222,6 @@ cc_library( name = "location_opencv", srcs = ["location_opencv.cc"], hdrs = ["location_opencv.h"], - visibility = ["//visibility:public"], deps = [ ":location", "//mediapipe/framework/formats/annotation:rasterization_cc_proto", @@ -261,7 +244,6 @@ cc_test( cc_library( name = "video_stream_header", hdrs = ["video_stream_header.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", ], @@ -270,7 +252,6 @@ cc_library( cc_library( name = "yuv_image", hdrs = ["yuv_image.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:integral_types", "@libyuv", @@ -294,7 +275,6 @@ cc_test( mediapipe_proto_library( name = "rect_proto", srcs = ["rect.proto"], - visibility = ["//visibility:public"], ) mediapipe_register_type( @@ -312,7 +292,6 @@ mediapipe_register_type( mediapipe_proto_library( name = "landmark_proto", srcs = ["landmark.proto"], - visibility = ["//visibility:public"], ) mediapipe_register_type( @@ -344,7 +323,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", @@ -374,7 +352,6 @@ cc_library( name = "image_multi_pool", srcs = ["image_multi_pool.cc"], hdrs = ["image_multi_pool.h"], - visibility = ["//visibility:public"], deps = [ ":image", "//mediapipe/framework/formats:image_frame_pool", @@ -411,7 +388,6 @@ cc_library( hdrs = [ "image_opencv.h", ], - visibility = ["//visibility:public"], deps = [ ":image", "//mediapipe/framework/formats:image_format_cc_proto", @@ -425,7 +401,6 @@ cc_library( name = "image_frame_pool", srcs = ["image_frame_pool.cc"], hdrs = ["image_frame_pool.h"], - visibility = ["//visibility:public"], deps = [ ":image_frame", "@com_google_absl//absl/memory", @@ -476,7 +451,6 @@ cc_library( "-landroid", ], }), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", diff --git a/mediapipe/framework/formats/annotation/BUILD b/mediapipe/framework/formats/annotation/BUILD index 328001e85..9bcb7bccd 100644 --- a/mediapipe/framework/formats/annotation/BUILD +++ b/mediapipe/framework/formats/annotation/BUILD @@ -16,7 +16,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -24,12 +24,10 @@ mediapipe_proto_library( name = "locus_proto", srcs = ["locus.proto"], portable_deps = ["//mediapipe/framework/formats/annotation:rasterization_cc_proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/formats/annotation:rasterization_proto"], ) mediapipe_proto_library( name = "rasterization_proto", srcs = ["rasterization.proto"], - visibility = ["//visibility:public"], ) diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index 9819d262c..f1bbc0289 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -20,18 +20,16 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "optical_flow_field_data_proto", srcs = ["optical_flow_field_data.proto"], - visibility = ["//visibility:public"], ) mediapipe_cc_proto_library( name = "optical_flow_field_data_cc_proto", srcs = ["optical_flow_field_data.proto"], - visibility = ["//visibility:public"], deps = [":optical_flow_field_data_proto"], ) @@ -39,9 +37,6 @@ cc_library( name = "optical_flow_field", srcs = ["optical_flow_field.cc"], hdrs = ["optical_flow_field.h"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:type_map", "//mediapipe/framework/deps:mathutil", diff --git a/mediapipe/framework/formats/object_detection/BUILD b/mediapipe/framework/formats/object_detection/BUILD index 39940acdc..35292e1cc 100644 --- a/mediapipe/framework/formats/object_detection/BUILD +++ b/mediapipe/framework/formats/object_detection/BUILD @@ -19,17 +19,15 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "anchor_proto", srcs = ["anchor.proto"], - visibility = ["//visibility:public"], ) mediapipe_cc_proto_library( name = "anchor_cc_proto", srcs = ["anchor.proto"], - visibility = ["//visibility:public"], deps = [":anchor_proto"], ) diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 866a5120e..01ef6ee86 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -18,35 +18,31 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) package( - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], features = ["-layering_check"], ) proto_library( name = "default_input_stream_handler_proto", srcs = ["default_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) proto_library( name = "fixed_size_input_stream_handler_proto", srcs = ["fixed_size_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) proto_library( name = "sync_set_input_stream_handler_proto", srcs = ["sync_set_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) proto_library( name = "timestamp_align_input_stream_handler_proto", srcs = ["timestamp_align_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) @@ -54,7 +50,6 @@ mediapipe_cc_proto_library( name = "default_input_stream_handler_cc_proto", srcs = ["default_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":default_input_stream_handler_proto"], ) @@ -62,7 +57,6 @@ mediapipe_cc_proto_library( name = "fixed_size_input_stream_handler_cc_proto", srcs = ["fixed_size_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":fixed_size_input_stream_handler_proto"], ) @@ -70,7 +64,6 @@ mediapipe_cc_proto_library( name = "sync_set_input_stream_handler_cc_proto", srcs = ["sync_set_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":sync_set_input_stream_handler_proto"], ) @@ -78,14 +71,12 @@ mediapipe_cc_proto_library( name = "timestamp_align_input_stream_handler_cc_proto", srcs = ["timestamp_align_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":timestamp_align_input_stream_handler_proto"], ) cc_library( name = "barrier_input_stream_handler", srcs = ["barrier_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", ], @@ -96,7 +87,6 @@ cc_library( name = "default_input_stream_handler", srcs = ["default_input_stream_handler.cc"], hdrs = ["default_input_stream_handler.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", "//mediapipe/framework/stream_handler:default_input_stream_handler_cc_proto", @@ -108,7 +98,6 @@ cc_library( cc_library( name = "early_close_input_stream_handler", srcs = ["early_close_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", "@com_google_absl//absl/strings", @@ -119,7 +108,6 @@ cc_library( cc_library( name = "fixed_size_input_stream_handler", srcs = ["fixed_size_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ ":default_input_stream_handler", "//mediapipe/framework:input_stream_handler", @@ -131,7 +119,6 @@ cc_library( cc_library( name = "immediate_input_stream_handler", srcs = ["immediate_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", ], @@ -142,7 +129,6 @@ cc_library( name = "in_order_output_stream_handler", srcs = ["in_order_output_stream_handler.cc"], hdrs = ["in_order_output_stream_handler.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", @@ -160,7 +146,6 @@ cc_library( cc_library( name = "mux_input_stream_handler", srcs = ["mux_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", "//mediapipe/framework/port:logging", @@ -173,7 +158,6 @@ cc_library( cc_library( name = "sync_set_input_stream_handler", srcs = ["sync_set_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", @@ -192,7 +176,6 @@ cc_library( cc_library( name = "timestamp_align_input_stream_handler", srcs = ["timestamp_align_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", diff --git a/mediapipe/modules/holistic_landmark/calculators/BUILD b/mediapipe/modules/holistic_landmark/calculators/BUILD index c3c091924..bc00b697c 100644 --- a/mediapipe/modules/holistic_landmark/calculators/BUILD +++ b/mediapipe/modules/holistic_landmark/calculators/BUILD @@ -21,7 +21,6 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "hand_detections_from_pose_to_rects_calculator", srcs = ["hand_detections_from_pose_to_rects_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", @@ -39,7 +38,6 @@ cc_library( mediapipe_proto_library( name = "roi_tracking_calculator_proto", srcs = ["roi_tracking_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -49,7 +47,6 @@ mediapipe_proto_library( cc_library( name = "roi_tracking_calculator", srcs = ["roi_tracking_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":roi_tracking_calculator_cc_proto", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index 3f1ebb353..6bca24446 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -134,7 +134,6 @@ proto_library( mediapipe_cc_proto_library( name = "tone_models_cc_proto", srcs = ["tone_models.proto"], - visibility = ["//visibility:public"], deps = [":tone_models_proto"], ) @@ -142,7 +141,6 @@ mediapipe_cc_proto_library( name = "tone_estimation_cc_proto", srcs = ["tone_estimation.proto"], cc_deps = [":tone_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":tone_estimation_proto"], ) @@ -153,21 +151,18 @@ mediapipe_cc_proto_library( ":tone_estimation_cc_proto", ":tone_models_cc_proto", ], - visibility = ["//visibility:public"], deps = [":region_flow_computation_proto"], ) mediapipe_cc_proto_library( name = "motion_saliency_cc_proto", srcs = ["motion_saliency.proto"], - visibility = ["//visibility:public"], deps = [":motion_saliency_proto"], ) mediapipe_cc_proto_library( name = "motion_estimation_cc_proto", srcs = ["motion_estimation.proto"], - visibility = ["//visibility:public"], deps = [":motion_estimation_proto"], ) @@ -179,7 +174,6 @@ mediapipe_cc_proto_library( ":motion_saliency_cc_proto", ":region_flow_computation_cc_proto", ], - visibility = ["//visibility:public"], deps = [":motion_analysis_proto"], ) @@ -187,14 +181,12 @@ mediapipe_cc_proto_library( name = "region_flow_cc_proto", srcs = ["region_flow.proto"], cc_deps = [":motion_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":region_flow_proto"], ) mediapipe_cc_proto_library( name = "motion_models_cc_proto", srcs = ["motion_models.proto"], - visibility = ["//visibility:public"], deps = [":motion_models_proto"], ) @@ -202,21 +194,18 @@ mediapipe_cc_proto_library( name = "camera_motion_cc_proto", srcs = ["camera_motion.proto"], cc_deps = [":motion_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":camera_motion_proto"], ) mediapipe_cc_proto_library( name = "push_pull_filtering_cc_proto", srcs = ["push_pull_filtering.proto"], - visibility = ["//visibility:public"], deps = [":push_pull_filtering_proto"], ) mediapipe_cc_proto_library( name = "frame_selection_solution_evaluator_cc_proto", srcs = ["frame_selection_solution_evaluator.proto"], - visibility = ["//visibility:public"], deps = [":frame_selection_solution_evaluator_proto"], ) @@ -228,7 +217,6 @@ mediapipe_cc_proto_library( ":frame_selection_solution_evaluator_cc_proto", ":region_flow_cc_proto", ], - visibility = ["//visibility:public"], deps = [":frame_selection_proto"], ) @@ -239,7 +227,6 @@ mediapipe_cc_proto_library( ":motion_models_cc_proto", ":region_flow_cc_proto", ], - visibility = ["//visibility:public"], deps = [":flow_packager_proto"], ) @@ -247,7 +234,6 @@ mediapipe_cc_proto_library( name = "tracking_cc_proto", srcs = ["tracking.proto"], cc_deps = [":motion_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":tracking_proto"], ) @@ -255,14 +241,12 @@ mediapipe_cc_proto_library( name = "box_tracker_cc_proto", srcs = ["box_tracker.proto"], cc_deps = [":tracking_cc_proto"], - visibility = ["//visibility:public"], deps = [":box_tracker_proto"], ) mediapipe_cc_proto_library( name = "tracked_detection_manager_config_cc_proto", srcs = ["tracked_detection_manager_config.proto"], - visibility = ["//visibility:public"], deps = [":tracked_detection_manager_config_proto"], ) @@ -273,7 +257,6 @@ mediapipe_cc_proto_library( ":box_tracker_cc_proto", ":region_flow_cc_proto", ], - visibility = ["//visibility:public"], deps = [":box_detector_proto"], ) From 09740130e874560957b154bbb51ae4c90dcd64ca Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 29 Nov 2022 11:32:44 -0800 Subject: [PATCH 065/346] Use naturalWidth and naturalHeight for image data PiperOrigin-RevId: 491694147 --- mediapipe/web/graph_runner/graph_runner.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 9a0f7148c..9a8101659 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -325,6 +325,10 @@ export class GraphRunner { if ((imageSource as HTMLVideoElement).videoWidth) { width = (imageSource as HTMLVideoElement).videoWidth; height = (imageSource as HTMLVideoElement).videoHeight; + } else if ((imageSource as HTMLImageElement).naturalWidth) { + // TODO: Ensure this works with SVG images + width = (imageSource as HTMLImageElement).naturalWidth; + height = (imageSource as HTMLImageElement).naturalHeight; } else { width = imageSource.width; height = imageSource.height; From 88173948eed970b3cc5c215ec3541fcc08b1723c Mon Sep 17 00:00:00 2001 From: Michael Hays Date: Tue, 29 Nov 2022 13:37:18 -0800 Subject: [PATCH 066/346] Internal change PiperOrigin-RevId: 491724816 --- mediapipe/web/graph_runner/graph_runner.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 9a8101659..a9bb979af 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -1085,8 +1085,8 @@ async function runScript(scriptUrl: string) { */ export async function createMediaPipeLib( constructorFcn: WasmMediaPipeConstructor, - wasmLoaderScript?: string, - assetLoaderScript?: string, + wasmLoaderScript?: string|null, + assetLoaderScript?: string|null, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, fileLocator?: FileLocator): Promise { const scripts = []; From fcd2d2c5af18dc4ebf16116a4f472b4bdb5e52a0 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 29 Nov 2022 14:12:14 -0800 Subject: [PATCH 067/346] Internal change PiperOrigin-RevId: 491733850 --- mediapipe/gpu/BUILD | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 9cc670fb6..7a8aa6557 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -176,6 +176,16 @@ cc_library( "-fobjc-arc", # enable reference-counting ], }), + linkopts = select({ + "//conditions:default": [], + "//mediapipe:ios": [ + "-framework OpenGLES", + ], + "//mediapipe:macos": [ + "-framework OpenGL", + "-framework AppKit", + ], + }), visibility = ["//visibility:public"], deps = [ ":attachments", @@ -204,8 +214,10 @@ cc_library( }) + select({ "//conditions:default": [ ], - "//mediapipe:ios": [], - "//mediapipe:macos": [], + "//mediapipe:ios": [ + ], + "//mediapipe:macos": [ + ], }), ) From 460aee7933f255c749bda69673174ec91a9be017 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 30 Nov 2022 20:40:00 -0800 Subject: [PATCH 068/346] Make mediapipe_tasks_aar's android_library depend on "//third_party:androidx_annotation". PiperOrigin-RevId: 492092487 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 762184842..6ca67c096 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -289,6 +289,7 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:androidx_annotation", "//third_party:autovalue", "@maven//:com_google_guava_guava", ] + select({ From 29c7702984fd0309fbadf64347fdd7cb5604b52f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 1 Dec 2022 05:50:46 -0800 Subject: [PATCH 069/346] Inline formerly nested 'ClassifierOptions' in Java classifier APIs. PiperOrigin-RevId: 492173060 --- .../com/google/mediapipe/tasks/audio/BUILD | 2 +- .../audioclassifier/AudioClassifier.java | 84 ++++++++++++++--- .../com/google/mediapipe/tasks/text/BUILD | 2 +- .../text/textclassifier/TextClassifier.java | 90 ++++++++++++++++--- .../com/google/mediapipe/tasks/vision/BUILD | 2 +- .../imageclassifier/ImageClassifier.java | 82 ++++++++++++++--- .../textclassifier/TextClassifierTest.java | 31 +++++++ .../imageclassifier/ImageClassifierTest.java | 81 +++++++++++------ 8 files changed, 305 insertions(+), 69 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD index 6771335ad..2afc75ec0 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD @@ -66,10 +66,10 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java index 0f3374175..d78685fe3 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java @@ -27,7 +27,7 @@ import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi; import com.google.mediapipe.tasks.audio.core.RunningMode; import com.google.mediapipe.tasks.components.containers.AudioData; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -266,7 +266,7 @@ public final class AudioClassifier extends BaseAudioTaskApi { /* * Sends audio data (a block in a continuous audio stream) to perform audio classification, and - * the results will be available via the {@link ResultListener} provided in the + * the results will be available via the {@link ResultListener} provided in the * {@link AudioClassifierOptions}. Only use this method when the AudioClassifier is created with * the audio stream mode. * @@ -320,10 +320,42 @@ public final class AudioClassifier extends BaseAudioTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); + + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); /** * Sets the {@link ResultListener} to receive the classification results asynchronously when @@ -340,9 +372,7 @@ public final class AudioClassifier extends BaseAudioTaskApi { /** * Validates and builds the {@link AudioClassifierOptions} instance. * - * @throws IllegalArgumentException if the result listener and the running mode are not - * properly configured. The result listener should only be set when the audio classifier - * is in the audio stream mode. + * @throws IllegalArgumentException if any of the set options are invalid. */ public final AudioClassifierOptions build() { AudioClassifierOptions options = autoBuild(); @@ -357,6 +387,13 @@ public final class AudioClassifier extends BaseAudioTaskApi { "The audio classifier is in the audio clips mode, a user-defined result listener" + " shouldn't be provided in AudioClassifierOptions."); } + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } return options; } } @@ -365,7 +402,15 @@ public final class AudioClassifier extends BaseAudioTaskApi { abstract RunningMode runningMode(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); abstract Optional> resultListener(); @@ -373,7 +418,9 @@ public final class AudioClassifier extends BaseAudioTaskApi { public static Builder builder() { return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder() - .setRunningMode(RunningMode.AUDIO_CLIPS); + .setRunningMode(RunningMode.AUDIO_CLIPS) + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** @@ -385,12 +432,21 @@ public final class AudioClassifier extends BaseAudioTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder = AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 023a1f286..f9c8e7c76 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -49,10 +49,10 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java index 341d6bf91..0ea91a9f8 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java @@ -24,7 +24,7 @@ import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.OutputHandler; import com.google.mediapipe.tasks.core.TaskInfo; @@ -216,20 +216,79 @@ public final class TextClassifier implements AutoCloseable { public abstract Builder setBaseOptions(BaseOptions value); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); - public abstract TextClassifierOptions build(); + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); + + abstract TextClassifierOptions autoBuild(); + + /** + * Validates and builds the {@link TextClassifierOptions} instance. + * + * @throws IllegalArgumentException if any of the set options are invalid. + */ + public final TextClassifierOptions build() { + TextClassifierOptions options = autoBuild(); + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } + return options; + } } abstract BaseOptions baseOptions(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); public static Builder builder() { - return new AutoValue_TextClassifier_TextClassifierOptions.Builder(); + return new AutoValue_TextClassifier_TextClassifierOptions.Builder() + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -238,12 +297,21 @@ public final class TextClassifier implements AutoCloseable { BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder = TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index b7febb118..2d130ff05 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -98,10 +98,10 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index 5e278804b..8990f46fd 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -27,7 +27,7 @@ import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -376,10 +376,42 @@ public final class ImageClassifier extends BaseVisionTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); + + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); /** * Sets the {@link ResultListener} to receive the classification results asynchronously when @@ -396,9 +428,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { /** * Validates and builds the {@link ImageClassifierOptions} instance. * * - * @throws IllegalArgumentException if the result listener and the running mode are not - * properly configured. The result listener should only be set when the image classifier - * is in the live stream mode. + * @throws IllegalArgumentException if any of the set options are invalid. */ public final ImageClassifierOptions build() { ImageClassifierOptions options = autoBuild(); @@ -413,6 +443,13 @@ public final class ImageClassifier extends BaseVisionTaskApi { "The image classifier is in the image or video mode, a user-defined result listener" + " shouldn't be provided in ImageClassifierOptions."); } + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } return options; } } @@ -421,7 +458,15 @@ public final class ImageClassifier extends BaseVisionTaskApi { abstract RunningMode runningMode(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); abstract Optional> resultListener(); @@ -429,7 +474,9 @@ public final class ImageClassifier extends BaseVisionTaskApi { public static Builder builder() { return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder() - .setRunningMode(RunningMode.IMAGE); + .setRunningMode(RunningMode.IMAGE) + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** @@ -441,12 +488,21 @@ public final class ImageClassifier extends BaseVisionTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder = ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java index 5e03d2a4c..5ed413f6a 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java @@ -40,6 +40,37 @@ public class TextClassifierTest { private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate"; private static final String POSITIVE_TEXT = "it's a charming and often affecting journey"; + @Test + public void options_failsWithNegativeMaxResults() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + TextClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build()) + .setMaxResults(-1) + .build()); + assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0"); + } + + @Test + public void options_failsWithBothAllowlistAndDenylist() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + TextClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build()) + .setCategoryAllowlist(Arrays.asList("foo")) + .setCategoryDenylist(Arrays.asList("bar")) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("Category allowlist and denylist are mutually exclusive"); + } + @Test public void create_failsWithMissingModel() throws Exception { String nonExistentFile = "/path/to/non/existent/file"; diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java index 69820ce2d..dac11bf02 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java @@ -26,7 +26,6 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.TestUtils; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; @@ -55,6 +54,37 @@ public class ImageClassifierTest { @RunWith(AndroidJUnit4.class) public static final class General extends ImageClassifierTest { + @Test + public void options_failsWithNegativeMaxResults() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setMaxResults(-1) + .build()); + assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0"); + } + + @Test + public void options_failsWithBothAllowlistAndDenylist() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setCategoryAllowlist(Arrays.asList("foo")) + .setCategoryDenylist(Arrays.asList("bar")) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("Category allowlist and denylist are mutually exclusive"); + } + @Test public void create_failsWithMissingModel() throws Exception { String nonExistentFile = "/path/to/non/existent/file"; @@ -105,7 +135,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .setMaxResults(3) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -125,7 +155,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -141,7 +171,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build()) + .setScoreThreshold(0.02f) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -160,10 +190,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions( - ClassifierOptions.builder() - .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf")) - .build()) + .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf")) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -183,11 +210,8 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions( - ClassifierOptions.builder() - .setMaxResults(3) - .setCategoryDenylist(Arrays.asList("bagel")) - .build()) + .setMaxResults(3) + .setCategoryDenylist(Arrays.asList("bagel")) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -207,7 +231,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -228,7 +252,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .setMaxResults(3) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -251,7 +275,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -322,14 +346,14 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyForVideo( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, () -> imageClassifier.classifyAsync( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -353,7 +377,7 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyAsync( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -379,7 +403,7 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyForVideo( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -388,7 +412,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -405,13 +429,14 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.VIDEO) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); for (int i = 0; i < 3; i++) { - ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); + ImageClassifierResult results = + imageClassifier.classifyForVideo(image, /* timestampMs= */ i); assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); @@ -424,7 +449,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageClassificationResult, inputImage) -> { @@ -436,11 +461,11 @@ public class ImageClassifierTest { .build(); try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); + imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1); MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyAsync(image, /*timestampMs=*/ 0)); + () -> imageClassifier.classifyAsync(image, /* timestampMs= */ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -453,7 +478,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageClassificationResult, inputImage) -> { @@ -466,7 +491,7 @@ public class ImageClassifierTest { try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; ++i) { - imageClassifier.classifyAsync(image, /*timestampMs=*/ i); + imageClassifier.classifyAsync(image, /* timestampMs= */ i); } } } From 01010fa24887e50f1bb851e9758847f6f340bea3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 1 Dec 2022 07:15:52 -0800 Subject: [PATCH 070/346] Internal change PiperOrigin-RevId: 492188196 --- .../com/google/mediapipe/tasks/audio/BUILD | 2 +- .../audio/audioembedder/AudioEmbedder.java | 40 ++++++++--- .../tasks/components/processors/BUILD | 13 ---- .../processors/EmbedderOptions.java | 68 ------------------ .../com/google/mediapipe/tasks/text/BUILD | 2 +- .../tasks/text/textembedder/TextEmbedder.java | 41 ++++++++--- .../com/google/mediapipe/tasks/vision/BUILD | 2 +- .../vision/imageembedder/ImageEmbedder.java | 40 ++++++++--- .../imageembedder/ImageEmbedderTest.java | 69 +++++++++---------- 9 files changed, 126 insertions(+), 151 deletions(-) delete mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD index 2afc75ec0..2d29ccf23 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD @@ -92,12 +92,12 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/audio:libmediapipe_tasks_audio_jni_lib", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java index c0bc04a4e..4bc505d84 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java @@ -28,7 +28,7 @@ import com.google.mediapipe.tasks.audio.core.RunningMode; import com.google.mediapipe.tasks.components.containers.AudioData; import com.google.mediapipe.tasks.components.containers.Embedding; import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; -import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; import com.google.mediapipe.tasks.components.utils.CosineSimilarity; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; @@ -309,10 +309,24 @@ public final class AudioEmbedder extends BaseAudioTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link EmbedderOptions} controling embedding behavior, such as score - * threshold, number of results, etc. + * Sets whether L2 normalization should be performed on the returned embeddings. Use this + * option only if the model does not already contain a native L2_NORMALIZATION TF + * Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF + * Lite inference. + * + *

False by default. */ - public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is + * guaranteed to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} + * if this is not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); /** * Sets the {@link ResultListener} to receive the embedding results asynchronously when the @@ -354,7 +368,9 @@ public final class AudioEmbedder extends BaseAudioTaskApi { abstract RunningMode runningMode(); - abstract Optional embedderOptions(); + abstract boolean l2Normalize(); + + abstract boolean quantize(); abstract Optional> resultListener(); @@ -362,7 +378,9 @@ public final class AudioEmbedder extends BaseAudioTaskApi { public static Builder builder() { return new AutoValue_AudioEmbedder_AudioEmbedderOptions.Builder() - .setRunningMode(RunningMode.AUDIO_CLIPS); + .setRunningMode(RunningMode.AUDIO_CLIPS) + .setL2Normalize(false) + .setQuantize(false); } /** Converts a {@link AudioEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -372,12 +390,14 @@ public final class AudioEmbedder extends BaseAudioTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder = + EmbedderOptionsProto.EmbedderOptions.newBuilder(); + embedderOptionsBuilder.setL2Normalize(l2Normalize()); + embedderOptionsBuilder.setQuantize(quantize()); AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.Builder taskOptionsBuilder = AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (embedderOptions().isPresent()) { - taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setEmbedderOptions(embedderOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD index e61e59390..1f99f1612 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD @@ -29,19 +29,6 @@ android_library( ], ) -android_library( - name = "embedderoptions", - srcs = ["EmbedderOptions.java"], - javacopts = [ - "-Xep:AndroidJdkLibsChecker:OFF", - ], - deps = [ - "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", - "//third_party:autovalue", - "@maven//:com_google_guava_guava", - ], -) - # Expose the java source files for building mediapipe tasks core AAR. filegroup( name = "java_src", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java deleted file mode 100644 index 3cd197234..000000000 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.tasks.components.processors; - -import com.google.auto.value.AutoValue; -import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; - -/** Embedder options shared across MediaPipe Java embedding tasks. */ -@AutoValue -public abstract class EmbedderOptions { - - /** Builder for {@link EmbedderOptions} */ - @AutoValue.Builder - public abstract static class Builder { - /** - * Sets whether L2 normalization should be performed on the returned embeddings. Use this option - * only if the model does not already contain a native L2_NORMALIZATION TF Lite Op. - * In most cases, this is already the case and L2 norm is thus achieved through TF Lite - * inference. - * - *

False by default. - */ - public abstract Builder setL2Normalize(boolean l2Normalize); - - /** - * Sets whether the returned embedding should be quantized to bytes via scalar quantization. - * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is guaranteed - * to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} if this is - * not the case. - * - *

False by default. - */ - public abstract Builder setQuantize(boolean quantize); - - public abstract EmbedderOptions build(); - } - - public abstract boolean l2Normalize(); - - public abstract boolean quantize(); - - public static Builder builder() { - return new AutoValue_EmbedderOptions.Builder().setL2Normalize(false).setQuantize(false); - } - - /** - * Converts an {@link EmbedderOptions} object to an {@link EmbedderOptionsProto.EmbedderOptions} - * protobuf message. - */ - public EmbedderOptionsProto.EmbedderOptions convertToProto() { - return EmbedderOptionsProto.EmbedderOptions.newBuilder() - .setL2Normalize(l2Normalize()) - .setQuantize(quantize()) - .build(); - } -} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index f9c8e7c76..5b10e9aab 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -74,11 +74,11 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java index 95fa1f087..9b464d0e8 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java @@ -25,7 +25,7 @@ import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.tasks.components.containers.Embedding; import com.google.mediapipe.tasks.components.containers.EmbeddingResult; import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; -import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; import com.google.mediapipe.tasks.components.utils.CosineSimilarity; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.OutputHandler; @@ -41,7 +41,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; /** * Performs embedding extraction on text. @@ -218,20 +217,38 @@ public final class TextEmbedder implements AutoCloseable { public abstract Builder setBaseOptions(BaseOptions value); /** - * Sets the optional {@link EmbedderOptions} controling embedder behavior, such as - * L2-normalization and scalar quantization. + * Sets whether L2 normalization should be performed on the returned embeddings. Use this + * option only if the model does not already contain a native L2_NORMALIZATION TF + * Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF + * Lite inference. + * + *

False by default. */ - public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is + * guaranteed to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} + * if this is not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); public abstract TextEmbedderOptions build(); } abstract BaseOptions baseOptions(); - abstract Optional embedderOptions(); + abstract boolean l2Normalize(); + + abstract boolean quantize(); public static Builder builder() { - return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder(); + return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder() + .setL2Normalize(false) + .setQuantize(false); } /** Converts a {@link TextEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -240,12 +257,14 @@ public final class TextEmbedder implements AutoCloseable { BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder = + EmbedderOptionsProto.EmbedderOptions.newBuilder(); + embedderOptionsBuilder.setL2Normalize(l2Normalize()); + embedderOptionsBuilder.setQuantize(quantize()); TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.Builder taskOptionsBuilder = TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (embedderOptions().isPresent()) { - taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setEmbedderOptions(embedderOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 2d130ff05..b61c174fe 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -190,11 +190,11 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java index 0d8ecd5c3..af053d860 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java @@ -28,7 +28,7 @@ import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Embedding; import com.google.mediapipe.tasks.components.containers.EmbeddingResult; import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; -import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; import com.google.mediapipe.tasks.components.utils.CosineSimilarity; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; @@ -369,10 +369,24 @@ public final class ImageEmbedder extends BaseVisionTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link EmbedderOptions} controling embedder behavior, such as - * L2-normalization and scalar quantization. + * Sets whether L2 normalization should be performed on the returned embeddings. Use this + * option only if the model does not already contain a native L2_NORMALIZATION TF + * Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF + * Lite inference. + * + *

False by default. */ - public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is + * guaranteed to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} + * if this is not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); /** * Sets the {@link ResultListener} to receive the embedding results asynchronously when the @@ -414,7 +428,9 @@ public final class ImageEmbedder extends BaseVisionTaskApi { abstract RunningMode runningMode(); - abstract Optional embedderOptions(); + abstract boolean l2Normalize(); + + abstract boolean quantize(); abstract Optional> resultListener(); @@ -422,7 +438,9 @@ public final class ImageEmbedder extends BaseVisionTaskApi { public static Builder builder() { return new AutoValue_ImageEmbedder_ImageEmbedderOptions.Builder() - .setRunningMode(RunningMode.IMAGE); + .setRunningMode(RunningMode.IMAGE) + .setL2Normalize(false) + .setQuantize(false); } /** Converts a {@link ImageEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -432,12 +450,14 @@ public final class ImageEmbedder extends BaseVisionTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder = + EmbedderOptionsProto.EmbedderOptions.newBuilder(); + embedderOptionsBuilder.setL2Normalize(l2Normalize()); + embedderOptionsBuilder.setQuantize(quantize()); ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.Builder taskOptionsBuilder = ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (embedderOptions().isPresent()) { - taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setEmbedderOptions(embedderOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.ext, diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java index 56249ead9..8dec6f80b 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java @@ -25,7 +25,6 @@ import androidx.test.ext.junit.runners.AndroidJUnit4; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; -import com.google.mediapipe.tasks.components.processors.EmbedderOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.TestUtils; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; @@ -92,8 +91,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -105,12 +104,8 @@ public class ImageEmbedderTest { @Test public void embed_succeedsWithL2Normalization() throws Exception { BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); - EmbedderOptions embedderOptions = EmbedderOptions.builder().setL2Normalize(true).build(); ImageEmbedderOptions options = - ImageEmbedderOptions.builder() - .setBaseOptions(baseOptions) - .setEmbedderOptions(embedderOptions) - .build(); + ImageEmbedderOptions.builder().setBaseOptions(baseOptions).setL2Normalize(true).build(); ImageEmbedder imageEmbedder = ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -118,8 +113,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -131,12 +126,8 @@ public class ImageEmbedderTest { @Test public void embed_succeedsWithQuantization() throws Exception { BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); - EmbedderOptions embedderOptions = EmbedderOptions.builder().setQuantize(true).build(); ImageEmbedderOptions options = - ImageEmbedderOptions.builder() - .setBaseOptions(baseOptions) - .setEmbedderOptions(embedderOptions) - .build(); + ImageEmbedderOptions.builder().setBaseOptions(baseOptions).setQuantize(true).build(); ImageEmbedder imageEmbedder = ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -144,8 +135,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ true); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ true); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ true); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ true); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -168,8 +159,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(resultRoi, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultRoi, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -190,8 +181,8 @@ public class ImageEmbedderTest { imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); // Check results. - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultRotated, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultRotated, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -214,8 +205,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(resultRoiRotated, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultRoiRotated, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -277,12 +268,14 @@ public class ImageEmbedderTest { assertThrows( MediaPipeException.class, () -> - imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + imageEmbedder.embedForVideo( + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, - () -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + () -> + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -303,7 +296,8 @@ public class ImageEmbedderTest { exception = assertThrows( MediaPipeException.class, - () -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + () -> + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -327,7 +321,8 @@ public class ImageEmbedderTest { assertThrows( MediaPipeException.class, () -> - imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + imageEmbedder.embedForVideo( + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -340,8 +335,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -363,8 +358,8 @@ public class ImageEmbedderTest { for (int i = 0; i < 3; ++i) { ImageEmbedderResult result = - imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ i); - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ i); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); } } @@ -378,17 +373,18 @@ public class ImageEmbedderTest { .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageEmbedderResult, inputImage) -> { - assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension( + imageEmbedderResult, /* quantized= */ false); assertImageSizeIsExpected(inputImage); }) .build(); try (ImageEmbedder imageEmbedder = ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1); MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> imageEmbedder.embedAsync(image, /*timestampMs=*/ 0)); + () -> imageEmbedder.embedAsync(image, /* timestampMs= */ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -405,14 +401,15 @@ public class ImageEmbedderTest { .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageEmbedderResult, inputImage) -> { - assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension( + imageEmbedderResult, /* quantized= */ false); assertImageSizeIsExpected(inputImage); }) .build(); try (ImageEmbedder imageEmbedder = ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; ++i) { - imageEmbedder.embedAsync(image, /*timestampMs=*/ i); + imageEmbedder.embedAsync(image, /* timestampMs= */ i); } } } From a430939fe4b333ddb31a254f6a08b072f7dfff57 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 1 Dec 2022 07:42:55 -0800 Subject: [PATCH 071/346] Document RunningMode PiperOrigin-RevId: 492193299 --- .../vision/gesture_recognizer/gesture_recognizer.ts | 8 ++++++-- .../web/vision/hand_landmarker/hand_landmarker.ts | 8 ++++++-- .../web/vision/image_classifier/image_classifier.ts | 6 ++++-- .../tasks/web/vision/image_embedder/image_embedder.ts | 11 ++++------- .../web/vision/object_detector/object_detector.ts | 8 ++++++-- 5 files changed, 26 insertions(+), 15 deletions(-) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 7441911c1..9ec63b07a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -225,7 +225,9 @@ export class GestureRecognizer extends /** * Performs gesture recognition on the provided single image and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * GestureRecognizer is created with running mode `image`. + * * @param image A single image to process. * @return The detected gestures. */ @@ -235,7 +237,9 @@ export class GestureRecognizer extends /** * Performs gesture recognition on the provided video frame and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * GestureRecognizer is created with running mode `video`. + * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. * @return The detected gestures. diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 6d69d568c..290f49455 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -177,7 +177,9 @@ export class HandLandmarker extends VisionTaskRunner { /** * Performs hand landmarks detection on the provided single image and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * HandLandmarker is created with running mode `image`. + * * @param image An image to process. * @return The detected hand landmarks. */ @@ -187,7 +189,9 @@ export class HandLandmarker extends VisionTaskRunner { /** * Performs hand landmarks detection on the provided video frame and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * HandLandmarker is created with running mode `video`. + * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. * @return The detected hand landmarks. diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 604795f9f..185ddf9ea 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -120,7 +120,8 @@ export class ImageClassifier extends VisionTaskRunner { /** * Performs image classification on the provided single image and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * ImageClassifier is created with running mode `image`. * * @param image An image to process. * @return The classification result of the image @@ -131,7 +132,8 @@ export class ImageClassifier extends VisionTaskRunner { /** * Performs image classification on the provided video frame and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * ImageClassifier is created with running mode `video`. * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 68068db6d..91352e934 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -122,10 +122,8 @@ export class ImageEmbedder extends VisionTaskRunner { /** * Performs embedding extraction on the provided single image and waits - * synchronously for the response. - * - * Only use this method when the `useStreamMode` option is not set or - * expliclity set to `false`. + * synchronously for the response. Only use this method when the + * ImageEmbedder is created with running mode `image`. * * @param image The image to process. * @return The classification result of the image @@ -136,9 +134,8 @@ export class ImageEmbedder extends VisionTaskRunner { /** * Performs embedding extraction on the provided video frame and waits - * synchronously for the response. - * - * Only use this method when the `useStreamMode` option is set to `true`. + * synchronously for the response. Only use this method when the + * ImageEmbedder is created with running mode `video`. * * @param imageFrame The image frame to process. * @param timestamp The timestamp of the current frame, in ms. diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 0f039acb2..7711c39e9 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -151,7 +151,9 @@ export class ObjectDetector extends VisionTaskRunner { /** * Performs object detection on the provided single image and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * ObjectDetector is created with running mode `image`. + * * @param image An image to process. * @return The list of detected objects */ @@ -161,7 +163,9 @@ export class ObjectDetector extends VisionTaskRunner { /** * Performs object detection on the provided vidoe frame and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * ObjectDetector is created with running mode `video`. + * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. * @return The list of detected objects From e7eee27c1c78649e126d197ec338b779ff72d356 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 1 Dec 2022 08:14:53 -0800 Subject: [PATCH 072/346] Remove the deleted library "mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions" from mediapipe_tasks_aar's android_library deps list. PiperOrigin-RevId: 492200061 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 6ca67c096..d91c03cc2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -286,7 +286,6 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:androidx_annotation", From 3ee37800e2d63092d8f8ded69619380eb55ad9ea Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 1 Dec 2022 08:41:33 -0800 Subject: [PATCH 073/346] Depending on "inference_calculator_cpu" when the mediapipe tasks can only support cpu inference. PiperOrigin-RevId: 492205954 --- mediapipe/tasks/cc/audio/audio_classifier/BUILD | 2 +- mediapipe/tasks/cc/audio/audio_embedder/BUILD | 2 +- mediapipe/tasks/cc/text/text_classifier/BUILD | 2 +- mediapipe/tasks/cc/text/text_embedder/BUILD | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index a817bcc3b..f61472413 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -55,7 +55,7 @@ cc_library( "//mediapipe/calculators/core:side_packet_to_stream_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/audio/audio_embedder/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/BUILD index adba28e6a..6a0f627b2 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/cc/audio/audio_embedder/BUILD @@ -56,7 +56,7 @@ cc_library( "//mediapipe/calculators/core:side_packet_to_stream_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index 61395cf4e..3c9c3fc0e 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -47,7 +47,7 @@ cc_library( name = "text_classifier_graph", srcs = ["text_classifier_graph.cc"], deps = [ - "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD index f19af35be..4c970159e 100644 --- a/mediapipe/tasks/cc/text/text_embedder/BUILD +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -48,8 +48,8 @@ cc_library( name = "text_embedder_graph", srcs = ["text_embedder_graph.cc"], deps = [ - "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", From e685ac93446e22d31a6bc269416ff13dece6edbe Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 1 Dec 2022 08:45:47 -0800 Subject: [PATCH 074/346] Re-use classifier options for ObjectDetector PiperOrigin-RevId: 492206856 --- .../web/components/utils/cosine_similarity.ts | 1 + .../tasks/web/vision/object_detector/BUILD | 1 + .../object_detector_options.d.ts | 33 ++----------------- 3 files changed, 5 insertions(+), 30 deletions(-) diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.ts index fb1d0c185..1f483b9b6 100644 --- a/mediapipe/tasks/web/components/utils/cosine_similarity.ts +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.ts @@ -36,6 +36,7 @@ export function computeCosineSimilarity(u: Embedding, v: Embedding): number { throw new Error( 'Cannot compute cosine similarity between quantized and float embeddings.'); } + function convertToBytes(data: Uint8Array): number[] { return Array.from(data, v => v - 128); } diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index b6bef6bfa..198585258 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -35,6 +35,7 @@ mediapipe_ts_declaration( deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts index 1d20ce1e2..7564e7760 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts @@ -14,36 +14,9 @@ * limitations under the License. */ +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; /** Options to configure the MediaPipe Object Detector Task */ -export interface ObjectDetectorOptions extends VisionTaskOptions { - /** - * The locale to use for display names specified through the TFLite Model - * Metadata, if any. Defaults to English. - */ - displayNamesLocale?: string|undefined; - - /** The maximum number of top-scored detection results to return. */ - maxResults?: number|undefined; - - /** - * Overrides the value provided in the model metadata. Results below this - * value are rejected. - */ - scoreThreshold?: number|undefined; - - /** - * Allowlist of category names. If non-empty, detection results whose category - * name is not in this set will be filtered out. Duplicate or unknown category - * names are ignored. Mutually exclusive with `categoryDenylist`. - */ - categoryAllowlist?: string[]|undefined; - - /** - * Denylist of category names. If non-empty, detection results whose category - * name is in this set will be filtered out. Duplicate or unknown category - * names are ignored. Mutually exclusive with `categoryAllowlist`. - */ - categoryDenylist?: string[]|undefined; -} +export interface ObjectDetectorOptions extends VisionTaskOptions, + ClassifierOptions {} From 02aa162c9e953b05153f68d13e55a06b34571a0f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 1 Dec 2022 11:09:26 -0800 Subject: [PATCH 075/346] Rename gesture_recognizer test_data to testdata to be consistent with rest of model_maker PiperOrigin-RevId: 492246728 --- .../python/vision/gesture_recognizer/BUILD | 12 ++++++------ .../gesture_recognizer/gesture_recognizer_demo.py | 2 +- .../gesture_recognizer/gesture_recognizer_test.py | 2 +- .../gesture_recognizer/metadata_writer_test.py | 2 +- .../metadata/custom_gesture_classifier.tflite | Bin .../metadata/custom_gesture_classifier_meta.json | 0 .../call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg | Bin .../call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg | Bin .../call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg | Bin .../call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg | Bin .../call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg | Bin .../call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg | Bin .../call/17d804b5-7118-462d-8191-58d764f591b8.jpg | Bin .../call/1d65a858-623a-4984-9420-958c7e870c3e.jpg | Bin .../call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg | Bin .../call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg | Bin .../four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg | Bin .../four/077fa4bf-a99e-496b-b895-709afc614eec.jpg | Bin .../four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg | Bin .../four/07fdea90-1102-4419-a3af-b394cb29531b.jpg | Bin .../four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg | Bin .../four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg | Bin .../four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg | Bin .../four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg | Bin .../four/249c5023-6106-447a-84ac-17eb4713731b.jpg | Bin .../four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg | Bin .../none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg | Bin .../none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg | Bin .../none/00c84257-800d-4032-9e64-e47eb97005f5.jpg | Bin .../none/0a038096-c14f-46ac-9155-980161ebc440.jpg | Bin .../none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg | Bin .../none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg | Bin .../none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg | Bin .../none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg | Bin .../none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg | Bin .../none/0a787971-9377-4888-803f-aef21863ef7d.jpg | Bin .../rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg | Bin .../rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg | Bin .../rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg | Bin .../rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg | Bin .../rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg | Bin .../rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg | Bin .../rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg | Bin .../rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg | Bin .../rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg | Bin .../rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg | Bin 46 files changed, 9 insertions(+), 9 deletions(-) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/metadata/custom_gesture_classifier.tflite (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/metadata/custom_gesture_classifier_meta.json (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg (100%) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index b9425a181..256447a8d 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -24,9 +24,9 @@ package( # TODO: Remove the unncessary test data once the demo data are moved to an open-sourced # directory. filegroup( - name = "test_data", + name = "testdata", srcs = glob([ - "test_data/**", + "testdata/**", ]), ) @@ -53,7 +53,7 @@ py_test( name = "dataset_test", srcs = ["dataset_test.py"], data = [ - ":test_data", + ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], deps = [ @@ -136,7 +136,7 @@ py_test( size = "large", srcs = ["gesture_recognizer_test.py"], data = [ - ":test_data", + ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], shard_count = 2, @@ -151,7 +151,7 @@ py_test( name = "metadata_writer_test", srcs = ["metadata_writer_test.py"], data = [ - ":test_data", + ":testdata", ], deps = [ ":metadata_writer", @@ -164,7 +164,7 @@ py_binary( name = "gesture_recognizer_demo", srcs = ["gesture_recognizer_demo.py"], data = [ - ":test_data", + ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], python_version = "PY3", diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py index 06075fbc6..1cf9f0619 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py @@ -31,7 +31,7 @@ FLAGS = flags.FLAGS # TODO: Move hand gesture recognizer demo dataset to an # open-sourced directory. -TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data' +TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data' def define_flags(): diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 9cee88362..280fc6a82 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -25,7 +25,7 @@ from mediapipe.model_maker.python.core.utils import test_util from mediapipe.model_maker.python.vision import gesture_recognizer from mediapipe.tasks.python.test import test_utils -_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data' +_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata' tf.keras.backend.experimental.enable_tf_random_generator() diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py index e1101e066..83998141d 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py @@ -23,7 +23,7 @@ from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writ from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer from mediapipe.tasks.python.test import test_utils -_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata" +_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata" _EXPECTED_JSON = test_utils.get_test_data_path( os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier_meta.json")) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier.tflite b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier.tflite similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier.tflite rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier.tflite diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier_meta.json b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier_meta.json similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier_meta.json rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier_meta.json diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg From 1e2cb2b35968100e6ec6cd974c2ec01e7bf6be9e Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Thu, 1 Dec 2022 11:33:15 -0800 Subject: [PATCH 076/346] Internal change PiperOrigin-RevId: 492253867 --- mediapipe/framework/input_stream_handler.cc | 4 +- .../immediate_input_stream_handler_test.cc | 37 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/input_stream_handler.cc b/mediapipe/framework/input_stream_handler.cc index d1dffa414..a7bd9ef43 100644 --- a/mediapipe/framework/input_stream_handler.cc +++ b/mediapipe/framework/input_stream_handler.cc @@ -354,7 +354,9 @@ NodeReadiness SyncSet::GetReadiness(Timestamp* min_stream_timestamp) { } } *min_stream_timestamp = std::min(min_packet, min_bound); - if (*min_stream_timestamp == Timestamp::Done()) { + if (*min_stream_timestamp >= Timestamp::OneOverPostStream()) { + // Either OneOverPostStream or Done indicates no more packets. + *min_stream_timestamp = Timestamp::Done(); last_processed_ts_ = Timestamp::Done().PreviousAllowedInStream(); return NodeReadiness::kReadyForClose; } diff --git a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc index e721afb02..e5de7f0c9 100644 --- a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc @@ -230,6 +230,43 @@ TEST_F(ImmediateInputStreamHandlerTest, StreamDoneReady) { input_stream_handler_->ClearCurrentInputs(cc_); } +// This test checks that the state is ReadyForClose after all streams reach +// Timestamp::Max. +TEST_F(ImmediateInputStreamHandlerTest, ReadyForCloseAfterTimestampMax) { + Timestamp min_stream_timestamp; + std::list packets; + + // One packet arrives, ready for process. + packets.push_back(Adopt(new std::string("packet 1")).At(Timestamp(10))); + input_stream_handler_->AddPackets(name_to_id_["input_a"], packets); + EXPECT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(Timestamp(10), cc_->InputTimestamp()); + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + + // No packets arrive, not ready. + EXPECT_FALSE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(Timestamp::Unset(), cc_->InputTimestamp()); + + // Timestamp::Max arrives, ready for close. + input_stream_handler_->SetNextTimestampBound( + name_to_id_["input_a"], Timestamp::Max().NextAllowedInStream()); + input_stream_handler_->SetNextTimestampBound( + name_to_id_["input_b"], Timestamp::Max().NextAllowedInStream()); + input_stream_handler_->SetNextTimestampBound( + name_to_id_["input_c"], Timestamp::Max().NextAllowedInStream()); + + EXPECT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(Timestamp::Done(), cc_->InputTimestamp()); + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); +} + // This test checks that when any stream is done, the state is ready to close. TEST_F(ImmediateInputStreamHandlerTest, ReadyForClose) { Timestamp min_stream_timestamp; From 40eb0e63858bd6c8746f4d5127a76ebef1f71cf7 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Thu, 1 Dec 2022 12:57:07 -0800 Subject: [PATCH 077/346] Internal change PiperOrigin-RevId: 492276913 --- mediapipe/gpu/multi_pool.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mediapipe/gpu/multi_pool.h b/mediapipe/gpu/multi_pool.h index 8a3cf6be0..e677c3bbf 100644 --- a/mediapipe/gpu/multi_pool.h +++ b/mediapipe/gpu/multi_pool.h @@ -59,6 +59,8 @@ class MultiPool { MultiPool(SimplePoolFactory factory = DefaultMakeSimplePool, MultiPoolOptions options = kDefaultMultiPoolOptions) : create_simple_pool_(factory), options_(options) {} + explicit MultiPool(MultiPoolOptions options) + : MultiPool(DefaultMakeSimplePool, options) {} // Obtains an item. May either be reused or created anew. Item Get(const Spec& spec); From fd79f18aeb41d78966a91dbd38107534c3fb29e8 Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Thu, 1 Dec 2022 14:13:01 -0800 Subject: [PATCH 078/346] Make BaseOptions to pass absolute path to C++ layer. PiperOrigin-RevId: 492296573 --- mediapipe/tasks/python/core/base_options.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/python/core/base_options.py b/mediapipe/tasks/python/core/base_options.py index 122dc620f..b48fa2ccc 100644 --- a/mediapipe/tasks/python/core/base_options.py +++ b/mediapipe/tasks/python/core/base_options.py @@ -14,6 +14,7 @@ """Base options for MediaPipe Task APIs.""" import dataclasses +import os from typing import Any, Optional from mediapipe.tasks.cc.core.proto import base_options_pb2 @@ -49,10 +50,14 @@ class BaseOptions: @doc_controls.do_not_generate_docs def to_pb2(self) -> _BaseOptionsProto: """Generates a BaseOptions protobuf object.""" + if self.model_asset_path is not None: + full_path = os.path.abspath(self.model_asset_path) + else: + full_path = None + return _BaseOptionsProto( model_asset=_ExternalFileProto( - file_name=self.model_asset_path, - file_content=self.model_asset_buffer)) + file_name=full_path, file_content=self.model_asset_buffer)) @classmethod @doc_controls.do_not_generate_docs From af990c3da1633f164ccf8f75edb0683079b0c005 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 1 Dec 2022 14:58:30 -0800 Subject: [PATCH 079/346] Open up the visibility of "//mediapipe/java/com/google/mediapipe/framework/image:image". PiperOrigin-RevId: 492308109 --- mediapipe/java/com/google/mediapipe/framework/image/BUILD | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BUILD b/mediapipe/java/com/google/mediapipe/framework/image/BUILD index bb3be318d..d9508c1f7 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/image/BUILD @@ -20,9 +20,7 @@ android_library( name = "image", srcs = glob(["*.java"]), manifest = "AndroidManifest.xml", - visibility = [ - "//mediapipe:__subpackages__", - ], + visibility = ["//visibility:public"], deps = [ "//third_party:androidx_legacy_support_v4", "//third_party:autovalue", From ead41132a856379a9a7d22f29abe471dc11f2b4a Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 1 Dec 2022 15:00:00 -0800 Subject: [PATCH 080/346] Load model file content from model file path with the help of GetResourceContents in browsers. This can handle the model files that are provided via a custom ResourceProviderFn. PiperOrigin-RevId: 492308453 --- mediapipe/tasks/cc/core/model_resources.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/mediapipe/tasks/cc/core/model_resources.cc b/mediapipe/tasks/cc/core/model_resources.cc index 618761f32..d5c12ee95 100644 --- a/mediapipe/tasks/cc/core/model_resources.cc +++ b/mediapipe/tasks/cc/core/model_resources.cc @@ -99,11 +99,21 @@ const tflite::Model* ModelResources::GetTfLiteModel() const { absl::Status ModelResources::BuildModelFromExternalFileProto() { if (model_file_->has_file_name()) { +#ifdef __EMSCRIPTEN__ + // In browsers, the model file may require a custom ResourceProviderFn to + // provide the model content. The open() method may not work in this case. + // Thus, loading the model content from the model file path in advance with + // the help of GetResourceContents. + MP_RETURN_IF_ERROR(mediapipe::GetResourceContents( + model_file_->file_name(), model_file_->mutable_file_content())); + model_file_->clear_file_name(); +#else // If the model file name is a relative path, searches the file in a // platform-specific location and returns the absolute path on success. ASSIGN_OR_RETURN(std::string path_to_resource, mediapipe::PathToResourceAsFile(model_file_->file_name())); model_file_->set_file_name(path_to_resource); +#endif // __EMSCRIPTEN__ } ASSIGN_OR_RETURN( model_file_handler_, From 768d2dc548f123246d34fe258d6ab75d05c51d3e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 1 Dec 2022 16:47:05 -0800 Subject: [PATCH 081/346] Separate web and java api landmark and world landmark to two classes. This makes the platforms interface consistent. PiperOrigin-RevId: 492332990 --- .../tasks/components/containers/BUILD | 9 +++ .../tasks/components/containers/Landmark.java | 20 +++--- .../containers/NormalizedLandmark.java | 63 +++++++++++++++++++ .../com/google/mediapipe/tasks/vision/BUILD | 2 + .../GestureRecognizerResult.java | 45 ++++++------- .../handlandmarker/HandLandmarkerResult.java | 52 +++++++-------- .../GestureRecognizerTest.java | 4 +- .../handlandmarker/HandLandmarkerTest.java | 4 +- .../web/components/containers/landmark.d.ts | 28 ++++++--- .../gesture_recognizer/gesture_recognizer.ts | 12 ++-- .../gesture_recognizer_result.d.ts | 4 +- .../vision/hand_landmarker/hand_landmarker.ts | 10 ++- .../hand_landmarker_result.d.ts | 4 +- 13 files changed, 161 insertions(+), 96 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index d6e6ac740..ad17d5552 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -83,6 +83,15 @@ android_library( ], ) +android_library( + name = "normalized_landmark", + srcs = ["NormalizedLandmark.java"], + deps = [ + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + # Expose the java source files for building mediapipe tasks core AAR. filegroup( name = "java_src", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java index e45866190..7fb1b99d0 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java @@ -18,16 +18,16 @@ import com.google.auto.value.AutoValue; import java.util.Objects; /** - * Landmark represents a point in 3D space with x, y, z coordinates. If normalized is true, the - * landmark coordinates is normalized respect to the dimension of image, and the coordinates values - * are in the range of [0,1]. Otherwise, it represenet a point in world coordinates. + * Landmark represents a point in 3D space with x, y, z coordinates. The landmark coordinates are in + * meters. z represents the landmark depth, and the smaller the value the closer the world landmark + * is to the camera. */ @AutoValue public abstract class Landmark { private static final float TOLERANCE = 1e-6f; - public static Landmark create(float x, float y, float z, boolean normalized) { - return new AutoValue_Landmark(x, y, z, normalized); + public static Landmark create(float x, float y, float z) { + return new AutoValue_Landmark(x, y, z); } // The x coordinates of the landmark. @@ -39,28 +39,24 @@ public abstract class Landmark { // The z coordinates of the landmark. public abstract float z(); - // Whether this landmark is normalized with respect to the image size. - public abstract boolean normalized(); - @Override public final boolean equals(Object o) { if (!(o instanceof Landmark)) { return false; } Landmark other = (Landmark) o; - return other.normalized() == this.normalized() - && Math.abs(other.x() - this.x()) < TOLERANCE + return Math.abs(other.x() - this.x()) < TOLERANCE && Math.abs(other.x() - this.y()) < TOLERANCE && Math.abs(other.x() - this.z()) < TOLERANCE; } @Override public final int hashCode() { - return Objects.hash(x(), y(), z(), normalized()); + return Objects.hash(x(), y(), z()); } @Override public final String toString() { - return ""; + return ""; } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java new file mode 100644 index 000000000..e77f3c3d4 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java @@ -0,0 +1,63 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import com.google.auto.value.AutoValue; +import java.util.Objects; + +/** + * Normalized Landmark represents a point in 3D space with x, y, z coordinates. x and y are + * normalized to [0.0, 1.0] by the image width and height respectively. z represents the landmark + * depth, and the smaller the value the closer the landmark is to the camera. The magnitude of z + * uses roughly the same scale as x. + */ +@AutoValue +public abstract class NormalizedLandmark { + private static final float TOLERANCE = 1e-6f; + + public static NormalizedLandmark create(float x, float y, float z) { + return new AutoValue_NormalizedLandmark(x, y, z); + } + + // The x coordinates of the normalized landmark. + public abstract float x(); + + // The y coordinates of the normalized landmark. + public abstract float y(); + + // The z coordinates of the normalized landmark. + public abstract float z(); + + @Override + public final boolean equals(Object o) { + if (!(o instanceof NormalizedLandmark)) { + return false; + } + NormalizedLandmark other = (NormalizedLandmark) o; + return Math.abs(other.x() - this.x()) < TOLERANCE + && Math.abs(other.x() - this.y()) < TOLERANCE + && Math.abs(other.x() - this.z()) < TOLERANCE; + } + + @Override + public final int hashCode() { + return Objects.hash(x(), y(), z()); + } + + @Override + public final String toString() { + return ""; + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index b61c174fe..6161fe032 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -135,6 +135,7 @@ android_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", @@ -167,6 +168,7 @@ android_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:androidx_annotation_annotation", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java index ef76bf226..90b92175d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java @@ -15,13 +15,12 @@ package com.google.mediapipe.tasks.vision.gesturerecognizer; import com.google.auto.value.AutoValue; -import com.google.mediapipe.formats.proto.LandmarkProto.Landmark; -import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto; import com.google.mediapipe.formats.proto.ClassificationProto.Classification; import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; @@ -43,41 +42,36 @@ public abstract class GestureRecognizerResult implements TaskResult { * @param gesturesProto a List of {@link ClassificationList} */ static GestureRecognizerResult create( - List landmarksProto, - List worldLandmarksProto, + List landmarksProto, + List worldLandmarksProto, List handednessesProto, List gesturesProto, long timestampMs) { - List> multiHandLandmarks = - new ArrayList<>(); - List> multiHandWorldLandmarks = - new ArrayList<>(); + List> multiHandLandmarks = new ArrayList<>(); + List> multiHandWorldLandmarks = new ArrayList<>(); List> multiHandHandednesses = new ArrayList<>(); List> multiHandGestures = new ArrayList<>(); - for (NormalizedLandmarkList handLandmarksProto : landmarksProto) { - List handLandmarks = - new ArrayList<>(); + for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProto) { + List handLandmarks = new ArrayList<>(); multiHandLandmarks.add(handLandmarks); - for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) { + for (LandmarkProto.NormalizedLandmark handLandmarkProto : + handLandmarksProto.getLandmarkList()) { handLandmarks.add( - com.google.mediapipe.tasks.components.containers.Landmark.create( - handLandmarkProto.getX(), - handLandmarkProto.getY(), - handLandmarkProto.getZ(), - true)); + com.google.mediapipe.tasks.components.containers.NormalizedLandmark.create( + handLandmarkProto.getX(), handLandmarkProto.getY(), handLandmarkProto.getZ())); } } - for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) { + for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) { List handWorldLandmarks = new ArrayList<>(); multiHandWorldLandmarks.add(handWorldLandmarks); - for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) { + for (LandmarkProto.Landmark handWorldLandmarkProto : + handWorldLandmarksProto.getLandmarkList()) { handWorldLandmarks.add( com.google.mediapipe.tasks.components.containers.Landmark.create( handWorldLandmarkProto.getX(), handWorldLandmarkProto.getY(), - handWorldLandmarkProto.getZ(), - false)); + handWorldLandmarkProto.getZ())); } } for (ClassificationList handednessProto : handednessesProto) { @@ -118,11 +112,10 @@ public abstract class GestureRecognizerResult implements TaskResult { public abstract long timestampMs(); /** Hand landmarks of detected hands. */ - public abstract List> landmarks(); + public abstract List> landmarks(); /** Hand landmarks in world coordniates of detected hands. */ - public abstract List> - worldLandmarks(); + public abstract List> worldLandmarks(); /** Handedness of detected hands. */ public abstract List> handednesses(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java index 2889b0e0b..9092c0a2d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java @@ -15,13 +15,12 @@ package com.google.mediapipe.tasks.vision.handlandmarker; import com.google.auto.value.AutoValue; -import com.google.mediapipe.formats.proto.LandmarkProto.Landmark; -import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto; import com.google.mediapipe.formats.proto.ClassificationProto.Classification; import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; @@ -32,47 +31,41 @@ import java.util.List; public abstract class HandLandmarkerResult implements TaskResult { /** - * Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and - * handedness protobuf messages. + * Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and handedness + * protobuf messages. * * @param landmarksProto a List of {@link NormalizedLandmarkList} * @param worldLandmarksProto a List of {@link LandmarkList} * @param handednessesProto a List of {@link ClassificationList} */ static HandLandmarkerResult create( - List landmarksProto, - List worldLandmarksProto, + List landmarksProto, + List worldLandmarksProto, List handednessesProto, long timestampMs) { - List> multiHandLandmarks = - new ArrayList<>(); - List> multiHandWorldLandmarks = - new ArrayList<>(); + List> multiHandLandmarks = new ArrayList<>(); + List> multiHandWorldLandmarks = new ArrayList<>(); List> multiHandHandednesses = new ArrayList<>(); - for (NormalizedLandmarkList handLandmarksProto : landmarksProto) { - List handLandmarks = - new ArrayList<>(); + for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProto) { + List handLandmarks = new ArrayList<>(); multiHandLandmarks.add(handLandmarks); - for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) { + for (LandmarkProto.NormalizedLandmark handLandmarkProto : + handLandmarksProto.getLandmarkList()) { handLandmarks.add( - com.google.mediapipe.tasks.components.containers.Landmark.create( - handLandmarkProto.getX(), - handLandmarkProto.getY(), - handLandmarkProto.getZ(), - true)); + NormalizedLandmark.create( + handLandmarkProto.getX(), handLandmarkProto.getY(), handLandmarkProto.getZ())); } } - for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) { - List handWorldLandmarks = - new ArrayList<>(); + for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) { + List handWorldLandmarks = new ArrayList<>(); multiHandWorldLandmarks.add(handWorldLandmarks); - for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) { + for (LandmarkProto.Landmark handWorldLandmarkProto : + handWorldLandmarksProto.getLandmarkList()) { handWorldLandmarks.add( com.google.mediapipe.tasks.components.containers.Landmark.create( handWorldLandmarkProto.getX(), handWorldLandmarkProto.getY(), - handWorldLandmarkProto.getZ(), - false)); + handWorldLandmarkProto.getZ())); } } for (ClassificationList handednessProto : handednessesProto) { @@ -98,11 +91,10 @@ public abstract class HandLandmarkerResult implements TaskResult { public abstract long timestampMs(); /** Hand landmarks of detected hands. */ - public abstract List> landmarks(); + public abstract List> landmarks(); /** Hand landmarks in world coordniates of detected hands. */ - public abstract List> - worldLandmarks(); + public abstract List> worldLandmarks(); /** Handedness of detected hands. */ public abstract List> handednesses(); diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java index c0be4cffe..5821b36cc 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java @@ -28,7 +28,7 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; @@ -603,7 +603,7 @@ public class GestureRecognizerTest { assertThat(actualResult.landmarks().get(0)) .comparingElementsUsing( Correspondence.from( - (Correspondence.BinaryPredicate) + (Correspondence.BinaryPredicate) (actual, expected) -> { return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) .compare(actual.x(), expected.x()) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java index 9e12d210f..c313d385d 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java @@ -27,7 +27,7 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; @@ -399,7 +399,7 @@ public class HandLandmarkerTest { assertThat(actualResult.landmarks().get(0)) .comparingElementsUsing( Correspondence.from( - (Correspondence.BinaryPredicate) + (Correspondence.BinaryPredicate) (actual, expected) -> { return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) .compare(actual.x(), expected.x()) diff --git a/mediapipe/tasks/web/components/containers/landmark.d.ts b/mediapipe/tasks/web/components/containers/landmark.d.ts index c887303d0..0f916bf88 100644 --- a/mediapipe/tasks/web/components/containers/landmark.d.ts +++ b/mediapipe/tasks/web/components/containers/landmark.d.ts @@ -15,10 +15,27 @@ */ /** - * Landmark represents a point in 3D space with x, y, z coordinates. If - * normalized is true, the landmark coordinates is normalized respect to the - * dimension of image, and the coordinates values are in the range of [0,1]. - * Otherwise, it represenet a point in world coordinates. + * Normalized Landmark represents a point in 3D space with x, y, z coordinates. + * x and y are normalized to [0.0, 1.0] by the image width and height + * respectively. z represents the landmark depth, and the smaller the value the + * closer the landmark is to the camera. The magnitude of z uses roughly the + * same scale as x. + */ +export declare interface NormalizedLandmark { + /** The x coordinates of the normalized landmark. */ + x: number; + + /** The y coordinates of the normalized landmark. */ + y: number; + + /** The z coordinates of the normalized landmark. */ + z: number; +} + +/** + * Landmark represents a point in 3D space with x, y, z coordinates. The + * landmark coordinates are in meters. z represents the landmark depth, + * and the smaller the value the closer the world landmark is to the camera. */ export declare interface Landmark { /** The x coordinates of the landmark. */ @@ -29,7 +46,4 @@ export declare interface Landmark { /** The z coordinates of the landmark. */ z: number; - - /** Whether this landmark is normalized with respect to the image size. */ - normalized: boolean; } diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 9ec63b07a..15b6acb1a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -27,7 +27,7 @@ import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detecto import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; @@ -67,7 +67,7 @@ FULL_IMAGE_RECT.setHeight(1); export class GestureRecognizer extends VisionTaskRunner { private gestures: Category[][] = []; - private landmarks: Landmark[][] = []; + private landmarks: NormalizedLandmark[][] = []; private worldLandmarks: Landmark[][] = []; private handednesses: Category[][] = []; @@ -306,13 +306,12 @@ export class GestureRecognizer extends for (const binaryProto of data) { const handLandmarksProto = NormalizedLandmarkList.deserializeBinary(binaryProto); - const landmarks: Landmark[] = []; + const landmarks: NormalizedLandmark[] = []; for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) { landmarks.push({ x: handLandmarkProto.getX() ?? 0, y: handLandmarkProto.getY() ?? 0, - z: handLandmarkProto.getZ() ?? 0, - normalized: true + z: handLandmarkProto.getZ() ?? 0 }); } this.landmarks.push(landmarks); @@ -333,8 +332,7 @@ export class GestureRecognizer extends worldLandmarks.push({ x: handWorldLandmarkProto.getX() ?? 0, y: handWorldLandmarkProto.getY() ?? 0, - z: handWorldLandmarkProto.getZ() ?? 0, - normalized: false + z: handWorldLandmarkProto.getZ() ?? 0 }); } this.worldLandmarks.push(worldLandmarks); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts index 7c295c9e9..e570270b2 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts @@ -15,14 +15,14 @@ */ import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; /** * Represents the gesture recognition results generated by `GestureRecognizer`. */ export declare interface GestureRecognizerResult { /** Hand landmarks of detected hands. */ - landmarks: Landmark[][]; + landmarks: NormalizedLandmark[][]; /** Hand landmarks in world coordniates of detected hands. */ worldLandmarks: Landmark[][]; diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 290f49455..c657275bf 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -24,7 +24,7 @@ import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detecto import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; @@ -59,7 +59,7 @@ FULL_IMAGE_RECT.setHeight(1); /** Performs hand landmarks detection on images. */ export class HandLandmarker extends VisionTaskRunner { - private landmarks: Landmark[][] = []; + private landmarks: NormalizedLandmark[][] = []; private worldLandmarks: Landmark[][] = []; private handednesses: Category[][] = []; @@ -255,13 +255,12 @@ export class HandLandmarker extends VisionTaskRunner { for (const binaryProto of data) { const handLandmarksProto = NormalizedLandmarkList.deserializeBinary(binaryProto); - const landmarks: Landmark[] = []; + const landmarks: NormalizedLandmark[] = []; for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) { landmarks.push({ x: handLandmarkProto.getX() ?? 0, y: handLandmarkProto.getY() ?? 0, z: handLandmarkProto.getZ() ?? 0, - normalized: true }); } this.landmarks.push(landmarks); @@ -269,7 +268,7 @@ export class HandLandmarker extends VisionTaskRunner { } /** - * Converts raw data into a landmark, and adds it to our worldLandmarks + * Converts raw data into a world landmark, and adds it to our worldLandmarks * list. */ private adddJsWorldLandmarks(data: Uint8Array[]): void { @@ -283,7 +282,6 @@ export class HandLandmarker extends VisionTaskRunner { x: handWorldLandmarkProto.getX() ?? 0, y: handWorldLandmarkProto.getY() ?? 0, z: handWorldLandmarkProto.getZ() ?? 0, - normalized: false }); } this.worldLandmarks.push(worldLandmarks); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts index 044bdfbe7..89f867d69 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts @@ -15,14 +15,14 @@ */ import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; /** * Represents the hand landmarks deection results generated by `HandLandmarker`. */ export declare interface HandLandmarkerResult { /** Hand landmarks of detected hands. */ - landmarks: Landmark[][]; + landmarks: NormalizedLandmark[][]; /** Hand landmarks in world coordniates of detected hands. */ worldLandmarks: Landmark[][]; From dabc2af15baad67d92ac5e9d1b2b2a588167664f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 2 Dec 2022 12:04:06 -0800 Subject: [PATCH 082/346] Fix base bath loading in Fileset resolver PiperOrigin-RevId: 492526041 --- mediapipe/tasks/web/core/fileset_resolver.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/web/core/fileset_resolver.ts b/mediapipe/tasks/web/core/fileset_resolver.ts index 7d68dbc16..d4691243b 100644 --- a/mediapipe/tasks/web/core/fileset_resolver.ts +++ b/mediapipe/tasks/web/core/fileset_resolver.ts @@ -48,16 +48,16 @@ async function createFileset( if (await isSimdSupported()) { return { wasmLoaderPath: - `/${basePath}/${taskName}_wasm_internal.js`, + `${basePath}/${taskName}_wasm_internal.js`, wasmBinaryPath: - `/${basePath}/${taskName}_wasm_internal.wasm`, + `${basePath}/${taskName}_wasm_internal.wasm`, }; } else { return { wasmLoaderPath: - `/${basePath}/${taskName}_wasm_nosimd_internal.js`, - wasmBinaryPath: `/${basePath}/${ - taskName}_wasm_nosimd_internal.wasm`, + `${basePath}/${taskName}_wasm_nosimd_internal.js`, + wasmBinaryPath: + `${basePath}/${taskName}_wasm_nosimd_internal.wasm`, }; } } From da9587033d118eb58672f25c8f2e541ba7037209 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 2 Dec 2022 12:40:59 -0800 Subject: [PATCH 083/346] Move shared code to TaskRunner PiperOrigin-RevId: 492534879 --- .../tasks/web/audio/audio_classifier/BUILD | 3 +- .../audio_classifier/audio_classifier.ts | 38 ++++++++------ .../audio_classifier_options.d.ts | 4 +- .../tasks/web/audio/audio_embedder/BUILD | 1 - .../audio/audio_embedder/audio_embedder.ts | 48 ++++++++--------- .../audio_embedder_options.d.ts | 4 +- mediapipe/tasks/web/audio/core/BUILD | 13 +---- .../web/audio/core/audio_task_options.d.ts | 23 --------- .../tasks/web/audio/core/audio_task_runner.ts | 17 +------ .../tasks/web/components/processors/BUILD | 1 - .../web/components/processors/base_options.ts | 2 +- mediapipe/tasks/web/core/BUILD | 8 +-- .../tasks/web/core/classifier_options.d.ts | 2 - .../tasks/web/core/embedder_options.d.ts | 2 - mediapipe/tasks/web/core/task_runner.ts | 43 ++++++++++------ ..._options.d.ts => task_runner_options.d.ts} | 8 ++- mediapipe/tasks/web/text/core/BUILD | 11 ---- .../web/text/core/text_task_options.d.ts | 23 --------- .../tasks/web/text/text_classifier/BUILD | 5 +- .../text/text_classifier/text_classifier.ts | 51 +++++++++++-------- .../text_classifier_options.d.ts | 4 +- mediapipe/tasks/web/text/text_embedder/BUILD | 4 +- .../web/text/text_embedder/text_embedder.ts | 51 +++++++++++-------- .../text_embedder/text_embedder_options.d.ts | 4 +- mediapipe/tasks/web/vision/core/BUILD | 2 - .../web/vision/core/vision_task_options.d.ts | 8 +-- .../web/vision/core/vision_task_runner.ts | 15 ++---- .../gesture_recognizer/gesture_recognizer.ts | 30 +++++------ .../vision/hand_landmarker/hand_landmarker.ts | 30 +++++------ .../image_classifier/image_classifier.ts | 38 ++++++++------ .../vision/image_embedder/image_embedder.ts | 38 ++++++++------ .../vision/object_detector/object_detector.ts | 36 +++++++------ 32 files changed, 262 insertions(+), 305 deletions(-) delete mode 100644 mediapipe/tasks/web/audio/core/audio_task_options.d.ts rename mediapipe/tasks/web/core/{base_options.d.ts => task_runner_options.d.ts} (85%) delete mode 100644 mediapipe/tasks/web/text/core/BUILD delete mode 100644 mediapipe/tasks/web/text/core/text_task_options.d.ts diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index c419d3b98..6f785dd0d 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) @@ -36,7 +36,6 @@ mediapipe_ts_declaration( "audio_classifier_result.d.ts", ], deps = [ - "//mediapipe/tasks/web/audio/core:audio_task_options", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index e606019f2..4e12780d2 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -22,8 +22,8 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {AudioClassifierOptions} from './audio_classifier_options'; @@ -56,13 +56,12 @@ export class AudioClassifier extends AudioTaskRunner { * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions): Promise { - const classifier = await TaskRunner.createInstance( - AudioClassifier, /* initializeCanvas= */ false, wasmFileset); - await classifier.setOptions(audioClassifierOptions); - return classifier; + return AudioTaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset, + audioClassifierOptions); } /** @@ -75,8 +74,9 @@ export class AudioClassifier extends AudioTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return AudioClassifier.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return AudioTaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -86,20 +86,26 @@ export class AudioClassifier extends AudioTaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return AudioClassifier.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return AudioTaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts index 975b1e315..dc3c494bf 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts @@ -14,9 +14,9 @@ * limitations under the License. */ -import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options'; import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Options to configure the MediaPipe Audio Classifier Task */ export declare interface AudioClassifierOptions extends ClassifierOptions, - AudioTaskOptions {} + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index 1a66464bd..0555bb639 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -36,7 +36,6 @@ mediapipe_ts_declaration( "audio_embedder_result.d.ts", ], deps = [ - "//mediapipe/tasks/web/audio/core:audio_task_options", "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index c87aceabe..d08eb4791 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -25,7 +25,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {AudioEmbedderOptions} from './audio_embedder_options'; @@ -58,23 +58,12 @@ export class AudioEmbedder extends AudioTaskRunner { * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, audioEmbedderOptions: AudioEmbedderOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmFileset.wasmBinaryPath.toString(); - } - }; - - const embedder = await createMediaPipeLib( - AudioEmbedder, wasmFileset.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await embedder.setOptions(audioEmbedderOptions); - return embedder; + return AudioTaskRunner.createInstance( + AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, + audioEmbedderOptions); } /** @@ -87,8 +76,9 @@ export class AudioEmbedder extends AudioTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return AudioEmbedder.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return AudioTaskRunner.createInstance( + AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -98,20 +88,26 @@ export class AudioEmbedder extends AudioTaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return AudioEmbedder.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return AudioTaskRunner.createInstance( + AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts index 98f412d0f..ac22728ab 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts @@ -14,9 +14,9 @@ * limitations under the License. */ -import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options'; import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Options to configure the MediaPipe Audio Embedder Task */ export declare interface AudioEmbedderOptions extends EmbedderOptions, - AudioTaskOptions {} + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/audio/core/BUILD b/mediapipe/tasks/web/audio/core/BUILD index 91ebbf524..9ab6c7bee 100644 --- a/mediapipe/tasks/web/audio/core/BUILD +++ b/mediapipe/tasks/web/audio/core/BUILD @@ -1,24 +1,13 @@ # This package contains options shared by all MediaPipe Audio Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) -mediapipe_ts_declaration( - name = "audio_task_options", - srcs = ["audio_task_options.d.ts"], - deps = [ - "//mediapipe/tasks/web/core", - ], -) - mediapipe_ts_library( name = "audio_task_runner", srcs = ["audio_task_runner.ts"], deps = [ - ":audio_task_options", - "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", ], diff --git a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts b/mediapipe/tasks/web/audio/core/audio_task_options.d.ts deleted file mode 100644 index e3068625d..000000000 --- a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts +++ /dev/null @@ -1,23 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {BaseOptions} from '../../../../tasks/web/core/base_options'; - -/** The options for configuring a MediaPipe Audio Task. */ -export declare interface AudioTaskOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; -} diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts index ceff3895b..00cfe0253 100644 --- a/mediapipe/tasks/web/audio/core/audio_task_runner.ts +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -14,26 +14,13 @@ * limitations under the License. */ -import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; - -import {AudioTaskOptions} from './audio_task_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Base class for all MediaPipe Audio Tasks. */ -export abstract class AudioTaskRunner extends TaskRunner { - protected abstract baseOptions?: BaseOptionsProto|undefined; +export abstract class AudioTaskRunner extends TaskRunner { private defaultSampleRate = 48000; - /** Configures the shared options of an audio task. */ - async setOptions(options: AudioTaskOptions): Promise { - this.baseOptions = this.baseOptions ?? new BaseOptionsProto(); - if (options.baseOptions) { - this.baseOptions = await convertBaseOptionsToProto( - options.baseOptions, this.baseOptions); - } - } - /** * Sets the sample rate for API calls that omit an explicit sample rate. * `48000` is used as a default if this method is not called. diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index 1b56bf4c9..86e743928 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -17,7 +17,6 @@ mediapipe_ts_library( name = "classifier_result", srcs = ["classifier_result.ts"], deps = [ - "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/web/components/containers:classification_result", ], diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts index ac24a8db6..16d562262 100644 --- a/mediapipe/tasks/web/components/processors/base_options.ts +++ b/mediapipe/tasks/web/components/processors/base_options.ts @@ -18,7 +18,7 @@ import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inferen import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; -import {BaseOptions} from '../../../../tasks/web/core/base_options'; +import {BaseOptions} from '../../../../tasks/web/core/task_runner_options'; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index d709e3409..de429690d 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -7,18 +7,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_declaration( name = "core", srcs = [ - "base_options.d.ts", + "task_runner_options.d.ts", "wasm_fileset.d.ts", ], ) mediapipe_ts_library( name = "task_runner", - srcs = [ - "task_runner.ts", - ], + srcs = ["task_runner.ts"], deps = [ ":core", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", diff --git a/mediapipe/tasks/web/core/classifier_options.d.ts b/mediapipe/tasks/web/core/classifier_options.d.ts index 1d804d629..08e7a7664 100644 --- a/mediapipe/tasks/web/core/classifier_options.d.ts +++ b/mediapipe/tasks/web/core/classifier_options.d.ts @@ -14,8 +14,6 @@ * limitations under the License. */ -import {BaseOptions} from '../../../tasks/web/core/base_options'; - /** Options to configure a MediaPipe Classifier Task. */ export declare interface ClassifierOptions { /** diff --git a/mediapipe/tasks/web/core/embedder_options.d.ts b/mediapipe/tasks/web/core/embedder_options.d.ts index 3ec2a170c..8669acfcb 100644 --- a/mediapipe/tasks/web/core/embedder_options.d.ts +++ b/mediapipe/tasks/web/core/embedder_options.d.ts @@ -14,8 +14,6 @@ * limitations under the License. */ -import {BaseOptions} from '../../../tasks/web/core/base_options'; - /** Options to configure a MediaPipe Embedder Task */ export declare interface EmbedderOptions { /** diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 4085be697..c2691fc76 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -14,6 +14,9 @@ * limitations under the License. */ +import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; +import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options'; +import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; @@ -28,7 +31,9 @@ const WasmMediaPipeImageLib = SupportModelResourcesGraphService(SupportImage(GraphRunner)); /** Base class for all MediaPipe Tasks. */ -export abstract class TaskRunner extends WasmMediaPipeImageLib { +export abstract class TaskRunner extends + WasmMediaPipeImageLib { + protected abstract baseOptions: BaseOptionsProto; private processingErrors: Error[] = []; /** @@ -36,9 +41,10 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { * supported and loads the relevant WASM binary. * @return A fully instantiated instance of `T`. */ - protected static async createInstance( + protected static async createInstance, + O extends TaskRunnerOptions>( type: WasmMediaPipeConstructor, initializeCanvas: boolean, - fileset: WasmFileset): Promise { + fileset: WasmFileset, options: O): Promise { const fileLocator: FileLocator = { locateFile() { // The only file loaded with this mechanism is the Wasm binary @@ -46,19 +52,16 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { } }; - if (initializeCanvas) { - // Fall back to an OffscreenCanvas created by the GraphRunner if - // OffscreenCanvas is available - const canvas = typeof OffscreenCanvas === 'undefined' ? - document.createElement('canvas') : - undefined; - return createMediaPipeLib( - type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); - } else { - return createMediaPipeLib( - type, fileset.wasmLoaderPath, NO_ASSETS, /* glCanvas= */ null, - fileLocator); - } + // Initialize a canvas if requested. If OffscreenCanvas is availble, we + // let the graph runner initialize it by passing `undefined`. + const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? + document.createElement('canvas') : + undefined) : + null; + const instance = await createMediaPipeLib( + type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); + await instance.setOptions(options); + return instance; } constructor( @@ -74,6 +77,14 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { this.registerModelResourcesGraphService(); } + /** Configures the shared options of a MediaPipe Task. */ + async setOptions(options: O): Promise { + if (options.baseOptions) { + this.baseOptions = await convertBaseOptionsToProto( + options.baseOptions, this.baseOptions); + } + } + /** * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run * over the video stream. Will replace the previously running MediaPipe graph, diff --git a/mediapipe/tasks/web/core/base_options.d.ts b/mediapipe/tasks/web/core/task_runner_options.d.ts similarity index 85% rename from mediapipe/tasks/web/core/base_options.d.ts rename to mediapipe/tasks/web/core/task_runner_options.d.ts index 86635b8c7..aa0b4a028 100644 --- a/mediapipe/tasks/web/core/base_options.d.ts +++ b/mediapipe/tasks/web/core/task_runner_options.d.ts @@ -16,7 +16,7 @@ // Placeholder for internal dependency on trusted resource url -/** Options to configure MediaPipe Tasks in general. */ +/** Options to configure MediaPipe model loading and processing. */ export declare interface BaseOptions { /** * The model path to the model asset file. Only one of `modelAssetPath` or @@ -33,3 +33,9 @@ export declare interface BaseOptions { /** Overrides the default backend to use for the provided model. */ delegate?: 'cpu'|'gpu'|undefined; } + +/** Options to configure MediaPipe Tasks in general. */ +export declare interface TaskRunnerOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; +} diff --git a/mediapipe/tasks/web/text/core/BUILD b/mediapipe/tasks/web/text/core/BUILD deleted file mode 100644 index 3e7faec93..000000000 --- a/mediapipe/tasks/web/text/core/BUILD +++ /dev/null @@ -1,11 +0,0 @@ -# This package contains options shared by all MediaPipe Texxt Tasks for Web. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -mediapipe_ts_declaration( - name = "text_task_options", - srcs = ["text_task_options.d.ts"], - deps = ["//mediapipe/tasks/web/core"], -) diff --git a/mediapipe/tasks/web/text/core/text_task_options.d.ts b/mediapipe/tasks/web/text/core/text_task_options.d.ts deleted file mode 100644 index 4874e35bf..000000000 --- a/mediapipe/tasks/web/text/core/text_task_options.d.ts +++ /dev/null @@ -1,23 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {BaseOptions} from '../../../../tasks/web/core/base_options'; - -/** The options for configuring a MediaPipe Text task. */ -export declare interface TextTaskOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; -} diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index f3d272daa..2a7de21d6 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -17,15 +17,16 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) @@ -38,7 +39,7 @@ mediapipe_ts_declaration( deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/text/core:text_task_options", ], ) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 197869a36..bd2a207ce 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -17,12 +17,13 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {TextClassifierOptions} from './text_classifier_options'; @@ -40,7 +41,7 @@ const TEXT_CLASSIFIER_GRAPH = // tslint:disable:jspb-use-builder-pattern /** Performs Natural Language classification. */ -export class TextClassifier extends TaskRunner { +export class TextClassifier extends TaskRunner { private classificationResult: TextClassifierResult = {classifications: []}; private readonly options = new TextClassifierGraphOptions(); @@ -53,13 +54,12 @@ export class TextClassifier extends TaskRunner { * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, textClassifierOptions: TextClassifierOptions): Promise { - const classifier = await TaskRunner.createInstance( - TextClassifier, /* initializeCanvas= */ false, wasmFileset); - await classifier.setOptions(textClassifierOptions); - return classifier; + return TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset, + textClassifierOptions); } /** @@ -72,8 +72,9 @@ export class TextClassifier extends TaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return TextClassifier.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -83,13 +84,19 @@ export class TextClassifier extends TaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return TextClassifier.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } /** @@ -101,18 +108,20 @@ export class TextClassifier extends TaskRunner { * * @param options The options for the text classifier. */ - async setOptions(options: TextClassifierOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override async setOptions(options: TextClassifierOptions): Promise { + await super.setOptions(options); this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); this.refreshGraph(); } + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } /** * Performs Natural Language classification on the provided text and waits diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts index b50767e1a..25592deb5 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts @@ -15,8 +15,8 @@ */ import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; -import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Options to configure the MediaPipe Text Classifier Task */ export declare interface TextClassifierOptions extends ClassifierOptions, - TextTaskOptions {} + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index b858f6b83..17d105258 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -17,15 +17,16 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:embedding_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) @@ -39,6 +40,5 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", - "//mediapipe/tasks/web/text/core:text_task_options", ], ) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 511fd2411..d2899fbe2 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -17,14 +17,15 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb'; import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {TextEmbedderOptions} from './text_embedder_options'; @@ -44,7 +45,7 @@ const TEXT_EMBEDDER_CALCULATOR = /** * Performs embedding extraction on text. */ -export class TextEmbedder extends TaskRunner { +export class TextEmbedder extends TaskRunner { private embeddingResult: TextEmbedderResult = {embeddings: []}; private readonly options = new TextEmbedderGraphOptionsProto(); @@ -57,13 +58,12 @@ export class TextEmbedder extends TaskRunner { * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, textEmbedderOptions: TextEmbedderOptions): Promise { - const embedder = await TaskRunner.createInstance( - TextEmbedder, /* initializeCanvas= */ false, wasmFileset); - await embedder.setOptions(textEmbedderOptions); - return embedder; + return TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset, + textEmbedderOptions); } /** @@ -76,8 +76,9 @@ export class TextEmbedder extends TaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return TextEmbedder.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -87,13 +88,19 @@ export class TextEmbedder extends TaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return TextEmbedder.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } /** @@ -105,17 +112,21 @@ export class TextEmbedder extends TaskRunner { * * @param options The options for the text embedder. */ - async setOptions(options: TextEmbedderOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } + override async setOptions(options: TextEmbedderOptions): Promise { + await super.setOptions(options); this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); this.refreshGraph(); } + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + /** * Performs embeding extraction on the provided text and waits synchronously * for the response. diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts index 9ea570304..7689ee0c1 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts @@ -15,8 +15,8 @@ */ import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; -import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Options to configure the MediaPipe Text Embedder Task */ export declare interface TextEmbedderOptions extends EmbedderOptions, - TextTaskOptions {} + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 1d8944f14..b389a9b01 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -17,8 +17,6 @@ mediapipe_ts_library( srcs = ["vision_task_runner.ts"], deps = [ ":vision_task_options", - "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", diff --git a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts index e04eb6596..76c0177a0 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import {BaseOptions} from '../../../../tasks/web/core/base_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** * The two running modes of a vision task. @@ -23,12 +23,8 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options'; */ export type RunningMode = 'image'|'video'; - /** The options for configuring a MediaPipe vision task. */ -export declare interface VisionTaskOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - +export declare interface VisionTaskOptions extends TaskRunnerOptions { /** * The running mode of the task. Default to the image mode. * Vision tasks have two running modes: diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 79ff45156..78b4859f2 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -14,24 +14,17 @@ * limitations under the License. */ -import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {VisionTaskOptions} from './vision_task_options'; /** Base class for all MediaPipe Vision Tasks. */ -export abstract class VisionTaskRunner extends TaskRunner { - protected abstract baseOptions?: BaseOptionsProto|undefined; - +export abstract class VisionTaskRunner extends + TaskRunner { /** Configures the shared options of a vision task. */ - async setOptions(options: VisionTaskOptions): Promise { - this.baseOptions = this.baseOptions ?? new BaseOptionsProto(); - if (options.baseOptions) { - this.baseOptions = await convertBaseOptionsToProto( - options.baseOptions, this.baseOptions); - } + override async setOptions(options: VisionTaskOptions): Promise { + await super.setOptions(options); if ('runningMode' in options) { const useStreamMode = !!options.runningMode && options.runningMode !== 'image'; diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 15b6acb1a..8baee5ce3 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -88,14 +88,13 @@ export class GestureRecognizer extends * Note that either a path to the model asset or a model buffer needs to * be provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, gestureRecognizerOptions: GestureRecognizerOptions): Promise { - const recognizer = await VisionTaskRunner.createInstance( - GestureRecognizer, /* initializeCanvas= */ true, wasmFileset); - await recognizer.setOptions(gestureRecognizerOptions); - return recognizer; + return VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, + gestureRecognizerOptions); } /** @@ -108,8 +107,9 @@ export class GestureRecognizer extends static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return GestureRecognizer.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -119,13 +119,12 @@ export class GestureRecognizer extends * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return GestureRecognizer.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } constructor( @@ -134,6 +133,7 @@ export class GestureRecognizer extends super(wasmModule, glCanvas); this.options = new GestureRecognizerGraphOptions(); + this.options.setBaseOptions(new BaseOptionsProto()); this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions(); this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions); this.handLandmarksDetectorGraphOptions = @@ -151,11 +151,11 @@ export class GestureRecognizer extends this.initDefaults(); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index c657275bf..263ed4b48 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -77,13 +77,12 @@ export class HandLandmarker extends VisionTaskRunner { * Note that either a path to the model asset or a model buffer needs to * be provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, handLandmarkerOptions: HandLandmarkerOptions): Promise { - const landmarker = await VisionTaskRunner.createInstance( - HandLandmarker, /* initializeCanvas= */ true, wasmFileset); - await landmarker.setOptions(handLandmarkerOptions); - return landmarker; + return VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset, + handLandmarkerOptions); } /** @@ -96,8 +95,9 @@ export class HandLandmarker extends VisionTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return HandLandmarker.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -107,13 +107,12 @@ export class HandLandmarker extends VisionTaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return HandLandmarker.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } constructor( @@ -122,6 +121,7 @@ export class HandLandmarker extends VisionTaskRunner { super(wasmModule, glCanvas); this.options = new HandLandmarkerGraphOptions(); + this.options.setBaseOptions(new BaseOptionsProto()); this.handLandmarksDetectorGraphOptions = new HandLandmarksDetectorGraphOptions(); this.options.setHandLandmarksDetectorGraphOptions( @@ -132,11 +132,11 @@ export class HandLandmarker extends VisionTaskRunner { this.initDefaults(); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 185ddf9ea..90dbf9798 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -23,7 +23,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/ import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageClassifierOptions} from './image_classifier_options'; @@ -55,13 +55,12 @@ export class ImageClassifier extends VisionTaskRunner { * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions): Promise { - const classifier = await VisionTaskRunner.createInstance( - ImageClassifier, /* initializeCanvas= */ true, wasmFileset); - await classifier.setOptions(imageClassifierOptions); - return classifier; + return VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset, + imageClassifierOptions); } /** @@ -74,8 +73,9 @@ export class ImageClassifier extends VisionTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return ImageClassifier.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -85,20 +85,26 @@ export class ImageClassifier extends VisionTaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return ImageClassifier.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 91352e934..559332650 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -25,7 +25,7 @@ import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/ import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageEmbedderOptions} from './image_embedder_options'; @@ -57,13 +57,12 @@ export class ImageEmbedder extends VisionTaskRunner { * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, imageEmbedderOptions: ImageEmbedderOptions): Promise { - const embedder = await VisionTaskRunner.createInstance( - ImageEmbedder, /* initializeCanvas= */ true, wasmFileset); - await embedder.setOptions(imageEmbedderOptions); - return embedder; + return VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, + imageEmbedderOptions); } /** @@ -76,8 +75,9 @@ export class ImageEmbedder extends VisionTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return ImageEmbedder.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -87,20 +87,26 @@ export class ImageEmbedder extends VisionTaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return ImageEmbedder.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 7711c39e9..03171003f 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -21,7 +21,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ObjectDetectorOptions} from './object_detector_options'; @@ -54,13 +54,12 @@ export class ObjectDetector extends VisionTaskRunner { * either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, objectDetectorOptions: ObjectDetectorOptions): Promise { - const detector = await VisionTaskRunner.createInstance( - ObjectDetector, /* initializeCanvas= */ true, wasmFileset); - await detector.setOptions(objectDetectorOptions); - return detector; + return VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset, + objectDetectorOptions); } /** @@ -73,8 +72,9 @@ export class ObjectDetector extends VisionTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return ObjectDetector.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -87,17 +87,23 @@ export class ObjectDetector extends VisionTaskRunner { static async createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return ObjectDetector.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } From e457039fc6350fbd2e75aa2d034f9b68af6d3410 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 2 Dec 2022 16:16:34 -0800 Subject: [PATCH 084/346] Don't inherit from GraphRunner PiperOrigin-RevId: 492584486 --- .../audio_classifier/audio_classifier.ts | 9 +++-- .../audio/audio_embedder/audio_embedder.ts | 25 ++++++++------ mediapipe/tasks/web/core/task_runner.ts | 24 +++++++------- .../text/text_classifier/text_classifier.ts | 11 ++++--- .../web/text/text_embedder/text_embedder.ts | 4 +-- .../gesture_recognizer/gesture_recognizer.ts | 33 +++++++++++-------- .../vision/hand_landmarker/hand_landmarker.ts | 26 ++++++++------- .../image_classifier/image_classifier.ts | 11 ++++--- .../vision/image_embedder/image_embedder.ts | 4 +-- .../vision/object_detector/object_detector.ts | 9 ++--- .../graph_runner/graph_runner_image_lib.ts | 2 +- .../register_model_resources_graph_service.ts | 4 +-- 12 files changed, 92 insertions(+), 70 deletions(-) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 4e12780d2..265ba2b33 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -145,8 +145,11 @@ export class AudioClassifier extends AudioTaskRunner { protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioClassifierResult[] { - this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); - this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); + this.graphRunner.addDoubleToStream( + sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.graphRunner.addAudioToStreamWithShape( + audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, + AUDIO_STREAM, timestampMs); this.classificationResults = []; this.finishProcessing(); @@ -189,7 +192,7 @@ export class AudioClassifier extends AudioTaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoVectorListener( + this.graphRunner.attachProtoVectorListener( TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => { this.addJsAudioClassificationResults(binaryProtos); }); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index d08eb4791..445dd5172 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -158,8 +158,11 @@ export class AudioEmbedder extends AudioTaskRunner { protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioEmbedderResult[] { - this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); - this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); + this.graphRunner.addDoubleToStream( + sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.graphRunner.addAudioToStreamWithShape( + audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, + AUDIO_STREAM, timestampMs); this.embeddingResults = []; this.finishProcessing(); @@ -189,19 +192,21 @@ export class AudioEmbedder extends AudioTaskRunner { graphConfig.addNode(embedderNode); - this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); this.embeddingResults.push( convertFromEmbeddingResultProto(embeddingResult)); }); - this.attachProtoVectorListener(TIMESTAMPED_EMBEDDINGS_STREAM, data => { - for (const binaryProto of data) { - const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); - this.embeddingResults.push( - convertFromEmbeddingResultProto(embeddingResult)); - } - }); + this.graphRunner.attachProtoVectorListener( + TIMESTAMPED_EMBEDDINGS_STREAM, data => { + for (const binaryProto of data) { + const embeddingResult = + EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + } + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index c2691fc76..d769139bc 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -27,13 +27,15 @@ import {WasmFileset} from './wasm_fileset'; const NO_ASSETS = undefined; // tslint:disable-next-line:enforce-name-casing -const WasmMediaPipeImageLib = +const GraphRunnerImageLibType = SupportModelResourcesGraphService(SupportImage(GraphRunner)); +/** An implementation of the GraphRunner that supports image operations */ +export class GraphRunnerImageLib extends GraphRunnerImageLibType {} /** Base class for all MediaPipe Tasks. */ -export abstract class TaskRunner extends - WasmMediaPipeImageLib { +export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; + protected graphRunner: GraphRunnerImageLib; private processingErrors: Error[] = []; /** @@ -67,14 +69,14 @@ export abstract class TaskRunner extends constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + this.graphRunner = new GraphRunnerImageLib(wasmModule, glCanvas); // Disables the automatic render-to-screen code, which allows for pure // CPU processing. - this.setAutoRenderToScreen(false); + this.graphRunner.setAutoRenderToScreen(false); // Enables use of our model resource caching graph service. - this.registerModelResourcesGraphService(); + this.graphRunner.registerModelResourcesGraphService(); } /** Configures the shared options of a MediaPipe Task. */ @@ -95,11 +97,11 @@ export abstract class TaskRunner extends * @param isBinary This should be set to true if the graph is in * binary format, and false if it is in human-readable text format. */ - override setGraph(graphData: Uint8Array, isBinary: boolean): void { - this.attachErrorListener((code, message) => { + protected setGraph(graphData: Uint8Array, isBinary: boolean): void { + this.graphRunner.attachErrorListener((code, message) => { this.processingErrors.push(new Error(message)); }); - super.setGraph(graphData, isBinary); + this.graphRunner.setGraph(graphData, isBinary); this.handleErrors(); } @@ -108,8 +110,8 @@ export abstract class TaskRunner extends * far as possible, performing all processing until no more processing can be * done. */ - override finishProcessing(): void { - super.finishProcessing(); + protected finishProcessing(): void { + this.graphRunner.finishProcessing(); this.handleErrors(); } diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index bd2a207ce..8810d4b42 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -133,7 +133,7 @@ export class TextClassifier extends TaskRunner { classify(text: string): TextClassifierResult { // Get classification result by running our MediaPipe graph. this.classificationResult = {classifications: []}; - this.addStringToStream( + this.graphRunner.addStringToStream( text, INPUT_STREAM, /* timestamp= */ performance.now()); this.finishProcessing(); return this.classificationResult; @@ -157,10 +157,11 @@ export class TextClassifier extends TaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { - this.classificationResult = convertFromClassificationResultProto( - ClassificationResult.deserializeBinary(binaryProto)); - }); + this.graphRunner.attachProtoListener( + CLASSIFICATIONS_STREAM, binaryProto => { + this.classificationResult = convertFromClassificationResultProto( + ClassificationResult.deserializeBinary(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index d2899fbe2..62f9b06db 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -136,7 +136,7 @@ export class TextEmbedder extends TaskRunner { */ embed(text: string): TextEmbedderResult { // Get text embeddings by running our MediaPipe graph. - this.addStringToStream( + this.graphRunner.addStringToStream( text, INPUT_STREAM, /* timestamp= */ performance.now()); this.finishProcessing(); return this.embeddingResult; @@ -173,7 +173,7 @@ export class TextEmbedder extends TaskRunner { graphConfig.addNode(embedderNode); - this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); this.embeddingResult = convertFromEmbeddingResultProto(embeddingResult); }); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 8baee5ce3..69a8118a6 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -257,8 +257,9 @@ export class GestureRecognizer extends this.worldLandmarks = []; this.handednesses = []; - this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); - this.addProtoToStream( + this.graphRunner.addGpuBufferAsImageToStream( + imageSource, IMAGE_STREAM, timestamp); + this.graphRunner.addProtoToStream( FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', NORM_RECT_STREAM, timestamp); this.finishProcessing(); @@ -365,18 +366,22 @@ export class GestureRecognizer extends graphConfig.addNode(recognizerNode); - this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { - this.addJsLandmarks(binaryProto); - }); - this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { - this.adddJsWorldLandmarks(binaryProto); - }); - this.attachProtoVectorListener(HAND_GESTURES_STREAM, binaryProto => { - this.gestures.push(...this.toJsCategories(binaryProto)); - }); - this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { - this.handednesses.push(...this.toJsCategories(binaryProto)); - }); + this.graphRunner.attachProtoVectorListener( + LANDMARKS_STREAM, binaryProto => { + this.addJsLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + WORLD_LANDMARKS_STREAM, binaryProto => { + this.adddJsWorldLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + HAND_GESTURES_STREAM, binaryProto => { + this.gestures.push(...this.toJsCategories(binaryProto)); + }); + this.graphRunner.attachProtoVectorListener( + HANDEDNESS_STREAM, binaryProto => { + this.handednesses.push(...this.toJsCategories(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 263ed4b48..9a0823f23 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -208,8 +208,9 @@ export class HandLandmarker extends VisionTaskRunner { this.worldLandmarks = []; this.handednesses = []; - this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); - this.addProtoToStream( + this.graphRunner.addGpuBufferAsImageToStream( + imageSource, IMAGE_STREAM, timestamp); + this.graphRunner.addProtoToStream( FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', NORM_RECT_STREAM, timestamp); this.finishProcessing(); @@ -312,15 +313,18 @@ export class HandLandmarker extends VisionTaskRunner { graphConfig.addNode(landmarkerNode); - this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { - this.addJsLandmarks(binaryProto); - }); - this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { - this.adddJsWorldLandmarks(binaryProto); - }); - this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { - this.handednesses.push(...this.toJsCategories(binaryProto)); - }); + this.graphRunner.attachProtoVectorListener( + LANDMARKS_STREAM, binaryProto => { + this.addJsLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + WORLD_LANDMARKS_STREAM, binaryProto => { + this.adddJsWorldLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + HANDEDNESS_STREAM, binaryProto => { + this.handednesses.push(...this.toJsCategories(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 90dbf9798..40e8b5099 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -155,7 +155,7 @@ export class ImageClassifier extends VisionTaskRunner { ImageClassifierResult { // Get classification result by running our MediaPipe graph. this.classificationResult = {classifications: []}; - this.addGpuBufferAsImageToStream( + this.graphRunner.addGpuBufferAsImageToStream( imageSource, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); return this.classificationResult; @@ -181,10 +181,11 @@ export class ImageClassifier extends VisionTaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { - this.classificationResult = convertFromClassificationResultProto( - ClassificationResult.deserializeBinary(binaryProto)); - }); + this.graphRunner.attachProtoListener( + CLASSIFICATIONS_STREAM, binaryProto => { + this.classificationResult = convertFromClassificationResultProto( + ClassificationResult.deserializeBinary(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 559332650..f8b0204ee 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -169,7 +169,7 @@ export class ImageEmbedder extends VisionTaskRunner { protected process(image: ImageSource, timestamp: number): ImageEmbedderResult { // Get embeddings by running our MediaPipe graph. - this.addGpuBufferAsImageToStream( + this.graphRunner.addGpuBufferAsImageToStream( image, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); return this.embeddings; @@ -201,7 +201,7 @@ export class ImageEmbedder extends VisionTaskRunner { graphConfig.addNode(embedderNode); - this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { this.addJsImageEmdedding(binaryProto); }); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 03171003f..e2cfe0575 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -185,7 +185,7 @@ export class ObjectDetector extends VisionTaskRunner { Detection[] { // Get detections by running our MediaPipe graph. this.detections = []; - this.addGpuBufferAsImageToStream( + this.graphRunner.addGpuBufferAsImageToStream( imageSource, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); return [...this.detections]; @@ -242,9 +242,10 @@ export class ObjectDetector extends VisionTaskRunner { graphConfig.addNode(detectorNode); - this.attachProtoVectorListener(DETECTIONS_STREAM, binaryProto => { - this.addJsObjectDetections(binaryProto); - }); + this.graphRunner.attachProtoVectorListener( + DETECTIONS_STREAM, binaryProto => { + this.addJsObjectDetections(binaryProto); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/web/graph_runner/graph_runner_image_lib.ts b/mediapipe/web/graph_runner/graph_runner_image_lib.ts index e886999cb..7a4ea09e2 100644 --- a/mediapipe/web/graph_runner/graph_runner_image_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner_image_lib.ts @@ -22,7 +22,7 @@ export declare interface WasmImageModule { * An implementation of GraphRunner that supports binding GPU image data as * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for * effective multiple inheritance. Example usage: - * `const WasmMediaPipeImageLib = SupportImage(GraphRunner);` + * `const GraphRunnerImageLib = SupportImage(GraphRunner);` */ // tslint:disable-next-line:enforce-name-casing export function SupportImage(Base: TBase) { diff --git a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts index bc9c93e8a..9f2791d80 100644 --- a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts +++ b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts @@ -20,8 +20,8 @@ export declare interface WasmModuleRegisterModelResources { * An implementation of GraphRunner that supports registering model * resources to a cache, in the form of a GraphService C++-side. We implement as * a proper TS mixin, to allow for effective multiple inheritance. Sample usage: - * `const WasmMediaPipeImageLib = SupportModelResourcesGraphService( - * GraphRunner);` + * `const GraphRunnerWithModelResourcesLib = + * SupportModelResourcesGraphService(GraphRunner);` */ // tslint:disable:enforce-name-casing export function SupportModelResourcesGraphService( From 35bb18945f21856f62cd99027f7702b92411dfc5 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 07:22:51 -0800 Subject: [PATCH 085/346] Better handling of empty packets in vector calculators. PiperOrigin-RevId: 493000695 --- .../core/get_vector_item_calculator.h | 9 +++-- .../core/get_vector_item_calculator.proto | 3 ++ .../core/get_vector_item_calculator_test.cc | 34 ++++++++++++++----- .../core/merge_to_vector_calculator.cc | 4 +++ .../core/merge_to_vector_calculator.h | 15 ++++++-- 5 files changed, 51 insertions(+), 14 deletions(-) diff --git a/mediapipe/calculators/core/get_vector_item_calculator.h b/mediapipe/calculators/core/get_vector_item_calculator.h index dc98ccfe7..25d90bfe6 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.h +++ b/mediapipe/calculators/core/get_vector_item_calculator.h @@ -65,6 +65,7 @@ class GetVectorItemCalculator : public Node { MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut); absl::Status Open(CalculatorContext* cc) final { + cc->SetOffset(mediapipe::TimestampDiff(0)); auto& options = cc->Options(); RET_CHECK(kIdx(cc).IsConnected() || options.has_item_index()); return absl::OkStatus(); @@ -90,8 +91,12 @@ class GetVectorItemCalculator : public Node { return absl::OkStatus(); } - RET_CHECK(idx >= 0 && idx < items.size()); - kOut(cc).Send(items[idx]); + RET_CHECK(idx >= 0); + RET_CHECK(options.output_empty_on_oob() || idx < items.size()); + + if (idx < items.size()) { + kOut(cc).Send(items[idx]); + } return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/get_vector_item_calculator.proto b/mediapipe/calculators/core/get_vector_item_calculator.proto index c406283e4..9cfb579e4 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.proto +++ b/mediapipe/calculators/core/get_vector_item_calculator.proto @@ -26,4 +26,7 @@ message GetVectorItemCalculatorOptions { // Index of vector item to get. INDEX input stream can be used instead, or to // override. optional int32 item_index = 1; + + // Set to true to output an empty packet when the index is out of bounds. + optional bool output_empty_on_oob = 2; } diff --git a/mediapipe/calculators/core/get_vector_item_calculator_test.cc b/mediapipe/calculators/core/get_vector_item_calculator_test.cc index c148aa9d1..c2974e20a 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator_test.cc +++ b/mediapipe/calculators/core/get_vector_item_calculator_test.cc @@ -32,18 +32,21 @@ CalculatorRunner MakeRunnerWithStream() { )"); } -CalculatorRunner MakeRunnerWithOptions(int set_index) { - return CalculatorRunner(absl::StrFormat(R"( +CalculatorRunner MakeRunnerWithOptions(int set_index, + bool output_empty_on_oob = false) { + return CalculatorRunner( + absl::StrFormat(R"( calculator: "TestGetIntVectorItemCalculator" input_stream: "VECTOR:vector_stream" output_stream: "ITEM:item_stream" options { [mediapipe.GetVectorItemCalculatorOptions.ext] { item_index: %d + output_empty_on_oob: %s } } )", - set_index)); + set_index, output_empty_on_oob ? "true" : "false")); } void AddInputVector(CalculatorRunner& runner, const std::vector& inputs, @@ -140,8 +143,7 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail1) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0")); } TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) { @@ -155,7 +157,8 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + testing::HasSubstr( + "options.output_empty_on_oob() || idx < items.size()")); } TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) { @@ -167,8 +170,7 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0")); } TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) { @@ -181,7 +183,21 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + testing::HasSubstr( + "options.output_empty_on_oob() || idx < items.size()")); +} + +TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail3) { + const int try_index = 3; + CalculatorRunner runner = MakeRunnerWithOptions(try_index, true); + const std::vector inputs = {1, 2, 3}; + + AddInputVector(runner, inputs, 1); + + MP_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Tag("ITEM").packets; + EXPECT_THAT(outputs, testing::ElementsAre()); } TEST(TestGetIntVectorItemCalculatorTest, IndexStreamTwoTimestamps) { diff --git a/mediapipe/calculators/core/merge_to_vector_calculator.cc b/mediapipe/calculators/core/merge_to_vector_calculator.cc index cca64bc9a..5f05ad725 100644 --- a/mediapipe/calculators/core/merge_to_vector_calculator.cc +++ b/mediapipe/calculators/core/merge_to_vector_calculator.cc @@ -23,5 +23,9 @@ namespace api2 { typedef MergeToVectorCalculator MergeImagesToVectorCalculator; MEDIAPIPE_REGISTER_NODE(MergeImagesToVectorCalculator); +typedef MergeToVectorCalculator + MergeGpuBuffersToVectorCalculator; +MEDIAPIPE_REGISTER_NODE(MergeGpuBuffersToVectorCalculator); + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/merge_to_vector_calculator.h b/mediapipe/calculators/core/merge_to_vector_calculator.h index bed616695..f63d86ee4 100644 --- a/mediapipe/calculators/core/merge_to_vector_calculator.h +++ b/mediapipe/calculators/core/merge_to_vector_calculator.h @@ -42,11 +42,20 @@ class MergeToVectorCalculator : public Node { return absl::OkStatus(); } + absl::Status Open(::mediapipe::CalculatorContext* cc) { + cc->SetOffset(::mediapipe::TimestampDiff(0)); + return absl::OkStatus(); + } + absl::Status Process(CalculatorContext* cc) { const int input_num = kIn(cc).Count(); - std::vector output_vector(input_num); - std::transform(kIn(cc).begin(), kIn(cc).end(), output_vector.begin(), - [](const auto& elem) -> T { return elem.Get(); }); + std::vector output_vector; + for (auto it = kIn(cc).begin(); it != kIn(cc).end(); it++) { + const auto& elem = *it; + if (!elem.IsEmpty()) { + output_vector.push_back(elem.Get()); + } + } kOut(cc).Send(output_vector); return absl::OkStatus(); } From 4f8eaee20f5c02d932b8bacecd1afb0655d84130 Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Mon, 5 Dec 2022 11:33:21 -0800 Subject: [PATCH 086/346] Internal change PiperOrigin-RevId: 493065632 --- mediapipe/graphs/iris_tracking/calculators/BUILD | 1 - mediapipe/java/com/google/mediapipe/framework/jni/BUILD | 7 +++---- mediapipe/modules/hand_landmark/calculators/BUILD | 1 - mediapipe/modules/objectron/calculators/BUILD | 4 ---- mediapipe/util/tracking/BUILD | 1 - 5 files changed, 3 insertions(+), 11 deletions(-) diff --git a/mediapipe/graphs/iris_tracking/calculators/BUILD b/mediapipe/graphs/iris_tracking/calculators/BUILD index 3a3d57a0f..f5124b464 100644 --- a/mediapipe/graphs/iris_tracking/calculators/BUILD +++ b/mediapipe/graphs/iris_tracking/calculators/BUILD @@ -97,7 +97,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:image_file_properties_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD index 4926e2f3c..4540f63a6 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD @@ -84,12 +84,11 @@ cc_library( deps = [ ":class_registry", ":jni_util", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:calculator_profile_cc_proto", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_profile_cc_proto", + "//mediapipe/framework:calculator_framework", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", diff --git a/mediapipe/modules/hand_landmark/calculators/BUILD b/mediapipe/modules/hand_landmark/calculators/BUILD index b2a8efe37..b42ec94de 100644 --- a/mediapipe/modules/hand_landmark/calculators/BUILD +++ b/mediapipe/modules/hand_landmark/calculators/BUILD @@ -24,7 +24,6 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/modules/objectron/calculators/BUILD b/mediapipe/modules/objectron/calculators/BUILD index eeeaee5f4..14cea526f 100644 --- a/mediapipe/modules/objectron/calculators/BUILD +++ b/mediapipe/modules/objectron/calculators/BUILD @@ -275,7 +275,6 @@ cc_library( ":tflite_tensors_to_objects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", @@ -299,7 +298,6 @@ cc_library( ":tensors_to_objects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", @@ -316,13 +314,11 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":annotation_cc_proto", - ":belief_decoder_config_cc_proto", ":decoder", ":lift_2d_frame_annotation_to_3d_calculator_cc_proto", ":tensor_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index 6bca24446..816af2533 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -282,7 +282,6 @@ cc_library( srcs = ["motion_models_cv.cc"], hdrs = ["motion_models_cv.h"], deps = [ - ":camera_motion_cc_proto", ":motion_models", ":motion_models_cc_proto", "//mediapipe/framework/port:opencv_core", From 69b27b246a3f11e775791805eb2c2b4858ed9412 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 14:14:13 -0800 Subject: [PATCH 087/346] Adds a public function for creating TaskRunner instances. PiperOrigin-RevId: 493109736 --- mediapipe/tasks/web/core/task_runner.ts | 46 ++++++++++++++++--------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index d769139bc..e2ab42e31 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -32,6 +32,34 @@ const GraphRunnerImageLibType = /** An implementation of the GraphRunner that supports image operations */ export class GraphRunnerImageLib extends GraphRunnerImageLibType {} +/** + * Creates a new instance of a Mediapipe Task. Determines if SIMD is + * supported and loads the relevant WASM binary. + * @return A fully instantiated instance of `T`. + */ +export async function +createTaskRunner, O extends TaskRunnerOptions>( + type: WasmMediaPipeConstructor, initializeCanvas: boolean, + fileset: WasmFileset, options: O): Promise { + const fileLocator: FileLocator = { + locateFile() { + // The only file loaded with this mechanism is the Wasm binary + return fileset.wasmBinaryPath.toString(); + } + }; + + // Initialize a canvas if requested. If OffscreenCanvas is availble, we + // let the graph runner initialize it by passing `undefined`. + const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? + document.createElement('canvas') : + undefined) : + null; + const instance = await createMediaPipeLib( + type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); + await instance.setOptions(options); + return instance; +} + /** Base class for all MediaPipe Tasks. */ export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; @@ -47,23 +75,7 @@ export abstract class TaskRunner { O extends TaskRunnerOptions>( type: WasmMediaPipeConstructor, initializeCanvas: boolean, fileset: WasmFileset, options: O): Promise { - const fileLocator: FileLocator = { - locateFile() { - // The only file loaded with this mechanism is the Wasm binary - return fileset.wasmBinaryPath.toString(); - } - }; - - // Initialize a canvas if requested. If OffscreenCanvas is availble, we - // let the graph runner initialize it by passing `undefined`. - const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? - document.createElement('canvas') : - undefined) : - null; - const instance = await createMediaPipeLib( - type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); - await instance.setOptions(options); - return instance; + return createTaskRunner(type, initializeCanvas, fileset, options); } constructor( From 3ad03bee0be95376cf4606d39b201dab5a0afcb5 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 14:48:07 -0800 Subject: [PATCH 088/346] Adds missing visibility rule. PiperOrigin-RevId: 493118880 --- mediapipe/calculators/tensor/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 645189a07..577ac4111 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -366,6 +366,9 @@ cc_test( cc_library( name = "universal_sentence_encoder_preprocessor_calculator", srcs = ["universal_sentence_encoder_preprocessor_calculator.cc"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", From 99d1dd6fbb130f9f262365ae334b2ca22c819478 Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Mon, 5 Dec 2022 15:28:52 -0800 Subject: [PATCH 089/346] Internal change PiperOrigin-RevId: 493129643 --- docs/build_py_api_docs.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py index fe706acd3..46546012d 100644 --- a/docs/build_py_api_docs.py +++ b/docs/build_py_api_docs.py @@ -26,7 +26,6 @@ from absl import app from absl import flags from tensorflow_docs.api_generator import generate_lib -from tensorflow_docs.api_generator import public_api try: # mediapipe has not been set up to work with bazel yet, so catch & report. @@ -68,10 +67,7 @@ def gen_api_docs(): code_url_prefix=_URL_PREFIX.value, search_hints=_SEARCH_HINTS.value, site_path=_SITE_PATH.value, - # This callback ensures that docs are only generated for objects that - # are explicitly imported in your __init__.py files. There are other - # options but this is a good starting point. - callbacks=[public_api.explicit_package_contents_filter], + callbacks=[], ) doc_generator.build(_OUTPUT_DIR.value) From 1e76d47a71602ba0ac4a089f625bbd667a7f184b Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 5 Dec 2022 16:18:36 -0800 Subject: [PATCH 090/346] Checks if a custom global resource provider is used as the first step of loading the model resources on all platforms. PiperOrigin-RevId: 493141519 --- mediapipe/tasks/cc/core/BUILD | 1 + mediapipe/tasks/cc/core/model_resources.cc | 30 +++++++++++----------- mediapipe/util/resource_util.cc | 2 ++ mediapipe/util/resource_util_custom.h | 3 +++ 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 202f3ea3c..f8004d257 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -117,6 +117,7 @@ cc_library_with_tflite( "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/util:resource_util", + "//mediapipe/util:resource_util_custom", "//mediapipe/util/tflite:error_reporter", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/core/model_resources.cc b/mediapipe/tasks/cc/core/model_resources.cc index d5c12ee95..7819f6213 100644 --- a/mediapipe/tasks/cc/core/model_resources.cc +++ b/mediapipe/tasks/cc/core/model_resources.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/util/resource_util.h" +#include "mediapipe/util/resource_util_custom.h" #include "mediapipe/util/tflite/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -99,21 +100,20 @@ const tflite::Model* ModelResources::GetTfLiteModel() const { absl::Status ModelResources::BuildModelFromExternalFileProto() { if (model_file_->has_file_name()) { -#ifdef __EMSCRIPTEN__ - // In browsers, the model file may require a custom ResourceProviderFn to - // provide the model content. The open() method may not work in this case. - // Thus, loading the model content from the model file path in advance with - // the help of GetResourceContents. - MP_RETURN_IF_ERROR(mediapipe::GetResourceContents( - model_file_->file_name(), model_file_->mutable_file_content())); - model_file_->clear_file_name(); -#else - // If the model file name is a relative path, searches the file in a - // platform-specific location and returns the absolute path on success. - ASSIGN_OR_RETURN(std::string path_to_resource, - mediapipe::PathToResourceAsFile(model_file_->file_name())); - model_file_->set_file_name(path_to_resource); -#endif // __EMSCRIPTEN__ + if (HasCustomGlobalResourceProvider()) { + // If the model contents are provided via a custom ResourceProviderFn, the + // open() method may not work. Thus, loads the model content from the + // model file path in advance with the help of GetResourceContents. + MP_RETURN_IF_ERROR(GetResourceContents( + model_file_->file_name(), model_file_->mutable_file_content())); + model_file_->clear_file_name(); + } else { + // If the model file name is a relative path, searches the file in a + // platform-specific location and returns the absolute path on success. + ASSIGN_OR_RETURN(std::string path_to_resource, + PathToResourceAsFile(model_file_->file_name())); + model_file_->set_file_name(path_to_resource); + } } ASSIGN_OR_RETURN( model_file_handler_, diff --git a/mediapipe/util/resource_util.cc b/mediapipe/util/resource_util.cc index 8f40154a0..38636f32e 100644 --- a/mediapipe/util/resource_util.cc +++ b/mediapipe/util/resource_util.cc @@ -37,6 +37,8 @@ absl::Status GetResourceContents(const std::string& path, std::string* output, return internal::DefaultGetResourceContents(path, output, read_as_binary); } +bool HasCustomGlobalResourceProvider() { return resource_provider_ != nullptr; } + void SetCustomGlobalResourceProvider(ResourceProviderFn fn) { resource_provider_ = std::move(fn); } diff --git a/mediapipe/util/resource_util_custom.h b/mediapipe/util/resource_util_custom.h index 6bc1513c6..e74af8b2e 100644 --- a/mediapipe/util/resource_util_custom.h +++ b/mediapipe/util/resource_util_custom.h @@ -10,6 +10,9 @@ namespace mediapipe { typedef std::function ResourceProviderFn; +// Returns true if files are provided via a custom resource provider. +bool HasCustomGlobalResourceProvider(); + // Overrides the behavior of GetResourceContents. void SetCustomGlobalResourceProvider(ResourceProviderFn fn); From 3174b20fbe8225c35433d86f3a82d29645bb82bb Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 17:32:57 -0800 Subject: [PATCH 091/346] Move segmentation calculator and options out of 'components' folder. PiperOrigin-RevId: 493157929 --- mediapipe/tasks/cc/components/proto/BUILD | 24 ------------------- .../tasks/cc/vision/image_segmenter/BUILD | 8 +++---- .../image_segmenter/calculators}/BUILD | 4 ++-- .../tensors_to_segmentation_calculator.cc | 14 +++++------ .../tensors_to_segmentation_calculator.proto | 6 +++-- ...tensors_to_segmentation_calculator_test.cc | 4 +--- .../vision/image_segmenter/image_segmenter.cc | 4 ++-- .../image_segmenter/image_segmenter_graph.cc | 6 ++--- .../image_segmenter/image_segmenter_test.cc | 2 +- .../cc/vision/image_segmenter/proto/BUILD | 7 +++++- .../proto/image_segmenter_graph_options.proto | 4 ++-- .../proto/segmenter_options.proto | 4 ++-- mediapipe/tasks/python/vision/BUILD | 2 +- .../tasks/python/vision/image_segmenter.py | 2 +- 14 files changed, 35 insertions(+), 56 deletions(-) delete mode 100644 mediapipe/tasks/cc/components/proto/BUILD rename mediapipe/tasks/cc/{components/calculators/tensor => vision/image_segmenter/calculators}/BUILD (94%) rename mediapipe/tasks/cc/{components/calculators/tensor => vision/image_segmenter/calculators}/tensors_to_segmentation_calculator.cc (95%) rename mediapipe/tasks/cc/{components/calculators/tensor => vision/image_segmenter/calculators}/tensors_to_segmentation_calculator.proto (82%) rename mediapipe/tasks/cc/{components/calculators/tensor => vision/image_segmenter/calculators}/tensors_to_segmentation_calculator_test.cc (99%) rename mediapipe/tasks/cc/{components => vision/image_segmenter}/proto/segmenter_options.proto (92%) diff --git a/mediapipe/tasks/cc/components/proto/BUILD b/mediapipe/tasks/cc/components/proto/BUILD deleted file mode 100644 index 569023753..000000000 --- a/mediapipe/tasks/cc/components/proto/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -mediapipe_proto_library( - name = "segmenter_options_proto", - srcs = ["segmenter_options.proto"], -) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 2124fe6e0..4c9c6e69c 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -28,7 +28,6 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", @@ -36,6 +35,7 @@ cc_library( "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", @@ -56,17 +56,17 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator", - "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator_cc_proto", "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", diff --git a/mediapipe/tasks/cc/components/calculators/tensor/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD similarity index 94% rename from mediapipe/tasks/cc/components/calculators/tensor/BUILD rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD index 6e4322a8f..dcd7fb407 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD @@ -25,7 +25,7 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:image_format_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_proto", "//mediapipe/util:label_map_proto", ], ) @@ -45,7 +45,7 @@ cc_library( "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", - "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_utils", "//mediapipe/util:label_map_cc_proto", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc similarity index 95% rename from mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index 40585848f..668de0057 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// TODO consolidate TensorsToSegmentationCalculator. #include #include #include @@ -35,14 +34,14 @@ limitations under the License. #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status_macros.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/util/label_map.pb.h" +// TODO: consolidate TensorToSegmentationCalculator. namespace mediapipe { namespace tasks { - namespace { using ::mediapipe::Image; @@ -51,9 +50,9 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Node; using ::mediapipe::api2::Output; using ::mediapipe::tasks::TensorsToSegmentationCalculatorOptions; -using ::mediapipe::tasks::components::proto::SegmenterOptions; using ::mediapipe::tasks::vision::GetImageLikeTensorShape; using ::mediapipe::tasks::vision::Shape; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; void StableSoftmax(absl::Span values, absl::Span activated_values) { @@ -90,7 +89,7 @@ void Sigmoid(absl::Span values, // the size to resize masks to. // // Output: -// Segmentation: Segmenation proto. +// Segmentation: Segmentation proto. // // Options: // See tensors_to_segmentation_calculator.proto @@ -132,8 +131,7 @@ class TensorsToSegmentationCalculator : public Node { absl::Status TensorsToSegmentationCalculator::Open( mediapipe::CalculatorContext* cc) { - options_ = - cc->Options(); + options_ = cc->Options(); RET_CHECK_NE(options_.segmenter_options().output_type(), SegmenterOptions::UNSPECIFIED) << "Must specify output_type as one of [CONFIDENCE_MASK|CATEGORY_MASK]."; diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto similarity index 82% rename from mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto index c26cf910a..b0fdfdd32 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto @@ -15,10 +15,11 @@ limitations under the License. syntax = "proto2"; +// TODO: consolidate TensorToSegmentationCalculator. package mediapipe.tasks; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/segmenter_options.proto"; +import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; import "mediapipe/util/label_map.proto"; message TensorsToSegmentationCalculatorOptions { @@ -26,7 +27,8 @@ message TensorsToSegmentationCalculatorOptions { optional TensorsToSegmentationCalculatorOptions ext = 458105876; } - optional components.proto.SegmenterOptions segmenter_options = 1; + optional mediapipe.tasks.vision.image_segmenter.proto.SegmenterOptions + segmenter_options = 1; // Identifying information for each classification label. map label_items = 2; diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc similarity index 99% rename from mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc index 55e46d72b..54fb9b816 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc @@ -33,10 +33,9 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" namespace mediapipe { -namespace api2 { namespace { @@ -374,5 +373,4 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) { expected_index, buffer_indices))); } -} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 6dce1b4ea..bbee714c6 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -18,12 +18,12 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" namespace mediapipe { namespace tasks { @@ -44,7 +44,7 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; -using ::mediapipe::tasks::components::proto::SegmenterOptions; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: image_segmenter::proto::ImageSegmenterGraphOptions; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index d5eb5af0d..5531968c1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -26,16 +26,16 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map_util.h" @@ -54,10 +54,10 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::MultiSource; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::proto::SegmenterOptions; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::mediapipe::tasks::vision::image_segmenter::proto:: ImageSegmenterGraphOptions; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ::tflite::Tensor; using ::tflite::TensorMetadata; using LabelItems = mediapipe::proto_ns::Map; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 752a116dd..d5ea088a1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -28,11 +28,11 @@ limitations under the License. #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD index 3b14060f1..9523dd679 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD @@ -18,13 +18,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +mediapipe_proto_library( + name = "segmenter_options_proto", + srcs = ["segmenter_options.proto"], +) + mediapipe_proto_library( name = "image_segmenter_graph_options_proto", srcs = ["image_segmenter_graph_options.proto"], deps = [ + ":segmenter_options_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto index 166e2e8e0..4d8100842 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto @@ -18,8 +18,8 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_segmenter.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/segmenter_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.imagesegmenter.proto"; option java_outer_classname = "ImageSegmenterGraphOptionsProto"; @@ -37,5 +37,5 @@ message ImageSegmenterGraphOptions { optional string display_names_locale = 2 [default = "en"]; // Segmentation output options. - optional components.proto.SegmenterOptions segmenter_options = 3; + optional SegmenterOptions segmenter_options = 3; } diff --git a/mediapipe/tasks/cc/components/proto/segmenter_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto similarity index 92% rename from mediapipe/tasks/cc/components/proto/segmenter_options.proto rename to mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto index ca9986707..be2b8a589 100644 --- a/mediapipe/tasks/cc/components/proto/segmenter_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto @@ -15,9 +15,9 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.vision.image_segmenter.proto; -option java_package = "com.google.mediapipe.tasks.components.proto"; +option java_package = "com.google.mediapipe.tasks.vision.imagesegmenter.proto"; option java_outer_classname = "SegmenterOptionsProto"; // Shared options used by image segmentation tasks. diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index e94507eed..29e7577e8 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -69,8 +69,8 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", - "//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_py_pb2", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index 62fc8bb7c..22a37cb3e 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -21,8 +21,8 @@ from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet -from mediapipe.tasks.cc.components.proto import segmenter_options_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2 +from mediapipe.tasks.cc.vision.image_segmenter.proto import segmenter_options_pb2 from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls From af43687f2e3c774ff7b0f1f4881d456952a6aadd Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 20:07:29 -0800 Subject: [PATCH 092/346] Open-sources a unit test. PiperOrigin-RevId: 493184055 --- .../text_classifier/text_classifier_test.cc | 51 +++++++++++++++---- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc index 8f73914fc..799885eac 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -38,10 +38,7 @@ limitations under the License. #include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" -namespace mediapipe { -namespace tasks { -namespace text { -namespace text_classifier { +namespace mediapipe::tasks::text::text_classifier { namespace { using ::mediapipe::file::JoinPath; @@ -88,6 +85,8 @@ void ExpectApproximatelyEqual(const TextClassifierResult& actual, } } +} // namespace + class TextClassifierTest : public tflite_shims::testing::Test {}; TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) { @@ -217,8 +216,42 @@ TEST_F(TextClassifierTest, TextClassifierWithStringToBool) { MP_ASSERT_OK(classifier->Close()); } -} // namespace -} // namespace text_classifier -} // namespace text -} // namespace tasks -} // namespace mediapipe +TEST_F(TextClassifierTest, BertLongPositive) { + std::stringstream ss_for_positive_review; + ss_for_positive_review + << "it's a charming and often affecting journey and this is a long"; + for (int i = 0; i < kMaxSeqLen; ++i) { + ss_for_positive_review << " long"; + } + ss_for_positive_review << " movie review"; + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result, + classifier->Classify(ss_for_positive_review.str())); + TextClassifierResult expected; + std::vector categories; + +// Predicted scores are slightly different on Mac OS. +#ifdef __APPLE__ + categories.push_back( + {/*index=*/1, /*score=*/0.974181, /*category_name=*/"positive"}); + categories.push_back( + {/*index=*/0, /*score=*/0.025819, /*category_name=*/"negative"}); +#else + categories.push_back( + {/*index=*/1, /*score=*/0.985889, /*category_name=*/"positive"}); + categories.push_back( + {/*index=*/0, /*score=*/0.014112, /*category_name=*/"negative"}); +#endif // __APPLE__ + + expected.classifications.emplace_back( + Classifications{/*categories=*/categories, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(result, expected); + MP_ASSERT_OK(classifier->Close()); +} + +} // namespace mediapipe::tasks::text::text_classifier From 1761cdcef4ff6fd37d04d15de765eccd7c0a5bcc Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 5 Dec 2022 22:11:00 -0800 Subject: [PATCH 093/346] Fix aar breakage caused by missing "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark". PiperOrigin-RevId: 493204770 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index d91c03cc2..c6aba3c66 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -285,6 +285,7 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", From b0b38a0013c819a6db4156330cbbe2e0dab11bd8 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 6 Dec 2022 08:29:35 -0800 Subject: [PATCH 094/346] Internal change PiperOrigin-RevId: 493313240 --- mediapipe/gpu/metal_shared_resources.h | 40 +++++++++++ mediapipe/gpu/metal_shared_resources.mm | 73 ++++++++++++++++++++ mediapipe/gpu/metal_shared_resources_test.mm | 49 +++++++++++++ 3 files changed, 162 insertions(+) create mode 100644 mediapipe/gpu/metal_shared_resources.h create mode 100644 mediapipe/gpu/metal_shared_resources.mm create mode 100644 mediapipe/gpu/metal_shared_resources_test.mm diff --git a/mediapipe/gpu/metal_shared_resources.h b/mediapipe/gpu/metal_shared_resources.h new file mode 100644 index 000000000..341860a2d --- /dev/null +++ b/mediapipe/gpu/metal_shared_resources.h @@ -0,0 +1,40 @@ +#ifndef MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_ +#define MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_ + +#import +#import +#import +#import + +#ifndef __OBJC__ +#error This class must be built as Objective-C++. +#endif // !__OBJC__ + +@interface MPPMetalSharedResources : NSObject { +} + +- (instancetype)init NS_DESIGNATED_INITIALIZER; + +@property(readonly) id mtlDevice; +@property(readonly) id mtlCommandQueue; +#if COREVIDEO_SUPPORTS_METAL +@property(readonly) CVMetalTextureCacheRef mtlTextureCache; +#endif + +@end + +namespace mediapipe { + +class MetalSharedResources { + public: + MetalSharedResources(); + ~MetalSharedResources(); + MPPMetalSharedResources* resources() { return resources_; } + + private: + MPPMetalSharedResources* resources_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_ diff --git a/mediapipe/gpu/metal_shared_resources.mm b/mediapipe/gpu/metal_shared_resources.mm new file mode 100644 index 000000000..80d755a01 --- /dev/null +++ b/mediapipe/gpu/metal_shared_resources.mm @@ -0,0 +1,73 @@ +#import "mediapipe/gpu/metal_shared_resources.h" + +@interface MPPMetalSharedResources () +@end + +@implementation MPPMetalSharedResources { +} + +@synthesize mtlDevice = _mtlDevice; +@synthesize mtlCommandQueue = _mtlCommandQueue; +#if COREVIDEO_SUPPORTS_METAL +@synthesize mtlTextureCache = _mtlTextureCache; +#endif + +- (instancetype)init { + self = [super init]; + if (self) { + } + return self; +} + +- (void)dealloc { +#if COREVIDEO_SUPPORTS_METAL + if (_mtlTextureCache) { + CFRelease(_mtlTextureCache); + _mtlTextureCache = NULL; + } +#endif +} + +- (id)mtlDevice { + @synchronized(self) { + if (!_mtlDevice) { + _mtlDevice = MTLCreateSystemDefaultDevice(); + } + } + return _mtlDevice; +} + +- (id)mtlCommandQueue { + @synchronized(self) { + if (!_mtlCommandQueue) { + _mtlCommandQueue = [self.mtlDevice newCommandQueue]; + } + } + return _mtlCommandQueue; +} + +#if COREVIDEO_SUPPORTS_METAL +- (CVMetalTextureCacheRef)mtlTextureCache { + @synchronized(self) { + if (!_mtlTextureCache) { + CVReturn __unused err = + CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); + NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d ; device %@", err, + self.mtlDevice); + // TODO: register and flush metal caches too. + } + } + return _mtlTextureCache; +} +#endif + +@end + +namespace mediapipe { + +MetalSharedResources::MetalSharedResources() { + resources_ = [[MPPMetalSharedResources alloc] init]; +} +MetalSharedResources::~MetalSharedResources() {} + +} // namespace mediapipe diff --git a/mediapipe/gpu/metal_shared_resources_test.mm b/mediapipe/gpu/metal_shared_resources_test.mm new file mode 100644 index 000000000..9eb53a9b7 --- /dev/null +++ b/mediapipe/gpu/metal_shared_resources_test.mm @@ -0,0 +1,49 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#include + +#include "absl/memory/memory.h" +#include "mediapipe/framework/port/threadpool.h" + +#import "mediapipe/gpu/gpu_shared_data_internal.h" +#import "mediapipe/gpu/metal_shared_resources.h" + +@interface MPPMetalSharedResourcesTests : XCTestCase { +} +@end + +@implementation MPPMetalSharedResourcesTests + +// This test verifies that the internal Objective-C object is correctly +// released when the C++ wrapper is released. +- (void)testCorrectlyReleased { + __weak id metalRes = nil; + std::weak_ptr weakGpuRes; + @autoreleasepool { + auto maybeGpuRes = mediapipe::GpuResources::Create(); + XCTAssertTrue(maybeGpuRes.ok()); + weakGpuRes = *maybeGpuRes; + metalRes = (**maybeGpuRes).metal_shared().resources(); + XCTAssertNotEqual(weakGpuRes.lock(), nullptr); + XCTAssertNotNil(metalRes); + } + XCTAssertEqual(weakGpuRes.lock(), nullptr); + XCTAssertNil(metalRes); +} + +@end From fb0b96115f148c8c293f6cc3ddc7b3ed67b8043c Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 6 Dec 2022 08:33:51 -0800 Subject: [PATCH 095/346] Open up mediapipe core calculators' visibility. PiperOrigin-RevId: 493314353 --- mediapipe/calculators/core/BUILD | 88 +--------------------------- mediapipe/calculators/internal/BUILD | 6 +- 2 files changed, 4 insertions(+), 90 deletions(-) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 3b658eb5b..29bca5fa6 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -17,12 +17,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "concatenate_vector_calculator_proto", srcs = ["concatenate_vector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -32,7 +31,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "dequantize_byte_array_calculator_proto", srcs = ["dequantize_byte_array_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -42,7 +40,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_cloner_calculator_proto", srcs = ["packet_cloner_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -52,7 +49,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_resampler_calculator_proto", srcs = ["packet_resampler_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -62,7 +58,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_thinner_calculator_proto", srcs = ["packet_thinner_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -72,7 +67,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "split_vector_calculator_proto", srcs = ["split_vector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -82,7 +76,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "quantize_float_vector_calculator_proto", srcs = ["quantize_float_vector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -92,7 +85,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "sequence_shift_calculator_proto", srcs = ["sequence_shift_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -102,7 +94,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "gate_calculator_proto", srcs = ["gate_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -112,7 +103,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "constant_side_packet_calculator_proto", srcs = ["constant_side_packet_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -124,7 +114,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "clip_vector_size_calculator_proto", srcs = ["clip_vector_size_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -134,7 +123,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "flow_limiter_calculator_proto", srcs = ["flow_limiter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -144,7 +132,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "graph_profile_calculator_proto", srcs = ["graph_profile_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -154,7 +141,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "get_vector_item_calculator_proto", srcs = ["get_vector_item_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -164,7 +150,6 @@ mediapipe_proto_library( cc_library( name = "add_header_calculator", srcs = ["add_header_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -193,7 +178,6 @@ cc_library( name = "begin_loop_calculator", srcs = ["begin_loop_calculator.cc"], hdrs = ["begin_loop_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_contract", @@ -216,7 +200,6 @@ cc_library( name = "end_loop_calculator", srcs = ["end_loop_calculator.cc"], hdrs = ["end_loop_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_contract", @@ -258,7 +241,6 @@ cc_test( cc_library( name = "concatenate_vector_calculator_hdr", hdrs = ["concatenate_vector_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -284,7 +266,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework/api2:node", @@ -311,7 +292,6 @@ cc_library( cc_library( name = "concatenate_detection_vector_calculator", srcs = ["concatenate_detection_vector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator", "//mediapipe/framework:calculator_framework", @@ -323,7 +303,6 @@ cc_library( cc_library( name = "concatenate_proto_list_calculator", srcs = ["concatenate_proto_list_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -372,7 +351,6 @@ cc_library( name = "clip_vector_size_calculator", srcs = ["clip_vector_size_calculator.cc"], hdrs = ["clip_vector_size_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":clip_vector_size_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -388,7 +366,6 @@ cc_library( cc_library( name = "clip_detection_vector_size_calculator", srcs = ["clip_detection_vector_size_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":clip_vector_size_calculator", "//mediapipe/framework:calculator_framework", @@ -415,9 +392,6 @@ cc_test( cc_library( name = "counting_source_calculator", srcs = ["counting_source_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -430,9 +404,6 @@ cc_library( cc_library( name = "make_pair_calculator", srcs = ["make_pair_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -461,9 +432,6 @@ cc_test( cc_library( name = "matrix_multiply_calculator", srcs = ["matrix_multiply_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -477,9 +445,6 @@ cc_library( cc_library( name = "matrix_subtract_calculator", srcs = ["matrix_subtract_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -493,9 +458,6 @@ cc_library( cc_library( name = "mux_calculator", srcs = ["mux_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -508,9 +470,6 @@ cc_library( cc_library( name = "non_zero_calculator", srcs = ["non_zero_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -556,9 +515,6 @@ cc_test( cc_library( name = "packet_cloner_calculator", srcs = ["packet_cloner_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ ":packet_cloner_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -587,7 +543,6 @@ cc_test( cc_library( name = "packet_inner_join_calculator", srcs = ["packet_inner_join_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -611,7 +566,6 @@ cc_test( cc_library( name = "packet_thinner_calculator", srcs = ["packet_thinner_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -643,9 +597,6 @@ cc_test( cc_library( name = "pass_through_calculator", srcs = ["pass_through_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", @@ -656,9 +607,6 @@ cc_library( cc_library( name = "round_robin_demux_calculator", srcs = ["round_robin_demux_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -670,9 +618,6 @@ cc_library( cc_library( name = "immediate_mux_calculator", srcs = ["immediate_mux_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -684,7 +629,6 @@ cc_library( cc_library( name = "packet_presence_calculator", srcs = ["packet_presence_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -713,7 +657,6 @@ cc_test( cc_library( name = "previous_loopback_calculator", srcs = ["previous_loopback_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -729,7 +672,6 @@ cc_library( cc_library( name = "flow_limiter_calculator", srcs = ["flow_limiter_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":flow_limiter_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -746,7 +688,6 @@ cc_library( cc_library( name = "string_to_int_calculator", srcs = ["string_to_int_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:integral_types", @@ -759,7 +700,6 @@ cc_library( cc_library( name = "default_side_packet_calculator", srcs = ["default_side_packet_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -771,7 +711,6 @@ cc_library( cc_library( name = "side_packet_to_stream_calculator", srcs = ["side_packet_to_stream_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:logging", @@ -822,9 +761,6 @@ cc_library( name = "packet_resampler_calculator", srcs = ["packet_resampler_calculator.cc"], hdrs = ["packet_resampler_calculator.h"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -884,7 +820,6 @@ cc_test( cc_test( name = "matrix_multiply_calculator_test", srcs = ["matrix_multiply_calculator_test.cc"], - visibility = ["//visibility:private"], deps = [ ":matrix_multiply_calculator", "//mediapipe/framework:calculator_framework", @@ -900,7 +835,6 @@ cc_test( cc_test( name = "matrix_subtract_calculator_test", srcs = ["matrix_subtract_calculator_test.cc"], - visibility = ["//visibility:private"], deps = [ ":matrix_subtract_calculator", "//mediapipe/framework:calculator_framework", @@ -950,7 +884,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -996,7 +929,6 @@ cc_test( cc_library( name = "split_proto_list_calculator", srcs = ["split_proto_list_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1028,7 +960,6 @@ cc_test( cc_library( name = "dequantize_byte_array_calculator", srcs = ["dequantize_byte_array_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":dequantize_byte_array_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -1054,7 +985,6 @@ cc_test( cc_library( name = "quantize_float_vector_calculator", srcs = ["quantize_float_vector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":quantize_float_vector_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -1080,7 +1010,6 @@ cc_test( cc_library( name = "sequence_shift_calculator", srcs = ["sequence_shift_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":sequence_shift_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1105,7 +1034,6 @@ cc_test( cc_library( name = "gate_calculator", srcs = ["gate_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":gate_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1131,7 +1059,6 @@ cc_test( cc_library( name = "matrix_to_vector_calculator", srcs = ["matrix_to_vector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1167,7 +1094,6 @@ cc_test( cc_library( name = "merge_calculator", srcs = ["merge_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1193,7 +1119,6 @@ cc_test( cc_library( name = "stream_to_side_packet_calculator", srcs = ["stream_to_side_packet_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -1219,7 +1144,6 @@ cc_test( cc_library( name = "constant_side_packet_calculator", srcs = ["constant_side_packet_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":constant_side_packet_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1249,7 +1173,6 @@ cc_test( cc_library( name = "graph_profile_calculator", srcs = ["graph_profile_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":graph_profile_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1291,7 +1214,6 @@ cc_library( name = "get_vector_item_calculator", srcs = ["get_vector_item_calculator.cc"], hdrs = ["get_vector_item_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":get_vector_item_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1325,7 +1247,6 @@ cc_library( name = "vector_indices_calculator", srcs = ["vector_indices_calculator.cc"], hdrs = ["vector_indices_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1351,7 +1272,6 @@ cc_library( name = "vector_size_calculator", srcs = ["vector_size_calculator.cc"], hdrs = ["vector_size_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1365,9 +1285,6 @@ cc_library( cc_library( name = "packet_sequencer_calculator", srcs = ["packet_sequencer_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:contract", @@ -1402,7 +1319,6 @@ cc_library( name = "merge_to_vector_calculator", srcs = ["merge_to_vector_calculator.cc"], hdrs = ["merge_to_vector_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1416,7 +1332,6 @@ cc_library( mediapipe_proto_library( name = "bypass_calculator_proto", srcs = ["bypass_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1426,7 +1341,6 @@ mediapipe_proto_library( cc_library( name = "bypass_calculator", srcs = ["bypass_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":bypass_calculator_cc_proto", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/calculators/internal/BUILD b/mediapipe/calculators/internal/BUILD index 54b6c20f1..caade2dc3 100644 --- a/mediapipe/calculators/internal/BUILD +++ b/mediapipe/calculators/internal/BUILD @@ -21,7 +21,7 @@ package(default_visibility = ["//visibility:private"]) proto_library( name = "callback_packet_calculator_proto", srcs = ["callback_packet_calculator.proto"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = ["//mediapipe/framework:calculator_proto"], ) @@ -29,14 +29,14 @@ mediapipe_cc_proto_library( name = "callback_packet_calculator_cc_proto", srcs = ["callback_packet_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [":callback_packet_calculator_proto"], ) cc_library( name = "callback_packet_calculator", srcs = ["callback_packet_calculator.cc"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":callback_packet_calculator_cc_proto", "//mediapipe/framework:calculator_base", From ab0b0ab558c633bc996c41923f9325269cc76e3c Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 6 Dec 2022 10:22:31 -0800 Subject: [PATCH 096/346] Change visibility for MP Tasks Web to public PiperOrigin-RevId: 493343996 --- mediapipe/tasks/web/audio/BUILD | 1 + mediapipe/tasks/web/audio/audio_classifier/BUILD | 2 ++ mediapipe/tasks/web/audio/audio_embedder/BUILD | 2 ++ mediapipe/tasks/web/core/BUILD | 1 + mediapipe/tasks/web/text/BUILD | 1 + mediapipe/tasks/web/text/text_classifier/BUILD | 2 ++ mediapipe/tasks/web/text/text_embedder/BUILD | 2 ++ mediapipe/tasks/web/vision/BUILD | 1 + mediapipe/tasks/web/vision/gesture_recognizer/BUILD | 2 ++ mediapipe/tasks/web/vision/hand_landmarker/BUILD | 2 ++ mediapipe/tasks/web/vision/image_classifier/BUILD | 2 ++ mediapipe/tasks/web/vision/image_embedder/BUILD | 2 ++ mediapipe/tasks/web/vision/object_detector/BUILD | 2 ++ 13 files changed, 22 insertions(+) diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index d08602521..9d26f1118 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -7,6 +7,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "audio_lib", srcs = ["index.ts"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/audio/audio_classifier", "//mediapipe/tasks/web/audio/audio_embedder", diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 6f785dd0d..dc82a4a24 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -11,6 +11,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "audio_classifier", srcs = ["audio_classifier.ts"], + visibility = ["//visibility:public"], deps = [ ":audio_classifier_types", "//mediapipe/framework:calculator_jspb_proto", @@ -35,6 +36,7 @@ mediapipe_ts_declaration( "audio_classifier_options.d.ts", "audio_classifier_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index 0555bb639..dc84d0cd6 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -11,6 +11,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "audio_embedder", srcs = ["audio_embedder.ts"], + visibility = ["//visibility:public"], deps = [ ":audio_embedder_types", "//mediapipe/framework:calculator_jspb_proto", @@ -35,6 +36,7 @@ mediapipe_ts_declaration( "audio_embedder_options.d.ts", "audio_embedder_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index de429690d..be1b71f5d 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -28,6 +28,7 @@ mediapipe_ts_library( mediapipe_ts_library( name = "fileset_resolver", srcs = ["fileset_resolver.ts"], + visibility = ["//visibility:public"], deps = [":core"], ) diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index 159db1a0d..32f43d4b6 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -7,6 +7,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "text_lib", srcs = ["index.ts"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/text/text_classifier", diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 2a7de21d6..07f78ac20 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -12,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "text_classifier", srcs = ["text_classifier.ts"], + visibility = ["//visibility:public"], deps = [ ":text_classifier_types", "//mediapipe/framework:calculator_jspb_proto", @@ -36,6 +37,7 @@ mediapipe_ts_declaration( "text_classifier_options.d.ts", "text_classifier_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 17d105258..7d796fb7e 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -12,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "text_embedder", srcs = ["text_embedder.ts"], + visibility = ["//visibility:public"], deps = [ ":text_embedder_types", "//mediapipe/framework:calculator_jspb_proto", @@ -36,6 +37,7 @@ mediapipe_ts_declaration( "text_embedder_options.d.ts", "text_embedder_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 42bc0a494..93493e873 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -7,6 +7,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "vision_lib", srcs = ["index.ts"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/vision/gesture_recognizer", diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index ddfd1a327..6e2e56196 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -12,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "gesture_recognizer", srcs = ["gesture_recognizer.ts"], + visibility = ["//visibility:public"], deps = [ ":gesture_recognizer_types", "//mediapipe/framework:calculator_jspb_proto", @@ -42,6 +43,7 @@ mediapipe_ts_declaration( "gesture_recognizer_options.d.ts", "gesture_recognizer_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index fc3e6ef1f..520898e34 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -12,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "hand_landmarker", srcs = ["hand_landmarker.ts"], + visibility = ["//visibility:public"], deps = [ ":hand_landmarker_types", "//mediapipe/framework:calculator_jspb_proto", @@ -38,6 +39,7 @@ mediapipe_ts_declaration( "hand_landmarker_options.d.ts", "hand_landmarker_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index ebe64ecf4..848c162ae 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -11,6 +11,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "image_classifier", srcs = ["image_classifier.ts"], + visibility = ["//visibility:public"], deps = [ ":image_classifier_types", "//mediapipe/framework:calculator_jspb_proto", @@ -35,6 +36,7 @@ mediapipe_ts_declaration( "image_classifier_options.d.ts", "image_classifier_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index 2f012dc5e..6c9d80fb1 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -11,6 +11,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "image_embedder", srcs = ["image_embedder.ts"], + visibility = ["//visibility:public"], deps = [ ":image_embedder_types", "//mediapipe/framework:calculator_jspb_proto", @@ -36,6 +37,7 @@ mediapipe_ts_declaration( "image_embedder_options.d.ts", "image_embedder_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index 198585258..f73790895 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -12,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "object_detector", srcs = ["object_detector.ts"], + visibility = ["//visibility:public"], deps = [ ":object_detector_types", "//mediapipe/framework:calculator_jspb_proto", @@ -32,6 +33,7 @@ mediapipe_ts_declaration( "object_detector_options.d.ts", "object_detector_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", From c6e6f9e0b9b35d055cd83016e468a8c30a7b153b Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 6 Dec 2022 11:05:47 -0800 Subject: [PATCH 097/346] Fix aar breakage caused by missing "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite". PiperOrigin-RevId: 493357585 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index c6aba3c66..727d020a6 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -21,7 +21,6 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", - "//mediapipe/tasks/cc/components/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite", @@ -43,6 +42,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", From 6deef1a5f13c4af5e38abe96f8aabbba733dcdcb Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 6 Dec 2022 12:07:51 -0800 Subject: [PATCH 098/346] Allow specifying tag_suffix in the templated CreateModelResources method. PiperOrigin-RevId: 493375701 --- mediapipe/tasks/cc/core/model_task_graph.cc | 2 +- mediapipe/tasks/cc/core/model_task_graph.h | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/cc/core/model_task_graph.cc b/mediapipe/tasks/cc/core/model_task_graph.cc index 66434483b..0cb556ec2 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.cc +++ b/mediapipe/tasks/cc/core/model_task_graph.cc @@ -186,7 +186,7 @@ absl::StatusOr ModelTaskGraph::CreateModelResources( absl::StatusOr ModelTaskGraph::CreateModelAssetBundleResources( SubgraphContext* sc, std::unique_ptr external_file, - const std::string tag_suffix) { + std::string tag_suffix) { auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); bool has_file_pointer_meta = external_file->has_file_pointer_meta(); // if external file is set by file pointer, no need to add the model asset diff --git a/mediapipe/tasks/cc/core/model_task_graph.h b/mediapipe/tasks/cc/core/model_task_graph.h index 50dcc903b..3068b2c46 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.h +++ b/mediapipe/tasks/cc/core/model_task_graph.h @@ -59,14 +59,16 @@ class ModelTaskGraph : public Subgraph { // creates a local model resources object that can only be used in the graph // construction stage. The returned model resources pointer will provide graph // authors with the access to the metadata extractor and the tflite model. + // If more than one model resources are created in a graph, the model + // resources graph service add the tag_suffix to support multiple resources. template absl::StatusOr CreateModelResources( - SubgraphContext* sc) { + SubgraphContext* sc, std::string tag_suffix = "") { auto external_file = std::make_unique(); external_file->Swap(sc->MutableOptions() ->mutable_base_options() ->mutable_model_asset()); - return CreateModelResources(sc, std::move(external_file)); + return CreateModelResources(sc, std::move(external_file), tag_suffix); } // If the model resources graph service is available, creates a model @@ -83,7 +85,7 @@ class ModelTaskGraph : public Subgraph { // resources. absl::StatusOr CreateModelResources( SubgraphContext* sc, std::unique_ptr external_file, - const std::string tag_suffix = ""); + std::string tag_suffix = ""); // If the model resources graph service is available, creates a model asset // bundle resources object from the subgraph context, and caches the created From cdc14522e2821a60ec1ee208430e364917e21985 Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Tue, 6 Dec 2022 13:01:06 -0800 Subject: [PATCH 099/346] Added issue templates for MP Preview. PiperOrigin-RevId: 493389856 --- .../ISSUE_TEMPLATE/11-tasks-issue.md | 25 +++++++++++++++++++ .../ISSUE_TEMPLATE/12-model-maker-issue.md | 25 +++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md create mode 100644 mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md diff --git a/mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md b/mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md new file mode 100644 index 000000000..ab7b38368 --- /dev/null +++ b/mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md @@ -0,0 +1,25 @@ +--- +name: "Tasks Issue" +about: Use this template for assistance with a specific task, such as "Gesture Recognition" or "Object Detection", including inference model usage/training etc. +labels: type:support + +--- +Please make sure that this is a [Tasks](https://developers.google.com/mediapipe/solutions) issue. + +**System information** (Please provide as much relevant information as possible) +- Have I written custom code (as opposed to using a stock example script provided in MediaPipe): +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4): +- MediaPipe Tasks SDK version: +- Task name (e.g. Object detection, Gesture recognition etc.): +- Programming Language and version ( e.g. C++, Python, Java): + +**Describe the expected behavior:** + +**Standalone code you may have used to try to get what you need :** + +If there is a problem, provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab, GitHub repo link or anything that we can use to reproduce the problem: + +**Other info / Complete Logs :** +Include any logs or source code that would be helpful to +diagnose the problem. If including tracebacks, please include the full +traceback. Large logs and files should be attached: diff --git a/mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md b/mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md new file mode 100644 index 000000000..687360957 --- /dev/null +++ b/mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md @@ -0,0 +1,25 @@ +--- +name: "Model Maker Issue" +about: Use this template for assistance with a specific task, such as "Gesture Recognition" or "Object Detection", including inference model usage/training etc. +labels: type:support + +--- +Please make sure that this is a [Model Maker](https://developers.google.com/mediapipe/solutions) issue. + +**System information** (Please provide as much relevant information as possible) +- Have I written custom code (as opposed to using a stock example script provided in MediaPipe): +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): +- Python version (e.g. 3.8): +- [MediaPipe Model Maker version](https://pypi.org/project/mediapipe-model-maker/): +- Task name (e.g. Image classification, Gesture recognition etc.): + +**Describe the expected behavior:** + +**Standalone code you may have used to try to get what you need :** + +If there is a problem, provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab, GitHub repo link or anything that we can use to reproduce the problem: + +**Other info / Complete Logs :** +Include any logs or source code that would be helpful to +diagnose the problem. If including tracebacks, please include the full +traceback. Large logs and files should be attached: From 0f32072804a4e078c9f64ae8cb48d9b1777a679f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 6 Dec 2022 14:01:49 -0800 Subject: [PATCH 100/346] Move ISSUE_TEMPLATAE files to .github folder PiperOrigin-RevId: 493405734 --- .../opensource_only => .github}/ISSUE_TEMPLATE/11-tasks-issue.md | 0 .../ISSUE_TEMPLATE/12-model-maker-issue.md | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename {mediapipe/opensource_only => .github}/ISSUE_TEMPLATE/11-tasks-issue.md (100%) rename {mediapipe/opensource_only => .github}/ISSUE_TEMPLATE/12-model-maker-issue.md (100%) diff --git a/mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md b/.github/ISSUE_TEMPLATE/11-tasks-issue.md similarity index 100% rename from mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md rename to .github/ISSUE_TEMPLATE/11-tasks-issue.md diff --git a/mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md similarity index 100% rename from mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md rename to .github/ISSUE_TEMPLATE/12-model-maker-issue.md From 9bc7b120de85d4991292d831bd844264c783350b Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Tue, 6 Dec 2022 15:12:25 -0800 Subject: [PATCH 101/346] Tweaked the issue templates. PiperOrigin-RevId: 493424927 --- .github/ISSUE_TEMPLATE/11-tasks-issue.md | 2 +- .github/ISSUE_TEMPLATE/12-model-maker-issue.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/11-tasks-issue.md b/.github/ISSUE_TEMPLATE/11-tasks-issue.md index ab7b38368..264371120 100644 --- a/.github/ISSUE_TEMPLATE/11-tasks-issue.md +++ b/.github/ISSUE_TEMPLATE/11-tasks-issue.md @@ -1,6 +1,6 @@ --- name: "Tasks Issue" -about: Use this template for assistance with a specific task, such as "Gesture Recognition" or "Object Detection", including inference model usage/training etc. +about: Use this template for assistance with using MediaPipe Tasks to deploy on-device ML solutions (e.g. gesture recognition etc.) on supported platforms. labels: type:support --- diff --git a/.github/ISSUE_TEMPLATE/12-model-maker-issue.md b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md index 687360957..258390d5e 100644 --- a/.github/ISSUE_TEMPLATE/12-model-maker-issue.md +++ b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md @@ -1,6 +1,6 @@ --- name: "Model Maker Issue" -about: Use this template for assistance with a specific task, such as "Gesture Recognition" or "Object Detection", including inference model usage/training etc. +about: Use this template for assistance with using MediaPipe Model Maker to create custom on-device ML solutions. labels: type:support --- From fca0f5806b470a47a3c74a7085d32c32a12d61f1 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 6 Dec 2022 15:16:42 -0800 Subject: [PATCH 102/346] Create Build Rules for Apple Frameworks PiperOrigin-RevId: 493426040 --- mediapipe/examples/ios/common/BUILD | 10 ++-- mediapipe/examples/ios/faceeffect/BUILD | 10 ++-- mediapipe/gpu/BUILD | 64 ++++++++-------------- mediapipe/objc/BUILD | 68 ++++++++++------------- third_party/apple_frameworks/BUILD | 73 +++++++++++++++++++++++++ 5 files changed, 134 insertions(+), 91 deletions(-) create mode 100644 third_party/apple_frameworks/BUILD diff --git a/mediapipe/examples/ios/common/BUILD b/mediapipe/examples/ios/common/BUILD index 9b8f8a968..bfa770cec 100644 --- a/mediapipe/examples/ios/common/BUILD +++ b/mediapipe/examples/ios/common/BUILD @@ -29,12 +29,6 @@ objc_library( "Base.lproj/LaunchScreen.storyboard", "Base.lproj/Main.storyboard", ], - sdk_frameworks = [ - "AVFoundation", - "CoreGraphics", - "CoreMedia", - "UIKit", - ], visibility = [ "//mediapipe:__subpackages__", ], @@ -42,6 +36,10 @@ objc_library( "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_input_sources_ios", "//mediapipe/objc:mediapipe_layer_renderer", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:UIKit", ], ) diff --git a/mediapipe/examples/ios/faceeffect/BUILD b/mediapipe/examples/ios/faceeffect/BUILD index 50a6f68bd..e0c3abb86 100644 --- a/mediapipe/examples/ios/faceeffect/BUILD +++ b/mediapipe/examples/ios/faceeffect/BUILD @@ -73,13 +73,11 @@ objc_library( "//mediapipe/modules/face_landmark:face_landmark.tflite", ], features = ["-layering_check"], - sdk_frameworks = [ - "AVFoundation", - "CoreGraphics", - "CoreMedia", - "UIKit", - ], deps = [ + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:UIKit", "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_input_sources_ios", "//mediapipe/objc:mediapipe_layer_renderer", diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 7a8aa6557..f5cb9f715 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -472,13 +472,13 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "Accelerate", - "CoreGraphics", - "CoreVideo", - ], visibility = ["//visibility:public"], - deps = ["//mediapipe/objc:util"], + deps = [ + "//mediapipe/objc:util", + "//third_party/apple_frameworks:Accelerate", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreVideo", + ], ) objc_library( @@ -510,13 +510,11 @@ objc_library( "-x objective-c++", "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", "@com_google_absl//absl/time", "@google_toolbox_for_mac//:GTM_Defines", ], @@ -808,15 +806,13 @@ objc_library( "-Wno-shorten-64-to-32", ], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":gpu_shared_data_internal", ":graph_support", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", "@google_toolbox_for_mac//:GTM_Defines", ], ) @@ -1020,16 +1016,14 @@ objc_library( name = "metal_copy_calculator", srcs = ["MetalCopyCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/gpu:copy_calculator_cc_proto", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1038,15 +1032,13 @@ objc_library( name = "metal_rgb_weight_calculator", srcs = ["MetalRgbWeightCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1055,15 +1047,13 @@ objc_library( name = "metal_sobel_calculator", srcs = ["MetalSobelCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1072,15 +1062,13 @@ objc_library( name = "metal_sobel_compute_calculator", srcs = ["MetalSobelComputeCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1090,15 +1078,13 @@ objc_library( srcs = ["MPSSobelCalculator.mm"], copts = ["-std=c++17"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - "MetalPerformanceShaders", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", + "//third_party/apple_frameworks:MetalPerformanceShaders", ], alwayslink = 1, ) @@ -1106,15 +1092,13 @@ objc_library( objc_library( name = "mps_threshold_calculator", srcs = ["MPSThresholdCalculator.mm"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - "MetalPerformanceShaders", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", + "//third_party/apple_frameworks:MetalPerformanceShaders", ], alwayslink = 1, ) diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index fafdfee8a..c71c02b6d 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -68,7 +68,6 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = ["Accelerate"], # This build rule is public to allow external customers to build their own iOS apps. visibility = ["//visibility:public"], deps = [ @@ -90,6 +89,7 @@ objc_library( "//mediapipe/gpu:metal_shared_resources", "//mediapipe/gpu:pixel_buffer_pool_util", "//mediapipe/util:cpu_util", + "//third_party/apple_frameworks:Accelerate", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -120,13 +120,13 @@ objc_library( ], "//conditions:default": [], }), - sdk_frameworks = [ - "AVFoundation", - "CoreVideo", - "Foundation", - ], # This build rule is public to allow external customers to build their own iOS apps. visibility = ["//visibility:public"], + deps = [ + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Foundation", + ], ) objc_library( @@ -140,16 +140,14 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "Foundation", - "GLKit", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":mediapipe_framework_ios", "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_simple_shaders", + "//third_party/apple_frameworks:Foundation", + "//third_party/apple_frameworks:GLKit", ], ) @@ -164,16 +162,14 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "Foundation", - "GLKit", - ], # This build rule is public to allow external customers to build their own iOS apps. visibility = ["//visibility:public"], deps = [ ":mediapipe_framework_ios", ":mediapipe_gl_view_renderer", "//mediapipe/gpu:gl_calculator_helper", + "//third_party/apple_frameworks:Foundation", + "//third_party/apple_frameworks:GLKit", ], ) @@ -188,13 +184,11 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "CoreVideo", - "Foundation", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Foundation", "@com_google_absl//absl/strings", ], ) @@ -211,23 +205,21 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "AVFoundation", - "Accelerate", - "CoreGraphics", - "CoreMedia", - "CoreVideo", - "GLKit", - "OpenGLES", - "QuartzCore", - "UIKit", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":CGImageRefUtils", ":Weakify", ":mediapipe_framework_ios", "//mediapipe/framework:calculator_framework", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:Accelerate", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:GLKit", + "//third_party/apple_frameworks:OpenGLES", + "//third_party/apple_frameworks:QuartzCore", + "//third_party/apple_frameworks:UIKit", ], ) @@ -245,16 +237,6 @@ objc_library( data = [ "testdata/googlelogo_color_272x92dp.png", ], - sdk_frameworks = [ - "AVFoundation", - "Accelerate", - "CoreGraphics", - "CoreMedia", - "CoreVideo", - "GLKit", - "QuartzCore", - "UIKit", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":CGImageRefUtils", @@ -263,6 +245,14 @@ objc_library( ":mediapipe_framework_ios", ":mediapipe_input_sources_ios", "//mediapipe/calculators/core:pass_through_calculator", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:Accelerate", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:GLKit", + "//third_party/apple_frameworks:QuartzCore", + "//third_party/apple_frameworks:UIKit", ], ) diff --git a/third_party/apple_frameworks/BUILD b/third_party/apple_frameworks/BUILD new file mode 100644 index 000000000..05f830e81 --- /dev/null +++ b/third_party/apple_frameworks/BUILD @@ -0,0 +1,73 @@ +# Build rules to inject Apple Frameworks + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "CoreGraphics", + linkopts = ["-framework CoreGraphics"], +) + +cc_library( + name = "CoreMedia", + linkopts = ["-framework CoreMedia"], +) + +cc_library( + name = "UIKit", + linkopts = ["-framework UIKit"], +) + +cc_library( + name = "Accelerate", + linkopts = ["-framework Accelerate"], +) + +cc_library( + name = "CoreVideo", + linkopts = ["-framework CoreVideo"], +) + +cc_library( + name = "Metal", + linkopts = ["-framework Metal"], +) + +cc_library( + name = "MetalPerformanceShaders", + linkopts = ["-framework MetalPerformanceShaders"], +) + +cc_library( + name = "AVFoundation", + linkopts = ["-framework AVFoundation"], +) + +cc_library( + name = "Foundation", + linkopts = ["-framework Foundation"], +) + +cc_library( + name = "CoreImage", + linkopts = ["-framework CoreImage"], +) + +cc_library( + name = "XCTest", + linkopts = ["-framework XCTest"], +) + +cc_library( + name = "GLKit", + linkopts = ["-framework GLKit"], +) + +cc_library( + name = "OpenGLES", + linkopts = ["-framework OpenGLES"], +) + +cc_library( + name = "QuartzCore", + linkopts = ["-framework QuartzCore"], +) From 576c6da173c4b84b13787c0e6926acab05118880 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 6 Dec 2022 15:22:03 -0800 Subject: [PATCH 103/346] Internal change PiperOrigin-RevId: 493427500 --- mediapipe/tasks/python/audio/BUILD | 4 +- .../tasks/python/audio/audio_classifier.py | 34 ++++++++-- .../tasks/python/audio/audio_embedder.py | 21 ++++-- .../tasks/python/components/processors/BUILD | 9 --- .../python/components/processors/__init__.py | 3 - .../components/processors/embedder_options.py | 68 ------------------- mediapipe/tasks/python/components/utils/BUILD | 5 +- .../components/utils/cosine_similarity.py | 2 - mediapipe/tasks/python/test/audio/BUILD | 2 - .../test/audio/audio_classifier_test.py | 20 ++---- .../python/test/audio/audio_embedder_test.py | 10 +-- mediapipe/tasks/python/test/text/BUILD | 2 - .../python/test/text/text_classifier_test.py | 2 - .../python/test/text/text_embedder_test.py | 10 +-- mediapipe/tasks/python/test/vision/BUILD | 2 - .../test/vision/image_classifier_test.py | 52 +++++--------- .../python/test/vision/image_embedder_test.py | 10 +-- mediapipe/tasks/python/text/BUILD | 4 +- .../tasks/python/text/text_classifier.py | 35 ++++++++-- mediapipe/tasks/python/text/text_embedder.py | 20 ++++-- mediapipe/tasks/python/vision/BUILD | 4 +- .../tasks/python/vision/image_classifier.py | 35 ++++++++-- .../tasks/python/vision/image_embedder.py | 20 ++++-- 23 files changed, 162 insertions(+), 212 deletions(-) delete mode 100644 mediapipe/tasks/python/components/processors/embedder_options.py diff --git a/mediapipe/tasks/python/audio/BUILD b/mediapipe/tasks/python/audio/BUILD index 2e5815ff0..ce7c5ce08 100644 --- a/mediapipe/tasks/python/audio/BUILD +++ b/mediapipe/tasks/python/audio/BUILD @@ -29,11 +29,11 @@ py_library( "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/audio/core:base_audio_task_api", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -51,11 +51,11 @@ py_library( "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/audio/core:base_audio_task_api", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/audio/audio_classifier.py b/mediapipe/tasks/python/audio/audio_classifier.py index d82b6fe27..cc87d6221 100644 --- a/mediapipe/tasks/python/audio/audio_classifier.py +++ b/mediapipe/tasks/python/audio/audio_classifier.py @@ -21,11 +21,11 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.audio.audio_classifier.proto import audio_classifier_graph_options_pb2 from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module from mediapipe.tasks.python.audio.core import base_audio_task_api from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options as classifier_options_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -34,7 +34,7 @@ AudioClassifierResult = classification_result_module.ClassificationResult _AudioClassifierGraphOptionsProto = audio_classifier_graph_options_pb2.AudioClassifierGraphOptions _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options_module.ClassifierOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _RunningMode = running_mode_module.AudioTaskRunningMode _TaskInfo = task_info_module.TaskInfo @@ -62,16 +62,31 @@ class AudioClassifierOptions: mode for running classification on the audio stream, such as from microphone. In this mode, the "result_callback" below must be specified to receive the classification results asynchronously. - classifier_options: Options for configuring the classifier behavior, such as - score threshold, number of results, etc. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. result_callback: The user-defined result callback for processing audio stream data. The result callback should only be specified when the running mode is set to the audio stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - classifier_options: Optional[_ClassifierOptions] = dataclasses.field( - default_factory=_ClassifierOptions) + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None @doc_controls.do_not_generate_docs @@ -79,7 +94,12 @@ class AudioClassifierOptions: """Generates an AudioClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True - classifier_options_proto = self.classifier_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) return _AudioClassifierGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/audio/audio_embedder.py b/mediapipe/tasks/python/audio/audio_embedder.py index 629e21882..4c37783e9 100644 --- a/mediapipe/tasks/python/audio/audio_embedder.py +++ b/mediapipe/tasks/python/audio/audio_embedder.py @@ -21,11 +21,11 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.audio.audio_embedder.proto import audio_embedder_graph_options_pb2 from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module from mediapipe.tasks.python.audio.core import base_audio_task_api from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -35,7 +35,7 @@ AudioEmbedderResult = embedding_result_module.EmbeddingResult _AudioEmbedderGraphOptionsProto = audio_embedder_graph_options_pb2.AudioEmbedderGraphOptions _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options_module.EmbedderOptions +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _RunningMode = running_mode_module.AudioTaskRunningMode _TaskInfo = task_info_module.TaskInfo @@ -63,16 +63,22 @@ class AudioEmbedderOptions: stream mode for running embedding extraction on the audio stream, such as from microphone. In this mode, the "result_callback" below must be specified to receive the embedding results asynchronously. - embedder_options: Options for configuring the embedder behavior, such as - l2_normalize and quantize. + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. result_callback: The user-defined result callback for processing audio stream data. The result callback should only be specified when the running mode is set to the audio stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - embedder_options: Optional[_EmbedderOptions] = dataclasses.field( - default_factory=_EmbedderOptions) + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None @doc_controls.do_not_generate_docs @@ -80,7 +86,8 @@ class AudioEmbedderOptions: """Generates an AudioEmbedderOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True - embedder_options_proto = self.embedder_options.to_pb2() + embedder_options_proto = _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) return _AudioEmbedderGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/components/processors/BUILD b/mediapipe/tasks/python/components/processors/BUILD index eef368db0..f87a579b0 100644 --- a/mediapipe/tasks/python/components/processors/BUILD +++ b/mediapipe/tasks/python/components/processors/BUILD @@ -28,12 +28,3 @@ py_library( "//mediapipe/tasks/python/core:optional_dependencies", ], ) - -py_library( - name = "embedder_options", - srcs = ["embedder_options.py"], - deps = [ - "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", - "//mediapipe/tasks/python/core:optional_dependencies", - ], -) diff --git a/mediapipe/tasks/python/components/processors/__init__.py b/mediapipe/tasks/python/components/processors/__init__.py index adcb38757..0eb73abe0 100644 --- a/mediapipe/tasks/python/components/processors/__init__.py +++ b/mediapipe/tasks/python/components/processors/__init__.py @@ -15,12 +15,9 @@ """MediaPipe Tasks Components Processors API.""" import mediapipe.tasks.python.components.processors.classifier_options -import mediapipe.tasks.python.components.processors.embedder_options ClassifierOptions = classifier_options.ClassifierOptions -EmbedderOptions = embedder_options.EmbedderOptions # Remove unnecessary modules to avoid duplication in API docs. del classifier_options -del embedder_options del mediapipe diff --git a/mediapipe/tasks/python/components/processors/embedder_options.py b/mediapipe/tasks/python/components/processors/embedder_options.py deleted file mode 100644 index c86a91105..000000000 --- a/mediapipe/tasks/python/components/processors/embedder_options.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Embedder options data class.""" - -import dataclasses -from typing import Any, Optional - -from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 -from mediapipe.tasks.python.core.optional_dependencies import doc_controls - -_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions - - -@dataclasses.dataclass -class EmbedderOptions: - """Shared options used by all embedding extraction tasks. - - Attributes: - l2_normalize: Whether to normalize the returned feature vector with L2 norm. - Use this option only if the model does not already contain a native - L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and - L2 norm is thus achieved through TF Lite inference. - quantize: Whether the returned embedding should be quantized to bytes via - scalar quantization. Embeddings are implicitly assumed to be unit-norm and - therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use - the l2_normalize option if this is not the case. - """ - - l2_normalize: Optional[bool] = None - quantize: Optional[bool] = None - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _EmbedderOptionsProto: - """Generates a EmbedderOptions protobuf object.""" - return _EmbedderOptionsProto( - l2_normalize=self.l2_normalize, quantize=self.quantize) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _EmbedderOptionsProto) -> 'EmbedderOptions': - """Creates a `EmbedderOptions` object from the given protobuf object.""" - return EmbedderOptions( - l2_normalize=pb2_obj.l2_normalize, quantize=pb2_obj.quantize) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, EmbedderOptions): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/utils/BUILD b/mediapipe/tasks/python/components/utils/BUILD index b64d04c72..31114f326 100644 --- a/mediapipe/tasks/python/components/utils/BUILD +++ b/mediapipe/tasks/python/components/utils/BUILD @@ -23,8 +23,5 @@ licenses(["notice"]) py_library( name = "cosine_similarity", srcs = ["cosine_similarity.py"], - deps = [ - "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", - ], + deps = ["//mediapipe/tasks/python/components/containers:embedding_result"], ) diff --git a/mediapipe/tasks/python/components/utils/cosine_similarity.py b/mediapipe/tasks/python/components/utils/cosine_similarity.py index 486c02ece..ff8979458 100644 --- a/mediapipe/tasks/python/components/utils/cosine_similarity.py +++ b/mediapipe/tasks/python/components/utils/cosine_similarity.py @@ -16,10 +16,8 @@ import numpy as np from mediapipe.tasks.python.components.containers import embedding_result -from mediapipe.tasks.python.components.processors import embedder_options _Embedding = embedding_result.Embedding -_EmbedderOptions = embedder_options.EmbedderOptions def _compute_cosine_similarity(u, v): diff --git a/mediapipe/tasks/python/test/audio/BUILD b/mediapipe/tasks/python/test/audio/BUILD index 9278cea55..43f1d417c 100644 --- a/mediapipe/tasks/python/test/audio/BUILD +++ b/mediapipe/tasks/python/test/audio/BUILD @@ -30,7 +30,6 @@ py_test( "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", ], @@ -48,7 +47,6 @@ py_test( "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", ], diff --git a/mediapipe/tasks/python/test/audio/audio_classifier_test.py b/mediapipe/tasks/python/test/audio/audio_classifier_test.py index 0d067e587..75146547c 100644 --- a/mediapipe/tasks/python/test/audio/audio_classifier_test.py +++ b/mediapipe/tasks/python/test/audio/audio_classifier_test.py @@ -27,7 +27,6 @@ from mediapipe.tasks.python.audio import audio_classifier from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils @@ -36,7 +35,6 @@ _AudioClassifierOptions = audio_classifier.AudioClassifierOptions _AudioClassifierResult = classification_result_module.ClassificationResult _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options.ClassifierOptions _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode _YAMNET_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite' @@ -210,8 +208,7 @@ class AudioClassifierTest(parameterized.TestCase): with _AudioClassifier.create_from_options( _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - max_results=1))) as classifier: + max_results=1)) as classifier: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -222,8 +219,7 @@ class AudioClassifierTest(parameterized.TestCase): with _AudioClassifier.create_from_options( _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - score_threshold=0.9))) as classifier: + score_threshold=0.9)) as classifier: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -234,8 +230,7 @@ class AudioClassifierTest(parameterized.TestCase): with _AudioClassifier.create_from_options( _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - category_allowlist=['Speech']))) as classifier: + category_allowlist=['Speech'])) as classifier: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -250,8 +245,8 @@ class AudioClassifierTest(parameterized.TestCase): r'exclusive options.'): options = _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - category_allowlist=['foo'], category_denylist=['bar'])) + category_allowlist=['foo'], + category_denylist=['bar']) with _AudioClassifier.create_from_options(options) as unused_classifier: pass @@ -278,8 +273,7 @@ class AudioClassifierTest(parameterized.TestCase): _AudioClassifierOptions( base_options=_BaseOptions( model_asset_path=self.two_heads_model_path), - classifier_options=_ClassifierOptions( - max_results=1))) as classifier: + max_results=1)) as classifier: for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -364,7 +358,7 @@ class AudioClassifierTest(parameterized.TestCase): options = _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), running_mode=_RUNNING_MODE.AUDIO_STREAM, - classifier_options=_ClassifierOptions(max_results=1), + max_results=1, result_callback=save_result) classifier = _AudioClassifier.create_from_options(options) audio_data_list = self._read_wav_file_as_stream(audio_file) diff --git a/mediapipe/tasks/python/test/audio/audio_embedder_test.py b/mediapipe/tasks/python/test/audio/audio_embedder_test.py index 2e38ea2ee..f280235d7 100644 --- a/mediapipe/tasks/python/test/audio/audio_embedder_test.py +++ b/mediapipe/tasks/python/test/audio/audio_embedder_test.py @@ -26,7 +26,6 @@ from scipy.io import wavfile from mediapipe.tasks.python.audio import audio_embedder from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.components.containers import audio_data as audio_data_module -from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils @@ -35,7 +34,6 @@ _AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions _AudioEmbedderResult = audio_embedder.AudioEmbedderResult _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options.EmbedderOptions _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode _YAMNET_MODEL_FILE = 'yamnet_embedding_metadata.tflite' @@ -172,9 +170,7 @@ class AudioEmbedderTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _AudioEmbedderOptions( - base_options=base_options, - embedder_options=_EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize)) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) with _AudioEmbedder.create_from_options(options) as embedder: embedding_result0_list = embedder.embed(self._read_wav_file(audio_file0)) @@ -291,8 +287,8 @@ class AudioEmbedderTest(parameterized.TestCase): options = _AudioEmbedderOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), running_mode=_RUNNING_MODE.AUDIO_STREAM, - embedder_options=_EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize), + l2_normalize=l2_normalize, + quantize=quantize, result_callback=save_result) with _AudioEmbedder.create_from_options(options) as embedder: diff --git a/mediapipe/tasks/python/test/text/BUILD b/mediapipe/tasks/python/test/text/BUILD index 38e56bdb2..0e2b06012 100644 --- a/mediapipe/tasks/python/test/text/BUILD +++ b/mediapipe/tasks/python/test/text/BUILD @@ -28,7 +28,6 @@ py_test( deps = [ "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/text:text_classifier", @@ -44,7 +43,6 @@ py_test( ], deps = [ "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/text:text_embedder", diff --git a/mediapipe/tasks/python/test/text/text_classifier_test.py b/mediapipe/tasks/python/test/text/text_classifier_test.py index 8678d2194..8df7dce86 100644 --- a/mediapipe/tasks/python/test/text/text_classifier_test.py +++ b/mediapipe/tasks/python/test/text/text_classifier_test.py @@ -21,14 +21,12 @@ from absl.testing import parameterized from mediapipe.tasks.python.components.containers import category from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.text import text_classifier TextClassifierResult = classification_result_module.ClassificationResult _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options.ClassifierOptions _Category = category.Category _Classifications = classification_result_module.Classifications _TextClassifier = text_classifier.TextClassifier diff --git a/mediapipe/tasks/python/test/text/text_embedder_test.py b/mediapipe/tasks/python/test/text/text_embedder_test.py index c9090026c..1346ba373 100644 --- a/mediapipe/tasks/python/test/text/text_embedder_test.py +++ b/mediapipe/tasks/python/test/text/text_embedder_test.py @@ -21,13 +21,11 @@ from absl.testing import parameterized import numpy as np from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.text import text_embedder _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options_module.EmbedderOptions _Embedding = embedding_result_module.Embedding _TextEmbedder = text_embedder.TextEmbedder _TextEmbedderOptions = text_embedder.TextEmbedderOptions @@ -128,10 +126,8 @@ class TextEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _TextEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) embedder = _TextEmbedder.create_from_options(options) # Extracts both embeddings. @@ -178,10 +174,8 @@ class TextEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _TextEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) with _TextEmbedder.create_from_options(options) as embedder: # Extracts both embeddings. positive_text0 = "it's a charming and often affecting journey" diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 066107421..48ecc30b3 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -49,7 +49,6 @@ py_test( "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_classifier", @@ -69,7 +68,6 @@ py_test( "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_embedder", diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 77f16278f..cbeaf36bd 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -26,7 +26,6 @@ from mediapipe.python._framework_bindings import image from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import rect -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_classifier @@ -36,7 +35,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode ImageClassifierResult = classification_result_module.ClassificationResult _Rect = rect.Rect _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options.ClassifierOptions _Category = category_module.Category _Classifications = classification_result_module.Classifications _Image = image.Image @@ -171,9 +169,8 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - custom_classifier_options = _ClassifierOptions(max_results=max_results) options = _ImageClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + base_options=base_options, max_results=max_results) classifier = _ImageClassifier.create_from_options(options) # Performs image classification on the input. @@ -200,9 +197,8 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - custom_classifier_options = _ClassifierOptions(max_results=max_results) options = _ImageClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + base_options=base_options, max_results=max_results) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -212,9 +208,7 @@ class ImageClassifierTest(parameterized.TestCase): def test_classify_succeeds_with_region_of_interest(self): base_options = _BaseOptions(model_asset_path=self.model_path) - custom_classifier_options = _ClassifierOptions(max_results=1) - options = _ImageClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + options = _ImageClassifierOptions(base_options=base_options, max_results=1) with _ImageClassifier.create_from_options(options) as classifier: # Load the test image. test_image = _Image.create_from_file( @@ -230,11 +224,9 @@ class ImageClassifierTest(parameterized.TestCase): _generate_soccer_ball_results().to_pb2()) def test_score_threshold_option(self): - custom_classifier_options = _ClassifierOptions( - score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + score_threshold=_SCORE_THRESHOLD) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -249,11 +241,9 @@ class ImageClassifierTest(parameterized.TestCase): f'{classification}') def test_max_results_option(self): - custom_classifier_options = _ClassifierOptions( - score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + score_threshold=_SCORE_THRESHOLD) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -263,11 +253,9 @@ class ImageClassifierTest(parameterized.TestCase): len(categories), _MAX_RESULTS, 'Too many results returned.') def test_allow_list_option(self): - custom_classifier_options = _ClassifierOptions( - category_allowlist=_ALLOW_LIST) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + category_allowlist=_ALLOW_LIST) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -280,10 +268,9 @@ class ImageClassifierTest(parameterized.TestCase): f'Label {label} found but not in label allow list') def test_deny_list_option(self): - custom_classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + category_denylist=_DENY_LIST) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -301,19 +288,17 @@ class ImageClassifierTest(parameterized.TestCase): ValueError, r'`category_allowlist` and `category_denylist` are mutually ' r'exclusive options.'): - custom_classifier_options = _ClassifierOptions( - category_allowlist=['foo'], category_denylist=['bar']) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + category_allowlist=['foo'], + category_denylist=['bar']) with _ImageClassifier.create_from_options(options) as unused_classifier: pass def test_empty_classification_outputs(self): - custom_classifier_options = _ClassifierOptions(score_threshold=1) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + score_threshold=1) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -386,11 +371,10 @@ class ImageClassifierTest(parameterized.TestCase): classifier.classify_for_video(self.test_image, 0) def test_classify_for_video(self): - custom_classifier_options = _ClassifierOptions(max_results=4) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.VIDEO, - classifier_options=custom_classifier_options) + max_results=4) with _ImageClassifier.create_from_options(options) as classifier: for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( @@ -399,11 +383,10 @@ class ImageClassifierTest(parameterized.TestCase): _generate_burger_results().to_pb2()) def test_classify_for_video_succeeds_with_region_of_interest(self): - custom_classifier_options = _ClassifierOptions(max_results=1) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.VIDEO, - classifier_options=custom_classifier_options) + max_results=1) with _ImageClassifier.create_from_options(options) as classifier: # Load the test image. test_image = _Image.create_from_file( @@ -439,11 +422,10 @@ class ImageClassifierTest(parameterized.TestCase): classifier.classify_for_video(self.test_image, 0) def test_classify_async_calls_with_illegal_timestamp(self): - custom_classifier_options = _ClassifierOptions(max_results=4) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - classifier_options=custom_classifier_options, + max_results=4, result_callback=mock.MagicMock()) with _ImageClassifier.create_from_options(options) as classifier: classifier.classify_async(self.test_image, 100) @@ -466,12 +448,11 @@ class ImageClassifierTest(parameterized.TestCase): self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms - custom_classifier_options = _ClassifierOptions( - max_results=4, score_threshold=threshold) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - classifier_options=custom_classifier_options, + max_results=4, + score_threshold=threshold, result_callback=check_result) with _ImageClassifier.create_from_options(options) as classifier: for timestamp in range(0, 300, 30): @@ -496,11 +477,10 @@ class ImageClassifierTest(parameterized.TestCase): self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms - custom_classifier_options = _ClassifierOptions(max_results=1) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - classifier_options=custom_classifier_options, + max_results=1, result_callback=check_result) with _ImageClassifier.create_from_options(options) as classifier: for timestamp in range(0, 300, 30): diff --git a/mediapipe/tasks/python/test/vision/image_embedder_test.py b/mediapipe/tasks/python/test/vision/image_embedder_test.py index 4bb96bad6..11c0cf002 100644 --- a/mediapipe/tasks/python/test/vision/image_embedder_test.py +++ b/mediapipe/tasks/python/test/vision/image_embedder_test.py @@ -24,7 +24,6 @@ import numpy as np from mediapipe.python._framework_bindings import image as image_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module from mediapipe.tasks.python.components.containers import rect -from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_embedder @@ -33,7 +32,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni _Rect = rect.Rect _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options_module.EmbedderOptions _Embedding = embedding_result_module.Embedding _Image = image_module.Image _ImageEmbedder = image_embedder.ImageEmbedder @@ -142,10 +140,8 @@ class ImageEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _ImageEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) embedder = _ImageEmbedder.create_from_options(options) image_processing_options = None @@ -186,10 +182,8 @@ class ImageEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _ImageEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) with _ImageEmbedder.create_from_options(options) as embedder: # Extracts both embeddings. diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index 10b4b8a6e..e2a51cdbd 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -28,9 +28,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -47,9 +47,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py index 9711e8b3a..fdb20f0ef 100644 --- a/mediapipe/tasks/python/text/text_classifier.py +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -14,14 +14,14 @@ """MediaPipe text classifier task.""" import dataclasses -from typing import Optional +from typing import Optional, List from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2 from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -30,7 +30,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api TextClassifierResult = classification_result_module.ClassificationResult _BaseOptions = base_options_module.BaseOptions _TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions -_ClassifierOptions = classifier_options.ClassifierOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _TaskInfo = task_info_module.TaskInfo _CLASSIFICATIONS_STREAM_NAME = 'classifications_out' @@ -46,17 +46,38 @@ class TextClassifierOptions: Attributes: base_options: Base options for the text classifier task. - classifier_options: Options for the text classification task. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. """ base_options: _BaseOptions - classifier_options: Optional[_ClassifierOptions] = dataclasses.field( - default_factory=_ClassifierOptions) + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextClassifierGraphOptionsProto: """Generates an TextClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - classifier_options_proto = self.classifier_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) return _TextClassifierGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/text/text_embedder.py b/mediapipe/tasks/python/text/text_embedder.py index a9e560ac9..be899636d 100644 --- a/mediapipe/tasks/python/text/text_embedder.py +++ b/mediapipe/tasks/python/text/text_embedder.py @@ -19,9 +19,9 @@ from typing import Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2 from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -31,7 +31,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api TextEmbedderResult = embedding_result_module.EmbeddingResult _BaseOptions = base_options_module.BaseOptions _TextEmbedderGraphOptionsProto = text_embedder_graph_options_pb2.TextEmbedderGraphOptions -_EmbedderOptions = embedder_options.EmbedderOptions +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _TaskInfo = task_info_module.TaskInfo _EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out' @@ -47,17 +47,25 @@ class TextEmbedderOptions: Attributes: base_options: Base options for the text embedder task. - embedder_options: Options for the text embedder task. + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. """ base_options: _BaseOptions - embedder_options: Optional[_EmbedderOptions] = dataclasses.field( - default_factory=_EmbedderOptions) + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextEmbedderGraphOptionsProto: """Generates an TextEmbedderOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - embedder_options_proto = self.embedder_options.to_pb2() + embedder_options_proto = _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) return _TextEmbedderGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 29e7577e8..241ca4341 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -47,10 +47,10 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -89,9 +89,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 6cbce7860..b60d18e31 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -14,17 +14,17 @@ """MediaPipe image classifier task.""" import dataclasses -from typing import Callable, Mapping, Optional +from typing import Callable, Mapping, Optional, List from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import rect -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -36,7 +36,7 @@ ImageClassifierResult = classification_result_module.ClassificationResult _NormalizedRect = rect.NormalizedRect _BaseOptions = base_options_module.BaseOptions _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions -_ClassifierOptions = classifier_options.ClassifierOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _RunningMode = vision_task_running_mode.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo @@ -63,15 +63,31 @@ class ImageClassifierOptions: objects on single image inputs. 2) The video mode for classifying objects on the decoded frames of a video. 3) The live stream mode for classifying objects on a live stream of input data, such as from camera. - classifier_options: Options for the image classification task. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - classifier_options: Optional[_ClassifierOptions] = dataclasses.field( - default_factory=_ClassifierOptions) + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None result_callback: Optional[Callable[ [ImageClassifierResult, image_module.Image, int], None]] = None @@ -80,7 +96,12 @@ class ImageClassifierOptions: """Generates an ImageClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - classifier_options_proto = self.classifier_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) return _ImageClassifierGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/vision/image_embedder.py b/mediapipe/tasks/python/vision/image_embedder.py index a58dca3ae..0bae21bda 100644 --- a/mediapipe/tasks/python/vision/image_embedder.py +++ b/mediapipe/tasks/python/vision/image_embedder.py @@ -21,9 +21,9 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2 from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -35,7 +35,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni ImageEmbedderResult = embedding_result_module.EmbeddingResult _BaseOptions = base_options_module.BaseOptions _ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions -_EmbedderOptions = embedder_options.EmbedderOptions +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _RunningMode = running_mode_module.VisionTaskRunningMode _TaskInfo = task_info_module.TaskInfo _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions @@ -62,15 +62,22 @@ class ImageEmbedderOptions: image on single image inputs. 2) The video mode for embedding image on the decoded frames of a video. 3) The live stream mode for embedding image on a live stream of input data, such as from camera. - embedder_options: Options for the image embedder task. + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - embedder_options: Optional[_EmbedderOptions] = dataclasses.field( - default_factory=_EmbedderOptions) + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None result_callback: Optional[Callable[ [ImageEmbedderResult, image_module.Image, int], None]] = None @@ -79,7 +86,8 @@ class ImageEmbedderOptions: """Generates an ImageEmbedderOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - embedder_options_proto = self.embedder_options.to_pb2() + embedder_options_proto = _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) return _ImageEmbedderGraphOptionsProto( base_options=base_options_proto, From 1167f61f9825cc80e3e81b53b08a59f1a19ef456 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 6 Dec 2022 18:02:35 -0800 Subject: [PATCH 104/346] Remove generic Options template argument from TaskRunner PiperOrigin-RevId: 493462947 --- mediapipe/tasks/web/audio/core/BUILD | 5 +---- .../tasks/web/audio/core/audio_task_runner.ts | 3 +-- mediapipe/tasks/web/core/task_runner.ts | 14 ++++++-------- .../web/text/text_classifier/text_classifier.ts | 2 +- .../tasks/web/text/text_embedder/text_embedder.ts | 2 +- .../tasks/web/vision/core/vision_task_runner.ts | 3 +-- 6 files changed, 11 insertions(+), 18 deletions(-) diff --git a/mediapipe/tasks/web/audio/core/BUILD b/mediapipe/tasks/web/audio/core/BUILD index 9ab6c7bee..cea689838 100644 --- a/mediapipe/tasks/web/audio/core/BUILD +++ b/mediapipe/tasks/web/audio/core/BUILD @@ -7,8 +7,5 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "audio_task_runner", srcs = ["audio_task_runner.ts"], - deps = [ - "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:task_runner", - ], + deps = ["//mediapipe/tasks/web/core:task_runner"], ) diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts index 00cfe0253..24d78378d 100644 --- a/mediapipe/tasks/web/audio/core/audio_task_runner.ts +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -15,10 +15,9 @@ */ import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Base class for all MediaPipe Audio Tasks. */ -export abstract class AudioTaskRunner extends TaskRunner { +export abstract class AudioTaskRunner extends TaskRunner { private defaultSampleRate = 48000; /** diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index e2ab42e31..71e159dce 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -37,10 +37,9 @@ export class GraphRunnerImageLib extends GraphRunnerImageLibType {} * supported and loads the relevant WASM binary. * @return A fully instantiated instance of `T`. */ -export async function -createTaskRunner, O extends TaskRunnerOptions>( +export async function createTaskRunner( type: WasmMediaPipeConstructor, initializeCanvas: boolean, - fileset: WasmFileset, options: O): Promise { + fileset: WasmFileset, options: TaskRunnerOptions): Promise { const fileLocator: FileLocator = { locateFile() { // The only file loaded with this mechanism is the Wasm binary @@ -61,7 +60,7 @@ createTaskRunner, O extends TaskRunnerOptions>( } /** Base class for all MediaPipe Tasks. */ -export abstract class TaskRunner { +export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; protected graphRunner: GraphRunnerImageLib; private processingErrors: Error[] = []; @@ -71,10 +70,9 @@ export abstract class TaskRunner { * supported and loads the relevant WASM binary. * @return A fully instantiated instance of `T`. */ - protected static async createInstance, - O extends TaskRunnerOptions>( + protected static async createInstance( type: WasmMediaPipeConstructor, initializeCanvas: boolean, - fileset: WasmFileset, options: O): Promise { + fileset: WasmFileset, options: TaskRunnerOptions): Promise { return createTaskRunner(type, initializeCanvas, fileset, options); } @@ -92,7 +90,7 @@ export abstract class TaskRunner { } /** Configures the shared options of a MediaPipe Task. */ - async setOptions(options: O): Promise { + async setOptions(options: TaskRunnerOptions): Promise { if (options.baseOptions) { this.baseOptions = await convertBaseOptionsToProto( options.baseOptions, this.baseOptions); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 8810d4b42..4a8588836 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -41,7 +41,7 @@ const TEXT_CLASSIFIER_GRAPH = // tslint:disable:jspb-use-builder-pattern /** Performs Natural Language classification. */ -export class TextClassifier extends TaskRunner { +export class TextClassifier extends TaskRunner { private classificationResult: TextClassifierResult = {classifications: []}; private readonly options = new TextClassifierGraphOptions(); diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 62f9b06db..cd5bc644e 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -45,7 +45,7 @@ const TEXT_EMBEDDER_CALCULATOR = /** * Performs embedding extraction on text. */ -export class TextEmbedder extends TaskRunner { +export class TextEmbedder extends TaskRunner { private embeddingResult: TextEmbedderResult = {embeddings: []}; private readonly options = new TextEmbedderGraphOptionsProto(); diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 78b4859f2..3432b521b 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -20,8 +20,7 @@ import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {VisionTaskOptions} from './vision_task_options'; /** Base class for all MediaPipe Vision Tasks. */ -export abstract class VisionTaskRunner extends - TaskRunner { +export abstract class VisionTaskRunner extends TaskRunner { /** Configures the shared options of a vision task. */ override async setOptions(options: VisionTaskOptions): Promise { await super.setOptions(options); From 402834b4f2236ed2d707f0d20c0ebd2d1a42a721 Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Tue, 6 Dec 2022 19:46:33 -0800 Subject: [PATCH 105/346] Internal change PiperOrigin-RevId: 493480322 --- docs/build_model_maker_api_docs.py | 81 ++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 docs/build_model_maker_api_docs.py diff --git a/docs/build_model_maker_api_docs.py b/docs/build_model_maker_api_docs.py new file mode 100644 index 000000000..7732b7d56 --- /dev/null +++ b/docs/build_model_maker_api_docs.py @@ -0,0 +1,81 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""MediaPipe Model Maker reference docs generation script. + +This script generates API reference docs for the `mediapipe` PIP package. + +$> pip install -U git+https://github.com/tensorflow/docs mediapipe-model-maker +$> python build_model_maker_api_docs.py +""" + +import os + +from absl import app +from absl import flags + +from tensorflow_docs.api_generator import generate_lib + +try: + # mediapipe has not been set up to work with bazel yet, so catch & report. + import mediapipe_model_maker # pytype: disable=import-error +except ImportError as e: + raise ImportError('Please `pip install mediapipe-model-maker`.') from e + + +PROJECT_SHORT_NAME = 'mediapipe_model_maker' +PROJECT_FULL_NAME = 'MediaPipe Model Maker' + +_OUTPUT_DIR = flags.DEFINE_string( + 'output_dir', + default='/tmp/generated_docs', + help='Where to write the resulting docs.') + +_URL_PREFIX = flags.DEFINE_string( + 'code_url_prefix', + 'https://github.com/google/mediapipe/tree/master/mediapipe/model_maker', + 'The url prefix for links to code.') + +_SEARCH_HINTS = flags.DEFINE_bool( + 'search_hints', True, + 'Include metadata search hints in the generated files') + +_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', + 'Path prefix in the _toc.yaml') + + +def gen_api_docs(): + """Generates API docs for the mediapipe-model-maker package.""" + + doc_generator = generate_lib.DocGenerator( + root_title=PROJECT_FULL_NAME, + py_modules=[(PROJECT_SHORT_NAME, mediapipe_model_maker)], + base_dir=os.path.dirname(mediapipe_model_maker.__file__), + code_url_prefix=_URL_PREFIX.value, + search_hints=_SEARCH_HINTS.value, + site_path=_SITE_PATH.value, + callbacks=[], + ) + + doc_generator.build(_OUTPUT_DIR.value) + + print('Docs output to:', _OUTPUT_DIR.value) + + +def main(_): + gen_api_docs() + + +if __name__ == '__main__': + app.run(main) From 523d16dffab5d066879b300230cc9ac26ad49128 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 6 Dec 2022 23:54:11 -0800 Subject: [PATCH 106/346] Make GpuBuffer a shared_ptr to a storage collection PiperOrigin-RevId: 493519590 --- mediapipe/gpu/BUILD | 2 + mediapipe/gpu/gpu_buffer.cc | 102 +++++++++++++++++++++--------- mediapipe/gpu/gpu_buffer.h | 105 +++++++++++++++++-------------- mediapipe/gpu/gpu_buffer_test.cc | 22 +++++++ 4 files changed, 156 insertions(+), 75 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index f5cb9f715..009eb3f9e 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -289,7 +289,9 @@ cc_library( deps = [ ":gpu_buffer_format", ":gpu_buffer_storage", + "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:logging", ":gpu_buffer_storage_image_frame", diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc index 388960b11..628e86099 100644 --- a/mediapipe/gpu/gpu_buffer.cc +++ b/mediapipe/gpu/gpu_buffer.cc @@ -3,6 +3,7 @@ #include #include +#include "absl/functional/bind_front.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "mediapipe/framework/port/logging.h" @@ -25,57 +26,101 @@ struct StorageTypeFormatter { } // namespace std::string GpuBuffer::DebugString() const { - return absl::StrCat("GpuBuffer[", - absl::StrJoin(storages_, ", ", StorageTypeFormatter()), - "]"); + return holder_ ? absl::StrCat("GpuBuffer[", width(), "x", height(), " ", + format(), " as ", holder_->DebugString(), "]") + : "GpuBuffer[invalid]"; } -internal::GpuBufferStorage* GpuBuffer::GetStorageForView( +std::string GpuBuffer::StorageHolder::DebugString() const { + absl::MutexLock lock(&mutex_); + return absl::StrJoin(storages_, ", ", StorageTypeFormatter()); +} + +internal::GpuBufferStorage* GpuBuffer::StorageHolder::GetStorageForView( TypeId view_provider_type, bool for_writing) const { - const std::shared_ptr* chosen_storage = nullptr; + std::shared_ptr chosen_storage; + std::function()> conversion; - // First see if any current storage supports the view. - for (const auto& s : storages_) { - if (s->can_down_cast_to(view_provider_type)) { - chosen_storage = &s; - break; - } - } - - // Then try to convert existing storages to one that does. - // TODO: choose best conversion. - if (!chosen_storage) { + { + absl::MutexLock lock(&mutex_); + // First see if any current storage supports the view. for (const auto& s : storages_) { - if (auto converter = internal::GpuBufferStorageRegistry::Get() - .StorageConverterForViewProvider( - view_provider_type, s->storage_type())) { - if (auto new_storage = converter(s)) { - storages_.push_back(new_storage); - chosen_storage = &storages_.back(); + if (s->can_down_cast_to(view_provider_type)) { + chosen_storage = s; + break; + } + } + + // Then try to convert existing storages to one that does. + // TODO: choose best conversion. + if (!chosen_storage) { + for (const auto& s : storages_) { + if (auto converter = internal::GpuBufferStorageRegistry::Get() + .StorageConverterForViewProvider( + view_provider_type, s->storage_type())) { + conversion = absl::bind_front(converter, s); break; } } } } + // Avoid invoking a converter or factory while holding the mutex. + // Two reasons: + // 1. Readers that don't need a conversion will not be blocked. + // 2. We use mutexes to make sure GL contexts are not used simultaneously on + // different threads, and we also rely on Mutex's deadlock detection + // heuristic, which enforces a consistent mutex acquisition order. + // This function is likely to be called within a GL context, and the + // conversion function may in turn use a GL context, and this may cause a + // false positive in the deadlock detector. + // TODO: we could use Mutex::ForgetDeadlockInfo instead. + if (conversion) { + auto new_storage = conversion(); + absl::MutexLock lock(&mutex_); + // Another reader might have already completed and inserted the same + // conversion. TODO: prevent this? + for (const auto& s : storages_) { + if (s->can_down_cast_to(view_provider_type)) { + chosen_storage = s; + break; + } + } + if (!chosen_storage) { + storages_.push_back(std::move(new_storage)); + chosen_storage = storages_.back(); + } + } + if (for_writing) { + // This will temporarily hold storages to be released, and do so while the + // lock is not held (see above). + decltype(storages_) old_storages; + using std::swap; if (chosen_storage) { // Discard all other storages. - storages_ = {*chosen_storage}; - chosen_storage = &storages_.back(); + absl::MutexLock lock(&mutex_); + swap(old_storages, storages_); + storages_ = {chosen_storage}; } else { // Allocate a new storage supporting the requested view. if (auto factory = internal::GpuBufferStorageRegistry::Get() .StorageFactoryForViewProvider(view_provider_type)) { - if (auto new_storage = factory(width(), height(), format())) { + if (auto new_storage = factory(width_, height_, format_)) { + absl::MutexLock lock(&mutex_); + swap(old_storages, storages_); storages_ = {std::move(new_storage)}; - chosen_storage = &storages_.back(); + chosen_storage = storages_.back(); } } } } - return chosen_storage ? chosen_storage->get() : nullptr; + + // It is ok to return a non-owning storage pointer here because this object + // ensures the storage's lifetime. Overwriting a GpuBuffer while readers are + // active would violate this, but it's not allowed in MediaPipe. + return chosen_storage ? chosen_storage.get() : nullptr; } internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( @@ -84,8 +129,7 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( GpuBuffer::GetStorageForView(view_provider_type, for_writing); CHECK(chosen_storage) << "no view provider found for requested view " << view_provider_type.name() << "; storages available: " - << absl::StrJoin(storages_, ", ", - StorageTypeFormatter()); + << (holder_ ? holder_->DebugString() : "invalid"); DCHECK(chosen_storage->can_down_cast_to(view_provider_type)); return *chosen_storage; } diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index 56507d92f..b9a88aa53 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -15,9 +15,12 @@ #ifndef MEDIAPIPE_GPU_GPU_BUFFER_H_ #define MEDIAPIPE_GPU_GPU_BUFFER_H_ +#include +#include #include #include +#include "absl/synchronization/mutex.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_buffer_storage.h" @@ -56,8 +59,7 @@ class GpuBuffer { // Creates an empty buffer of a given size and format. It will be allocated // when a view is requested. GpuBuffer(int width, int height, Format format) - : GpuBuffer(std::make_shared(width, height, - format)) {} + : holder_(std::make_shared(width, height, format)) {} // Copy and move constructors and assignment operators are supported. GpuBuffer(const GpuBuffer& other) = default; @@ -70,9 +72,8 @@ class GpuBuffer { // are not portable. Applications and calculators should normally obtain // GpuBuffers in a portable way from the framework, e.g. using // GpuBufferMultiPool. - explicit GpuBuffer(std::shared_ptr storage) { - storages_.push_back(std::move(storage)); - } + explicit GpuBuffer(std::shared_ptr storage) + : holder_(std::make_shared(std::move(storage))) {} #if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER // This is used to support backward-compatible construction of GpuBuffer from @@ -84,9 +85,11 @@ class GpuBuffer { : GpuBuffer(internal::AsGpuBufferStorage(storage_convertible)) {} #endif // !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - int width() const { return current_storage().width(); } - int height() const { return current_storage().height(); } - GpuBufferFormat format() const { return current_storage().format(); } + int width() const { return holder_ ? holder_->width() : 0; } + int height() const { return holder_ ? holder_->height() : 0; } + GpuBufferFormat format() const { + return holder_ ? holder_->format() : GpuBufferFormat::kUnknown; + } // Converts to true iff valid. explicit operator bool() const { return operator!=(nullptr); } @@ -122,31 +125,17 @@ class GpuBuffer { // using views. template std::shared_ptr internal_storage() const { - for (const auto& s : storages_) - if (s->down_cast()) return std::static_pointer_cast(s); - return nullptr; + return holder_ ? holder_->internal_storage() : nullptr; } std::string DebugString() const; private: - class PlaceholderGpuBufferStorage - : public internal::GpuBufferStorageImpl { - public: - PlaceholderGpuBufferStorage(int width, int height, Format format) - : width_(width), height_(height), format_(format) {} - int width() const override { return width_; } - int height() const override { return height_; } - GpuBufferFormat format() const override { return format_; } - - private: - int width_ = 0; - int height_ = 0; - GpuBufferFormat format_ = GpuBufferFormat::kUnknown; - }; - internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type, - bool for_writing) const; + bool for_writing) const { + return holder_ ? holder_->GetStorageForView(view_provider_type, for_writing) + : nullptr; + } internal::GpuBufferStorage& GetStorageForViewOrDie(TypeId view_provider_type, bool for_writing) const; @@ -158,25 +147,49 @@ class GpuBuffer { .template down_cast(); } - std::shared_ptr& no_storage() const { - static auto placeholder = - std::static_pointer_cast( - std::make_shared( - 0, 0, GpuBufferFormat::kUnknown)); - return placeholder; - } + // This class manages a set of alternative storages for the contents of a + // GpuBuffer. GpuBuffer was originally designed as a reference-type object, + // where a copy represents another reference to the same contents, so multiple + // GpuBuffer instances can share the same StorageHolder. + class StorageHolder { + public: + explicit StorageHolder(std::shared_ptr storage) + : StorageHolder(storage->width(), storage->height(), + storage->format()) { + storages_.push_back(std::move(storage)); + } + explicit StorageHolder(int width, int height, Format format) + : width_(width), height_(height), format_(format) {} - const internal::GpuBufferStorage& current_storage() const { - return storages_.empty() ? *no_storage() : *storages_[0]; - } + int width() const { return width_; } + int height() const { return height_; } + GpuBufferFormat format() const { return format_; } - internal::GpuBufferStorage& current_storage() { - return storages_.empty() ? *no_storage() : *storages_[0]; - } + internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type, + bool for_writing) const; - // This is mutable because view methods that do not change the contents may - // still need to allocate new storages. - mutable std::vector> storages_; + template + std::shared_ptr internal_storage() const { + absl::MutexLock lock(&mutex_); + for (const auto& s : storages_) + if (s->down_cast()) return std::static_pointer_cast(s); + return nullptr; + } + + std::string DebugString() const; + + private: + int width_ = 0; + int height_ = 0; + GpuBufferFormat format_ = GpuBufferFormat::kUnknown; + // This is mutable because view methods that do not change the contents may + // still need to allocate new storages. + mutable absl::Mutex mutex_; + mutable std::vector> storages_ + ABSL_GUARDED_BY(mutex_); + }; + + std::shared_ptr holder_; #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER friend CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer); @@ -184,15 +197,15 @@ class GpuBuffer { }; inline bool GpuBuffer::operator==(std::nullptr_t other) const { - return storages_.empty(); + return holder_ == other; } inline bool GpuBuffer::operator==(const GpuBuffer& other) const { - return storages_ == other.storages_; + return holder_ == other.holder_; } inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) { - storages_.clear(); + holder_ = other; return *this; } diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc index 145b71806..e4be617db 100644 --- a/mediapipe/gpu/gpu_buffer_test.cc +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -20,6 +20,7 @@ #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/gpu/gl_texture_util.h" #include "mediapipe/gpu/gpu_buffer_storage_ahwb.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" @@ -228,5 +229,26 @@ TEST_F(GpuBufferTest, GlTextureViewRetainsWhatItNeeds) { EXPECT_TRUE(true); } +TEST_F(GpuBufferTest, CopiesShareConversions) { + GpuBuffer buffer(300, 200, GpuBufferFormat::kBGRA32); + { + std::shared_ptr view = buffer.GetWriteView(); + FillImageFrameRGBA(*view, 255, 0, 0, 255); + } + + GpuBuffer other_handle = buffer; + RunInGlContext([&buffer] { + TempGlFramebuffer fb; + auto view = buffer.GetReadView(0); + }); + + // Check that other_handle also sees the same GlTextureBuffer as buffer. + // Note that this is deliberately written so that it still passes on platforms + // where we use another storage for GL textures (they will both be null). + // TODO: expose more accessors for testing? + EXPECT_EQ(other_handle.internal_storage(), + buffer.internal_storage()); +} + } // anonymous namespace } // namespace mediapipe From aad797197bbb4c4170cd21c6baf18084bee84446 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 7 Dec 2022 07:14:46 -0800 Subject: [PATCH 107/346] TensorV1 EGL.h include fix. PiperOrigin-RevId: 493596083 --- mediapipe/framework/formats/tensor.h | 5 ++--- mediapipe/framework/formats/tensor_ahwb.cc | 5 +++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index ecd63c8c6..3ed72c6fd 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -39,10 +39,9 @@ #endif // MEDIAPIPE_NO_JNI #ifdef MEDIAPIPE_TENSOR_USE_AHWB +#include +#include #include - -#include "third_party/GL/gl/include/EGL/egl.h" -#include "third_party/GL/gl/include/EGL/eglext.h" #endif // MEDIAPIPE_TENSOR_USE_AHWB #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include "mediapipe/gpu/gl_base.h" diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index b11f6b55b..90d89c40a 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -4,12 +4,13 @@ #include "mediapipe/framework/formats/tensor.h" #ifdef MEDIAPIPE_TENSOR_USE_AHWB +#include +#include + #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/gpu/gl_base.h" -#include "third_party/GL/gl/include/EGL/egl.h" -#include "third_party/GL/gl/include/EGL/eglext.h" #endif // MEDIAPIPE_TENSOR_USE_AHWB namespace mediapipe { From d9688b769f5207aff13bd782d94dd4d2ad8dcd92 Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Wed, 7 Dec 2022 08:13:51 -0800 Subject: [PATCH 108/346] Hide internal APIs from mediapipe pip package's API docs. PiperOrigin-RevId: 493607984 --- .../tasks/python/core/optional_dependencies.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/python/core/optional_dependencies.py b/mediapipe/tasks/python/core/optional_dependencies.py index d4f6a6abc..b1a0ed538 100644 --- a/mediapipe/tasks/python/core/optional_dependencies.py +++ b/mediapipe/tasks/python/core/optional_dependencies.py @@ -13,6 +13,13 @@ # limitations under the License. """MediaPipe Tasks' common but optional dependencies.""" -doc_controls = lambda: None -no_op = lambda x: x -setattr(doc_controls, 'do_not_generate_docs', no_op) +# TensorFlow isn't a dependency of mediapipe pip package. It's only +# required in the API docgen pipeline so we'll ignore it if tensorflow is not +# installed. +try: + from tensorflow.tools.docs import doc_controls +except ModuleNotFoundError: + # Replace the real doc_controls.do_not_generate_docs with an no-op + doc_controls = lambda: None + no_op = lambda x: x + setattr(doc_controls, 'do_not_generate_docs', no_op) From d84eec387bb277b4f379f360d97bbf734cb3ae13 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 7 Dec 2022 10:50:12 -0800 Subject: [PATCH 109/346] Add missing import to InferenceCalculator.proto PiperOrigin-RevId: 493649869 --- mediapipe/calculators/tensor/inference_calculator.proto | 1 + mediapipe/tasks/web/BUILD | 3 --- mediapipe/tasks/web/rollup.config.mjs | 6 ------ package.json | 1 - yarn.lock | 8 -------- 5 files changed, 1 insertion(+), 18 deletions(-) diff --git a/mediapipe/calculators/tensor/inference_calculator.proto b/mediapipe/calculators/tensor/inference_calculator.proto index 46552803b..78a0039bc 100644 --- a/mediapipe/calculators/tensor/inference_calculator.proto +++ b/mediapipe/calculators/tensor/inference_calculator.proto @@ -17,6 +17,7 @@ syntax = "proto2"; package mediapipe; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; option java_package = "com.google.mediapipe.calculator.proto"; option java_outer_classname = "InferenceCalculatorProto"; diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index 20e717433..bc9e84147 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -44,7 +44,6 @@ rollup_bundle( ":audio_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", "@npm//google-protobuf", ], @@ -88,7 +87,6 @@ rollup_bundle( ":text_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", "@npm//google-protobuf", ], @@ -132,7 +130,6 @@ rollup_bundle( ":vision_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", "@npm//google-protobuf", ], diff --git a/mediapipe/tasks/web/rollup.config.mjs b/mediapipe/tasks/web/rollup.config.mjs index e633bf702..3b5119530 100644 --- a/mediapipe/tasks/web/rollup.config.mjs +++ b/mediapipe/tasks/web/rollup.config.mjs @@ -1,15 +1,9 @@ import resolve from '@rollup/plugin-node-resolve'; import commonjs from '@rollup/plugin-commonjs'; -import replace from '@rollup/plugin-replace'; import terser from '@rollup/plugin-terser'; export default { plugins: [ - // Workaround for https://github.com/protocolbuffers/protobuf-javascript/issues/151 - replace({ - 'var calculator_options_pb = {};': 'var calculator_options_pb = {}; var mediapipe_framework_calculator_options_pb = calculator_options_pb;', - delimiters: ['', ''] - }), resolve(), commonjs(), terser() diff --git a/package.json b/package.json index 22a035b74..6ad0b52c0 100644 --- a/package.json +++ b/package.json @@ -7,7 +7,6 @@ "@bazel/typescript": "^5.7.1", "@rollup/plugin-commonjs": "^23.0.2", "@rollup/plugin-node-resolve": "^15.0.1", - "@rollup/plugin-replace": "^5.0.1", "@rollup/plugin-terser": "^0.1.0", "@types/google-protobuf": "^3.15.6", "@types/offscreencanvas": "^2019.7.0", diff --git a/yarn.lock b/yarn.lock index 19c32e322..91b50456e 100644 --- a/yarn.lock +++ b/yarn.lock @@ -148,14 +148,6 @@ is-module "^1.0.0" resolve "^1.22.1" -"@rollup/plugin-replace@^5.0.1": - version "5.0.1" - resolved "https://registry.yarnpkg.com/@rollup/plugin-replace/-/plugin-replace-5.0.1.tgz#49a57af3e6df111a9e75dea3f3572741f4c5c83e" - integrity sha512-Z3MfsJ4CK17BfGrZgvrcp/l6WXoKb0kokULO+zt/7bmcyayokDaQ2K3eDJcRLCTAlp5FPI4/gz9MHAsosz4Rag== - dependencies: - "@rollup/pluginutils" "^5.0.1" - magic-string "^0.26.4" - "@rollup/plugin-terser@^0.1.0": version "0.1.0" resolved "https://registry.yarnpkg.com/@rollup/plugin-terser/-/plugin-terser-0.1.0.tgz#7530c0f11667637419d71820461646c418526041" From 80c605459c2361840c1c0eab05dfa260d7dcfedc Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 7 Dec 2022 11:24:15 -0800 Subject: [PATCH 110/346] Open up framework visibility. PiperOrigin-RevId: 493660013 --- mediapipe/framework/deps/BUILD | 23 ++++---------- mediapipe/framework/port/BUILD | 42 +------------------------ mediapipe/framework/profiler/BUILD | 5 ++- mediapipe/framework/tool/testdata/BUILD | 7 +++-- 4 files changed, 13 insertions(+), 64 deletions(-) diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index 95ab21707..27bc105c8 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -20,7 +20,9 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = [ + "//mediapipe:__subpackages__", +]) bzl_library( name = "expand_template_bzl", @@ -50,13 +52,11 @@ mediapipe_proto_library( cc_library( name = "aligned_malloc_and_free", hdrs = ["aligned_malloc_and_free.h"], - visibility = ["//visibility:public"], ) cc_library( name = "cleanup", hdrs = ["cleanup.h"], - visibility = ["//visibility:public"], deps = ["@com_google_absl//absl/base:core_headers"], ) @@ -86,7 +86,6 @@ cc_library( # Use this library through "mediapipe/framework/port:gtest_main". visibility = [ "//mediapipe/framework/port:__pkg__", - "//third_party/visionai/algorithms/tracking:__pkg__", ], deps = [ "//mediapipe/framework/port:core_proto", @@ -108,7 +107,6 @@ cc_library( name = "file_helpers", srcs = ["file_helpers.cc"], hdrs = ["file_helpers.h"], - visibility = ["//visibility:public"], deps = [ ":file_path", "//mediapipe/framework/port:status", @@ -134,7 +132,6 @@ cc_library( cc_library( name = "image_resizer", hdrs = ["image_resizer.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:opencv_imgproc", ], @@ -151,7 +148,9 @@ cc_library( cc_library( name = "mathutil", hdrs = ["mathutil.h"], - visibility = ["//visibility:public"], + visibility = [ + "//mediapipe:__subpackages__", + ], deps = [ "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -171,7 +170,6 @@ cc_library( cc_library( name = "no_destructor", hdrs = ["no_destructor.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -190,7 +188,6 @@ cc_library( cc_library( name = "random", hdrs = ["random_base.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/port:integral_types"], ) @@ -211,14 +208,12 @@ cc_library( name = "registration_token", srcs = ["registration_token.cc"], hdrs = ["registration_token.h"], - visibility = ["//visibility:public"], ) cc_library( name = "registration", srcs = ["registration.cc"], hdrs = ["registration.h"], - visibility = ["//visibility:public"], deps = [ ":registration_token", "//mediapipe/framework/port:logging", @@ -279,7 +274,6 @@ cc_library( hdrs = [ "re2.h", ], - visibility = ["//visibility:public"], ) cc_library( @@ -310,7 +304,6 @@ cc_library( cc_library( name = "thread_options", hdrs = ["thread_options.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -356,7 +349,6 @@ cc_library( cc_test( name = "mathutil_unittest", srcs = ["mathutil_unittest.cc"], - visibility = ["//visibility:public"], deps = [ ":mathutil", "//mediapipe/framework/port:benchmark", @@ -368,7 +360,6 @@ cc_test( name = "registration_token_test", srcs = ["registration_token_test.cc"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ ":registration_token", "//mediapipe/framework/port:gtest_main", @@ -381,7 +372,6 @@ cc_test( timeout = "long", srcs = ["safe_int_test.cc"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ ":intops", "//mediapipe/framework/port:gtest_main", @@ -393,7 +383,6 @@ cc_test( name = "monotonic_clock_test", srcs = ["monotonic_clock_test.cc"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ ":clock", "//mediapipe/framework/port:gtest_main", diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index e499ca3a6..1039dc1c6 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -18,7 +18,7 @@ licenses(["notice"]) package( - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], features = ["-parse_headers"], ) @@ -28,7 +28,6 @@ config_setting( define_values = { "USE_MEDIAPIPE_THREADPOOL": "1", }, - visibility = ["//visibility:public"], ) #TODO : remove from OSS. @@ -37,13 +36,11 @@ config_setting( define_values = { "USE_MEDIAPIPE_THREADPOOL": "0", }, - visibility = ["//visibility:public"], ) cc_library( name = "aligned_malloc_and_free", hdrs = ["aligned_malloc_and_free.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/deps:aligned_malloc_and_free", "@com_google_absl//absl/base:core_headers", @@ -57,7 +54,6 @@ cc_library( "advanced_proto_inc.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ ":advanced_proto_lite", ":core_proto", @@ -72,7 +68,6 @@ cc_library( "advanced_proto_lite_inc.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ ":core_proto", "//mediapipe/framework:port", @@ -83,7 +78,6 @@ cc_library( cc_library( name = "any_proto", hdrs = ["any_proto.h"], - visibility = ["//visibility:public"], deps = [ ":core_proto", ], @@ -94,7 +88,6 @@ cc_library( hdrs = [ "commandlineflags.h", ], - visibility = ["//visibility:public"], deps = [ "//third_party:glog", "@com_google_absl//absl/flags:flag", @@ -107,7 +100,6 @@ cc_library( "core_proto_inc.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "@com_google_protobuf//:protobuf", @@ -117,7 +109,6 @@ cc_library( cc_library( name = "file_helpers", hdrs = ["file_helpers.h"], - visibility = ["//visibility:public"], deps = [ ":status", "//mediapipe/framework/deps:file_helpers", @@ -128,7 +119,6 @@ cc_library( cc_library( name = "image_resizer", hdrs = ["image_resizer.h"], - visibility = ["//visibility:public"], deps = select({ "//conditions:default": [ "//mediapipe/framework/deps:image_resizer", @@ -140,14 +130,12 @@ cc_library( cc_library( name = "integral_types", hdrs = ["integral_types.h"], - visibility = ["//visibility:public"], ) cc_library( name = "benchmark", testonly = 1, hdrs = ["benchmark.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_benchmark//:benchmark", ], @@ -158,7 +146,6 @@ cc_library( hdrs = [ "re2.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/deps:re2", ], @@ -173,7 +160,6 @@ cc_library( "gtest-spi.h", "status_matchers.h", ], - visibility = ["//visibility:public"], deps = [ ":status_matchers", "//mediapipe/framework/deps:message_matchers", @@ -190,7 +176,6 @@ cc_library( "gtest-spi.h", "status_matchers.h", ], - visibility = ["//visibility:public"], deps = [ ":status_matchers", "//mediapipe/framework/deps:message_matchers", @@ -204,7 +189,6 @@ cc_library( hdrs = [ "logging.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//third_party:glog", @@ -217,7 +201,6 @@ cc_library( hdrs = [ "map_util.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:map_util", @@ -227,7 +210,6 @@ cc_library( cc_library( name = "numbers", hdrs = ["numbers.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:numbers"], ) @@ -238,13 +220,11 @@ config_setting( define_values = { "MEDIAPIPE_DISABLE_OPENCV": "1", }, - visibility = ["//visibility:public"], ) cc_library( name = "opencv_core", hdrs = ["opencv_core_inc.h"], - visibility = ["//visibility:public"], deps = [ "//third_party:opencv", ], @@ -253,7 +233,6 @@ cc_library( cc_library( name = "opencv_imgproc", hdrs = ["opencv_imgproc_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -263,7 +242,6 @@ cc_library( cc_library( name = "opencv_imgcodecs", hdrs = ["opencv_imgcodecs_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -273,7 +251,6 @@ cc_library( cc_library( name = "opencv_highgui", hdrs = ["opencv_highgui_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -283,7 +260,6 @@ cc_library( cc_library( name = "opencv_video", hdrs = ["opencv_video_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//mediapipe/framework:port", @@ -294,7 +270,6 @@ cc_library( cc_library( name = "opencv_features2d", hdrs = ["opencv_features2d_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -304,7 +279,6 @@ cc_library( cc_library( name = "opencv_calib3d", hdrs = ["opencv_calib3d_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -314,7 +288,6 @@ cc_library( cc_library( name = "opencv_videoio", hdrs = ["opencv_videoio_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//mediapipe/framework:port", @@ -328,7 +301,6 @@ cc_library( "parse_text_proto.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ ":core_proto", ":logging", @@ -339,14 +311,12 @@ cc_library( cc_library( name = "point", hdrs = ["point2.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:point"], ) cc_library( name = "port", hdrs = ["port.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "@com_google_absl//absl/base:core_headers", @@ -356,14 +326,12 @@ cc_library( cc_library( name = "rectangle", hdrs = ["rectangle.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:rectangle"], ) cc_library( name = "ret_check", hdrs = ["ret_check.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:ret_check", @@ -373,7 +341,6 @@ cc_library( cc_library( name = "singleton", hdrs = ["singleton.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:singleton"], ) @@ -382,7 +349,6 @@ cc_library( hdrs = [ "source_location.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:source_location", @@ -397,7 +363,6 @@ cc_library( "status_builder.h", "status_macros.h", ], - visibility = ["//visibility:public"], deps = [ ":source_location", "//mediapipe/framework:port", @@ -412,7 +377,6 @@ cc_library( hdrs = [ "statusor.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "@com_google_absl//absl/status:statusor", @@ -423,7 +387,6 @@ cc_library( name = "status_matchers", testonly = 1, hdrs = ["status_matchers.h"], - visibility = ["//visibility:private"], deps = [ ":status", "@com_google_googletest//:gtest", @@ -433,7 +396,6 @@ cc_library( cc_library( name = "threadpool", hdrs = ["threadpool.h"], - visibility = ["//visibility:public"], deps = select({ "//conditions:default": [":threadpool_impl_default_to_google"], "//mediapipe:android": [":threadpool_impl_default_to_mediapipe"], @@ -460,7 +422,6 @@ alias( cc_library( name = "topologicalsorter", hdrs = ["topologicalsorter.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:topologicalsorter", @@ -470,6 +431,5 @@ cc_library( cc_library( name = "vector", hdrs = ["vector.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:vector"], ) diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index b53a1ac39..2947b9844 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -140,7 +140,7 @@ cc_library( name = "circular_buffer", hdrs = ["circular_buffer.h"], visibility = [ - "//visibility:public", + "//mediapipe:__subpackages__", ], deps = [ "//mediapipe/framework/port:integral_types", @@ -151,7 +151,6 @@ cc_test( name = "circular_buffer_test", size = "small", srcs = ["circular_buffer_test.cc"], - visibility = ["//visibility:public"], deps = [ ":circular_buffer", "//mediapipe/framework/port:gtest_main", @@ -164,7 +163,7 @@ cc_library( name = "trace_buffer", srcs = ["trace_buffer.h"], hdrs = ["trace_buffer.h"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework/profiler:__subpackages__"], deps = [ ":circular_buffer", "//mediapipe/framework:calculator_profile_cc_proto", diff --git a/mediapipe/framework/tool/testdata/BUILD b/mediapipe/framework/tool/testdata/BUILD index 906688520..f9aab7b72 100644 --- a/mediapipe/framework/tool/testdata/BUILD +++ b/mediapipe/framework/tool/testdata/BUILD @@ -20,7 +20,9 @@ load( licenses(["notice"]) -package(default_visibility = ["//mediapipe:__subpackages__"]) +package(default_visibility = [ + "//mediapipe:__subpackages__", +]) filegroup( name = "test_graph", @@ -40,7 +42,6 @@ mediapipe_simple_subgraph( testonly = 1, graph = "dub_quad_test_subgraph.pbtxt", register_as = "DubQuadTestSubgraph", - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:test_calculators", ], @@ -51,7 +52,7 @@ mediapipe_simple_subgraph( testonly = 1, graph = "nested_test_subgraph.pbtxt", register_as = "NestedTestSubgraph", - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":dub_quad_test_subgraph", "//mediapipe/framework:test_calculators", From 3c0ddf16b4c2b04cfff07d0db0aba48411468e9c Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 7 Dec 2022 11:37:04 -0800 Subject: [PATCH 111/346] Fix an incorrect model sanity check in the object detector graph. PiperOrigin-RevId: 493663592 --- .../tasks/cc/vision/object_detector/object_detector_graph.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index f5dc7e061..a1625c16c 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -532,8 +532,7 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Checks that the model has 4 outputs. auto& model = *model_resources.GetTfLiteModel(); - if (model.subgraphs()->size() != 1 || - (*model.subgraphs())[0]->outputs()->size() != 4) { + if (model.subgraphs()->size() != 1) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrFormat("Expected a model with a single subgraph, found %d.", From 2811e0c5c81e0ac7d39eab8c32efbe694de45940 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 7 Dec 2022 12:13:25 -0800 Subject: [PATCH 112/346] Open Source the first set of MediaPipe Tasks tests for Web PiperOrigin-RevId: 493673279 --- mediapipe/framework/port/build_config.bzl | 2 + .../tasks/web/components/processors/BUILD | 76 ++++ .../processors/base_options.test.ts | 127 ++++++ .../processors/classifier_options.test.ts | 114 +++++ .../processors/classifier_result.test.ts | 80 ++++ .../processors/embedder_options.test.ts | 93 ++++ .../processors/embedder_result.test.ts | 75 ++++ mediapipe/tasks/web/components/utils/BUILD | 16 + .../utils/cosine_similarity.test.ts | 85 ++++ mediapipe/tasks/web/core/BUILD | 33 ++ mediapipe/tasks/web/core/task_runner.ts | 7 +- mediapipe/tasks/web/core/task_runner_test.ts | 107 +++++ .../tasks/web/core/task_runner_test_utils.ts | 113 +++++ package.json | 5 + tsconfig.json | 2 +- yarn.lock | 419 ++++++++++++++++-- 16 files changed, 1308 insertions(+), 46 deletions(-) create mode 100644 mediapipe/tasks/web/components/processors/base_options.test.ts create mode 100644 mediapipe/tasks/web/components/processors/classifier_options.test.ts create mode 100644 mediapipe/tasks/web/components/processors/classifier_result.test.ts create mode 100644 mediapipe/tasks/web/components/processors/embedder_options.test.ts create mode 100644 mediapipe/tasks/web/components/processors/embedder_result.test.ts create mode 100644 mediapipe/tasks/web/components/utils/cosine_similarity.test.ts create mode 100644 mediapipe/tasks/web/core/task_runner_test.ts create mode 100644 mediapipe/tasks/web/core/task_runner_test_utils.ts diff --git a/mediapipe/framework/port/build_config.bzl b/mediapipe/framework/port/build_config.bzl index eaabda856..94a4a5646 100644 --- a/mediapipe/framework/port/build_config.bzl +++ b/mediapipe/framework/port/build_config.bzl @@ -228,6 +228,8 @@ def mediapipe_ts_library( srcs = srcs, visibility = visibility, deps = deps + [ + "@npm//@types/jasmine", + "@npm//@types/node", "@npm//@types/offscreencanvas", "@npm//@types/google-protobuf", ], diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index 86e743928..148a08238 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -1,5 +1,6 @@ # This package contains options shared by all MediaPipe Tasks for Web. +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -13,6 +14,22 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "classifier_options_test_lib", + testonly = True, + srcs = ["classifier_options.test.ts"], + deps = [ + ":classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_jspb_proto", + "//mediapipe/tasks/web/core:classifier_options", + ], +) + +jasmine_node_test( + name = "classifier_options_test", + deps = [":classifier_options_test_lib"], +) + mediapipe_ts_library( name = "classifier_result", srcs = ["classifier_result.ts"], @@ -22,6 +39,22 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "classifier_result_test_lib", + testonly = True, + srcs = ["classifier_result.test.ts"], + deps = [ + ":classifier_result", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + ], +) + +jasmine_node_test( + name = "classifier_result_test", + deps = [":classifier_result_test_lib"], +) + mediapipe_ts_library( name = "embedder_result", srcs = ["embedder_result.ts"], @@ -31,6 +64,21 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "embedder_result_test_lib", + testonly = True, + srcs = ["embedder_result.test.ts"], + deps = [ + ":embedder_result", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + ], +) + +jasmine_node_test( + name = "embedder_result_test", + deps = [":embedder_result_test_lib"], +) + mediapipe_ts_library( name = "embedder_options", srcs = ["embedder_options.ts"], @@ -40,6 +88,22 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "embedder_options_test_lib", + testonly = True, + srcs = ["embedder_options.test.ts"], + deps = [ + ":embedder_options", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_jspb_proto", + "//mediapipe/tasks/web/core:embedder_options", + ], +) + +jasmine_node_test( + name = "embedder_options_test", + deps = [":embedder_options_test_lib"], +) + mediapipe_ts_library( name = "base_options", srcs = [ @@ -53,3 +117,15 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", ], ) + +mediapipe_ts_library( + name = "base_options_test_lib", + testonly = True, + srcs = ["base_options.test.ts"], + deps = [":base_options"], +) + +jasmine_node_test( + name = "base_options_test", + deps = [":base_options_test_lib"], +) diff --git a/mediapipe/tasks/web/components/processors/base_options.test.ts b/mediapipe/tasks/web/components/processors/base_options.test.ts new file mode 100644 index 000000000..46c2277e9 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/base_options.test.ts @@ -0,0 +1,127 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +// Placeholder for internal dependency on trusted resource URL builder + +import {convertBaseOptionsToProto} from './base_options'; + +describe('convertBaseOptionsToProto()', () => { + const mockBytes = new Uint8Array([0, 1, 2, 3]); + const mockBytesResult = { + modelAsset: { + fileContent: Buffer.from(mockBytes).toString('base64'), + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined, + }, + useStreamMode: false, + acceleration: { + xnnpack: undefined, + gpu: undefined, + tflite: {}, + }, + }; + + let fetchSpy: jasmine.Spy; + + beforeEach(() => { + fetchSpy = jasmine.createSpy().and.callFake(async url => { + expect(url).toEqual('foo'); + return { + arrayBuffer: () => mockBytes.buffer, + } as unknown as Response; + }); + global.fetch = fetchSpy; + }); + + it('verifies that at least one model asset option is provided', async () => { + await expectAsync(convertBaseOptionsToProto({})) + .toBeRejectedWithError( + /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); + }); + + it('verifies that no more than one model asset option is provided', async () => { + await expectAsync(convertBaseOptionsToProto({ + modelAssetPath: `foo`, + modelAssetBuffer: new Uint8Array([]) + })) + .toBeRejectedWithError( + /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); + }); + + it('downloads model', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetPath: `foo`, + }); + + expect(fetchSpy).toHaveBeenCalled(); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); + + it('does not download model when bytes are provided', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + }); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); + + it('can enable CPU delegate', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'cpu', + }); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); + + it('can enable GPU delegate', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'gpu', + }); + expect(baseOptionsProto.toObject()).toEqual({ + ...mockBytesResult, + acceleration: { + xnnpack: undefined, + gpu: { + useAdvancedGpuApi: false, + api: 0, + allowPrecisionLoss: true, + cachedKernelPath: undefined, + serializedModelDir: undefined, + modelToken: undefined, + usage: 2, + }, + tflite: undefined, + }, + }); + }); + + it('can reset delegate', async () => { + let baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'gpu', + }); + // Clear backend + baseOptionsProto = + await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/classifier_options.test.ts b/mediapipe/tasks/web/components/processors/classifier_options.test.ts new file mode 100644 index 000000000..928bda426 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/classifier_options.test.ts @@ -0,0 +1,114 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {ClassifierOptions as ClassifierOptionsProto} from '../../../../tasks/cc/components/processors/proto/classifier_options_pb'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; + +import {convertClassifierOptionsToProto} from './classifier_options'; + +interface TestCase { + optionName: keyof ClassifierOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; +} + +describe('convertClassifierOptionsToProto()', () => { + function verifyOption( + actualClassifierOptions: ClassifierOptionsProto, + expectedClassifierOptions: Record = {}): void { + expect(actualClassifierOptions.toObject()) + .toEqual(jasmine.objectContaining(expectedClassifierOptions)); + } + + const testCases: TestCase[] = [ + { + optionName: 'maxResults', + protoName: 'maxResults', + customValue: 5, + defaultValue: -1 + }, + { + optionName: 'displayNamesLocale', + protoName: 'displayNamesLocale', + customValue: 'en', + defaultValue: 'en' + }, + { + optionName: 'scoreThreshold', + protoName: 'scoreThreshold', + customValue: 0.1, + defaultValue: undefined + }, + { + optionName: 'categoryAllowlist', + protoName: 'categoryAllowlistList', + customValue: ['foo'], + defaultValue: [] + }, + { + optionName: 'categoryDenylist', + protoName: 'categoryDenylistList', + customValue: ['bar'], + defaultValue: [] + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, () => { + const classifierOptionsProto = convertClassifierOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + classifierOptionsProto, {[testCase.protoName]: testCase.customValue}); + }); + + it(`can clear ${testCase.optionName}`, () => { + let classifierOptionsProto = convertClassifierOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + classifierOptionsProto, {[testCase.protoName]: testCase.customValue}); + + classifierOptionsProto = + convertClassifierOptionsToProto({[testCase.optionName]: undefined}); + verifyOption( + classifierOptionsProto, + {[testCase.protoName]: testCase.defaultValue}); + }); + } + + it('overwrites options', () => { + let classifierOptionsProto = + convertClassifierOptionsToProto({maxResults: 1}); + verifyOption(classifierOptionsProto, {'maxResults': 1}); + + classifierOptionsProto = convertClassifierOptionsToProto( + {maxResults: 2}, classifierOptionsProto); + verifyOption(classifierOptionsProto, {'maxResults': 2}); + }); + + it('merges options', () => { + let classifierOptionsProto = + convertClassifierOptionsToProto({maxResults: 1}); + verifyOption(classifierOptionsProto, {'maxResults': 1}); + + classifierOptionsProto = convertClassifierOptionsToProto( + {displayNamesLocale: 'en'}, classifierOptionsProto); + verifyOption( + classifierOptionsProto, {'maxResults': 1, 'displayNamesLocale': 'en'}); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/classifier_result.test.ts b/mediapipe/tasks/web/components/processors/classifier_result.test.ts new file mode 100644 index 000000000..4b93d0a76 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/classifier_result.test.ts @@ -0,0 +1,80 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; + +import {convertFromClassificationResultProto} from './classifier_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +describe('convertFromClassificationResultProto()', () => { + it('transforms custom values', () => { + const classificationResult = new ClassificationResult(); + classificationResult.setTimestampMs(1); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + clasification.setIndex(2); + clasification.setScore(0.3); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + const result = convertFromClassificationResultProto(classificationResult); + + expect(result).toEqual({ + classifications: [{ + categories: [{ + index: 2, + score: 0.3, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }], + timestampMs: 1 + }); + }); + + it('transforms default values', () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + const result = convertFromClassificationResultProto(classificationResult); + + expect(result).toEqual({ + classifications: [{ + categories: [{index: 0, score: 0, displayName: '', categoryName: ''}], + headIndex: 0, + headName: '' + }], + }); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/embedder_options.test.ts b/mediapipe/tasks/web/components/processors/embedder_options.test.ts new file mode 100644 index 000000000..b879a6b29 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_options.test.ts @@ -0,0 +1,93 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {EmbedderOptions as EmbedderOptionsProto} from '../../../../tasks/cc/components/processors/proto/embedder_options_pb'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; + +import {convertEmbedderOptionsToProto} from './embedder_options'; + +interface TestCase { + optionName: keyof EmbedderOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; +} + +describe('convertEmbedderOptionsToProto()', () => { + function verifyOption( + actualEmbedderOptions: EmbedderOptionsProto, + expectedEmbedderOptions: Record = {}): void { + expect(actualEmbedderOptions.toObject()) + .toEqual(jasmine.objectContaining(expectedEmbedderOptions)); + } + + const testCases: TestCase[] = [ + { + optionName: 'l2Normalize', + protoName: 'l2Normalize', + customValue: true, + defaultValue: undefined + }, + { + optionName: 'quantize', + protoName: 'quantize', + customValue: true, + defaultValue: undefined + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, () => { + const embedderOptionsProto = convertEmbedderOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + embedderOptionsProto, {[testCase.protoName]: testCase.customValue}); + }); + + it(`can clear ${testCase.optionName}`, () => { + let embedderOptionsProto = convertEmbedderOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + embedderOptionsProto, {[testCase.protoName]: testCase.customValue}); + + embedderOptionsProto = + convertEmbedderOptionsToProto({[testCase.optionName]: undefined}); + verifyOption( + embedderOptionsProto, {[testCase.protoName]: testCase.defaultValue}); + }); + } + + it('overwrites options', () => { + let embedderOptionsProto = + convertEmbedderOptionsToProto({l2Normalize: true}); + verifyOption(embedderOptionsProto, {'l2Normalize': true}); + + embedderOptionsProto = convertEmbedderOptionsToProto( + {l2Normalize: false}, embedderOptionsProto); + verifyOption(embedderOptionsProto, {'l2Normalize': false}); + }); + + it('replaces options', () => { + let embedderOptionsProto = convertEmbedderOptionsToProto({quantize: true}); + verifyOption(embedderOptionsProto, {'quantize': true}); + + embedderOptionsProto = convertEmbedderOptionsToProto( + {l2Normalize: true}, embedderOptionsProto); + verifyOption(embedderOptionsProto, {'l2Normalize': true, 'quantize': true}); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/embedder_result.test.ts b/mediapipe/tasks/web/components/processors/embedder_result.test.ts new file mode 100644 index 000000000..97ba935c8 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_result.test.ts @@ -0,0 +1,75 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {Embedding, EmbeddingResult, FloatEmbedding, QuantizedEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; + +import {convertFromEmbeddingResultProto} from './embedder_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +describe('convertFromEmbeddingResultProto()', () => { + it('transforms custom values', () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + embedding.setFloatEmbedding(floatEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + resultProto.setTimestampMs(1); + + const embedderResult = convertFromEmbeddingResultProto(resultProto); + const embeddings = embedderResult.embeddings; + const timestampMs = embedderResult.timestampMs; + expect(embeddings.length).toEqual(1); + expect(embeddings[0]) + .toEqual( + {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); + expect(timestampMs).toEqual(1); + }); + + it('transforms custom quantized values', () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const quantizedEmbedding = new QuantizedEmbedding(); + const quantizedValues = new Uint8Array([1, 2, 3]); + quantizedEmbedding.setValues(quantizedValues); + + embedding.setQuantizedEmbedding(quantizedEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + resultProto.setTimestampMs(1); + + const embedderResult = convertFromEmbeddingResultProto(resultProto); + const embeddings = embedderResult.embeddings; + const timestampMs = embedderResult.timestampMs; + expect(embeddings.length).toEqual(1); + expect(embeddings[0]).toEqual({ + quantizedEmbedding: new Uint8Array([1, 2, 3]), + headIndex: 1, + headName: 'headName' + }); + expect(timestampMs).toEqual(1); + }); +}); diff --git a/mediapipe/tasks/web/components/utils/BUILD b/mediapipe/tasks/web/components/utils/BUILD index 1c1ba69ca..f4a215e48 100644 --- a/mediapipe/tasks/web/components/utils/BUILD +++ b/mediapipe/tasks/web/components/utils/BUILD @@ -1,4 +1,5 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -9,3 +10,18 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:embedding_result", ], ) + +mediapipe_ts_library( + name = "cosine_similarity_test_lib", + testonly = True, + srcs = ["cosine_similarity.test.ts"], + deps = [ + ":cosine_similarity", + "//mediapipe/tasks/web/components/containers:embedding_result", + ], +) + +jasmine_node_test( + name = "cosine_similarity_test", + deps = [":cosine_similarity_test_lib"], +) diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts new file mode 100644 index 000000000..f442caa20 --- /dev/null +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts @@ -0,0 +1,85 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + *

Licensed under the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + *

http://www.apache.org/licenses/LICENSE-2.0 + * + *

Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; + +import {computeCosineSimilarity} from './cosine_similarity'; + +describe('computeCosineSimilarity', () => { + it('fails with quantized and float embeddings', () => { + const u: Embedding = {floatEmbedding: [1.0], headIndex: 0, headName: ''}; + const v: Embedding = { + quantizedEmbedding: new Uint8Array([1.0]), + headIndex: 0, + headName: '' + }; + + expect(() => computeCosineSimilarity(u, v)) + .toThrowError( + /Cannot compute cosine similarity between quantized and float embeddings/); + }); + + it('fails with zero norm', () => { + const u = {floatEmbedding: [0.0], headIndex: 0, headName: ''}; + expect(() => computeCosineSimilarity(u, u)) + .toThrowError( + /Cannot compute cosine similarity on embedding with 0 norm/); + }); + + it('fails with different sizes', () => { + const u: + Embedding = {floatEmbedding: [1.0, 2.0], headIndex: 0, headName: ''}; + const v: Embedding = { + floatEmbedding: [1.0, 2.0, 3.0], + headIndex: 0, + headName: '' + }; + + expect(() => computeCosineSimilarity(u, v)) + .toThrowError( + /Cannot compute cosine similarity between embeddings of different sizes/); + }); + + it('succeeds with float embeddings', () => { + const u: Embedding = { + floatEmbedding: [1.0, 0.0, 0.0, 0.0], + headIndex: 0, + headName: '' + }; + const v: Embedding = { + floatEmbedding: [0.5, 0.5, 0.5, 0.5], + headIndex: 0, + headName: '' + }; + + expect(computeCosineSimilarity(u, v)).toEqual(0.5); + }); + + it('succeeds with quantized embeddings', () => { + const u: Embedding = { + quantizedEmbedding: new Uint8Array([255, 128, 128, 128]), + headIndex: 0, + headName: '' + }; + const v: Embedding = { + quantizedEmbedding: new Uint8Array([0, 128, 128, 128]), + headIndex: 0, + headName: '' + }; + + expect(computeCosineSimilarity(u, v)).toEqual(-1.0); + }); +}); diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index be1b71f5d..1721661f5 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -1,6 +1,7 @@ # This package contains options shared by all MediaPipe Tasks for Web. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -32,6 +33,38 @@ mediapipe_ts_library( deps = [":core"], ) +mediapipe_ts_library( + name = "task_runner_test_utils", + testonly = True, + srcs = [ + "task_runner_test_utils.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/web/graph_runner:graph_runner_ts", + "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", + ], +) + +mediapipe_ts_library( + name = "task_runner_test_lib", + testonly = True, + srcs = [ + "task_runner_test.ts", + ], + deps = [ + ":task_runner", + ":task_runner_test_utils", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +jasmine_node_test( + name = "task_runner_test", + deps = [":task_runner_test_lib"], +) + mediapipe_ts_declaration( name = "classifier_options", srcs = ["classifier_options.d.ts"], diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 71e159dce..6712c4d89 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -77,9 +77,10 @@ export abstract class TaskRunner { } constructor( - wasmModule: WasmModule, - glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - this.graphRunner = new GraphRunnerImageLib(wasmModule, glCanvas); + wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, + graphRunner?: GraphRunnerImageLib) { + this.graphRunner = + graphRunner ?? new GraphRunnerImageLib(wasmModule, glCanvas); // Disables the automatic render-to-screen code, which allows for pure // CPU processing. diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts new file mode 100644 index 000000000..c9aad9d25 --- /dev/null +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -0,0 +1,107 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; +import {TaskRunner} from '../../../tasks/web/core/task_runner'; +import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; +import {ErrorListener} from '../../../web/graph_runner/graph_runner'; + +import {GraphRunnerImageLib} from './task_runner'; + +class TaskRunnerFake extends TaskRunner { + protected baseOptions = new BaseOptionsProto(); + private errorListener: ErrorListener|undefined; + private errors: string[] = []; + + static createFake(): TaskRunnerFake { + const wasmModule = createSpyWasmModule(); + return new TaskRunnerFake(wasmModule); + } + + constructor(wasmModuleFake: SpyWasmModule) { + super( + wasmModuleFake, /* glCanvas= */ null, + jasmine.createSpyObj([ + 'setAutoRenderToScreen', 'setGraph', 'finishProcessing', + 'registerModelResourcesGraphService', 'attachErrorListener' + ])); + const graphRunner = this.graphRunner as jasmine.SpyObj; + expect(graphRunner.registerModelResourcesGraphService).toHaveBeenCalled(); + expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled(); + graphRunner.attachErrorListener.and.callFake(listener => { + this.errorListener = listener; + }); + graphRunner.setGraph.and.callFake(() => { + this.throwErrors(); + }); + graphRunner.finishProcessing.and.callFake(() => { + this.throwErrors(); + }); + } + + enqueueError(message: string): void { + this.errors.push(message); + } + + override finishProcessing(): void { + super.finishProcessing(); + } + + override setGraph(graphData: Uint8Array, isBinary: boolean): void { + super.setGraph(graphData, isBinary); + } + + private throwErrors(): void { + expect(this.errorListener).toBeDefined(); + for (const error of this.errors) { + this.errorListener!(/* errorCode= */ -1, error); + } + this.errors = []; + } +} + +describe('TaskRunner', () => { + it('handles errors during graph update', () => { + const taskRunner = TaskRunnerFake.createFake(); + taskRunner.enqueueError('Test error'); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).toThrowError('Test error'); + }); + + it('handles errors during graph execution', () => { + const taskRunner = TaskRunnerFake.createFake(); + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + + taskRunner.enqueueError('Test error'); + + expect(() => { + taskRunner.finishProcessing(); + }).toThrowError('Test error'); + }); + + it('can handle multiple errors', () => { + const taskRunner = TaskRunnerFake.createFake(); + taskRunner.enqueueError('Test error 1'); + taskRunner.enqueueError('Test error 2'); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).toThrowError(/Test error 1, Test error 2/); + }); +}); diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts new file mode 100644 index 000000000..2a1161a55 --- /dev/null +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -0,0 +1,113 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../framework/calculator_pb'; +import {WasmModule} from '../../../web/graph_runner/graph_runner'; +import {WasmModuleRegisterModelResources} from '../../../web/graph_runner/register_model_resources_graph_service'; + +type SpyWasmModuleInternal = WasmModule&WasmModuleRegisterModelResources; + +/** + * Convenience type for our fake WasmModule for Jasmine testing. + */ +export declare type SpyWasmModule = jasmine.SpyObj; + +/** + * Factory function for creating a fake WasmModule for our Jasmine tests, + * allowing our APIs to no longer rely on the Wasm layer so they can run tests + * in pure JS/TS (and optionally spy on the calls). + */ +export function createSpyWasmModule(): SpyWasmModule { + return jasmine.createSpyObj([ + '_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener', + '_attachProtoVectorListener', '_free', '_waitUntilIdle', + '_addStringToInputStream', '_registerModelResourcesGraphService', + '_configureAudio' + ]); +} + +/** + * Sets up our equality testing to use a custom float equality checking function + * to avoid incorrect test results due to minor floating point inaccuracies. + */ +export function addJasmineCustomFloatEqualityTester() { + jasmine.addCustomEqualityTester((a, b) => { // Custom float equality + if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) { + return Math.abs(a - b) < 5e-8; + } + return; + }); +} + +/** The minimum interface provided by a test fake. */ +export interface MediapipeTasksFake { + graph: CalculatorGraphConfig|undefined; + calculatorName: string; + attachListenerSpies: jasmine.Spy[]; +} + +/** An map of field paths to values */ +export type FieldPathToValue = [string[] | string, unknown]; + +/** + * Verifies that the graph has been initialized and that it contains the + * provided options. + */ +export function verifyGraph( + tasksFake: MediapipeTasksFake, + expectedCalculatorOptions?: FieldPathToValue, + expectedBaseOptions?: FieldPathToValue, + ): void { + expect(tasksFake.graph).toBeDefined(); + expect(tasksFake.graph!.getNodeList().length).toBe(1); + const node = tasksFake.graph!.getNodeList()[0].toObject(); + expect(node).toEqual( + jasmine.objectContaining({calculator: tasksFake.calculatorName})); + + if (expectedBaseOptions) { + const [fieldPath, value] = expectedBaseOptions; + let proto = (node.options as {ext: {baseOptions: unknown}}).ext.baseOptions; + for (const fieldName of ( + Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { + proto = ((proto ?? {}) as Record)[fieldName]; + } + expect(proto).toEqual(value); + } + + if (expectedCalculatorOptions) { + const [fieldPath, value] = expectedCalculatorOptions; + let proto = (node.options as {ext: unknown}).ext; + for (const fieldName of ( + Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { + proto = ((proto ?? {}) as Record)[fieldName]; + } + expect(proto).toEqual(value); + } +} + +/** + * Verifies all listeners (as exposed by `.attachListenerSpies`) have been + * attached at least once since the last call to `verifyListenersRegistered()`. + * This helps us to ensure that listeners are re-registered with every graph + * update. + */ +export function verifyListenersRegistered(tasksFake: MediapipeTasksFake): void { + for (const spy of tasksFake.attachListenerSpies) { + expect(spy.calls.count()).toBeGreaterThanOrEqual(1); + spy.calls.reset(); + } +} diff --git a/package.json b/package.json index 6ad0b52c0..89b62bc83 100644 --- a/package.json +++ b/package.json @@ -3,14 +3,19 @@ "version": "0.0.0-alphga", "description": "MediaPipe GitHub repo", "devDependencies": { + "@bazel/jasmine": "^5.7.2", "@bazel/rollup": "^5.7.1", "@bazel/typescript": "^5.7.1", "@rollup/plugin-commonjs": "^23.0.2", "@rollup/plugin-node-resolve": "^15.0.1", "@rollup/plugin-terser": "^0.1.0", "@types/google-protobuf": "^3.15.6", + "@types/jasmine": "^4.3.1", + "@types/node": "^18.11.11", "@types/offscreencanvas": "^2019.7.0", "google-protobuf": "^3.21.2", + "jasmine": "^4.5.0", + "jasmine-core": "^4.5.0", "protobufjs": "^7.1.2", "protobufjs-cli": "^1.0.2", "rollup": "^2.3.0", diff --git a/tsconfig.json b/tsconfig.json index c17b1902e..970246dbb 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -10,7 +10,7 @@ "inlineSourceMap": true, "inlineSources": true, "strict": true, - "types": ["@types/offscreencanvas"], + "types": ["@types/offscreencanvas", "@types/jasmine", "node"], "rootDirs": [ ".", "./bazel-out/host/bin", diff --git a/yarn.lock b/yarn.lock index 91b50456e..9c4d91d30 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3,34 +3,52 @@ "@babel/parser@^7.9.4": - version "7.20.3" - resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.3.tgz#5358cf62e380cf69efcb87a7bb922ff88bfac6e2" - integrity sha512-OP/s5a94frIPXwjzEcv5S/tpQfc6XhxYUnmWpgdqMWGgYCuErA3SzozaRAMQgSZWKeTJxht9aWAkUY+0UzvOFg== + version "7.20.5" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.5.tgz#7f3c7335fe417665d929f34ae5dceae4c04015e8" + integrity sha512-r27t/cy/m9uKLXQNWWebeCUHgnAZq0CpG1OwKRxzJMP1vpSU4bSIK2hq+/cp0bQxetkXx38n09rNu8jVkcK/zA== + +"@bazel/jasmine@^5.7.2": + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/jasmine/-/jasmine-5.7.2.tgz#438f272e66e939106cbdd58db709cd6aa008131b" + integrity sha512-RJruOB6S9e0efTNIE2JVdaslguUXh5KcmLUCq/xLCt0zENP44ssp9OooDIrZ8H+Sp4mLDNBX7CMMA9WTsbsxTQ== + dependencies: + c8 "~7.5.0" + jasmine-reporters "~2.5.0" "@bazel/rollup@^5.7.1": - version "5.7.1" - resolved "https://registry.yarnpkg.com/@bazel/rollup/-/rollup-5.7.1.tgz#6f644c2d493a5bd9cd3724a6f239e609585c6e37" - integrity sha512-LLNogoK2Qx9GIJVywQ+V/czjud8236mnaRX//g7qbOyXoWZDQvAEgsxRHq+lS/XX9USbh+zJJlfb+Dfp/PXx4A== + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/rollup/-/rollup-5.7.2.tgz#9953b06e3de52794791cee4f89540c263b035fcf" + integrity sha512-yGWLheSKdMnJ/Y3/qg+zCDx/qkD04FBFp+BjRS8xP4yvlz9G4rW3zc45VzHHz3oOywgQaY1vhfKuZMCcjTGEyA== dependencies: - "@bazel/worker" "5.7.1" + "@bazel/worker" "5.7.2" "@bazel/typescript@^5.7.1": - version "5.7.1" - resolved "https://registry.yarnpkg.com/@bazel/typescript/-/typescript-5.7.1.tgz#e585bcdc54a4ccb23d99c3e1206abf4853cf0682" - integrity sha512-MAnAtFxA2znadm81+rbYXcyWX1DEF/urzZ1F4LBq+w27EQ4PGyqIqCM5om7JcoSZJwjjMoBJc3SflRsMrZZ6+g== + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/typescript/-/typescript-5.7.2.tgz#a341215dc93ce28794e8430b311756816140bd78" + integrity sha512-tarBJBEIirnq/YaeYu18vXcDxjzlq4xhCXvXUxA0lhHX5oArjEcAEn4tmO0jF+t/7cbkAdMT7daG6vIHSz0QAA== dependencies: - "@bazel/worker" "5.7.1" + "@bazel/worker" "5.7.2" semver "5.6.0" source-map-support "0.5.9" tsutils "3.21.0" -"@bazel/worker@5.7.1": - version "5.7.1" - resolved "https://registry.yarnpkg.com/@bazel/worker/-/worker-5.7.1.tgz#2c4a9bd0e0ef75e496aec9599ff64a87307e7dad" - integrity sha512-UndmQVRqK0t0NMNl8I1P5XmxzdPvMA0X6jufszpfwy5gyzjOxeiOIzmC0ALCOx78CuJqOB/8WOI1pwTRmhd0tg== +"@bazel/worker@5.7.2": + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/worker/-/worker-5.7.2.tgz#43d800dc1b5a3707340a4eb0102da81c53fc6f63" + integrity sha512-H+auDA0QKF4mtZxKkZ2OKJvD7hGXVsVKtvcf4lbb93ur0ldpb5k810PcDxngmIGBcIX5kmyxniNTIiGFNobWTg== dependencies: google-protobuf "^3.6.1" +"@bcoe/v8-coverage@^0.2.3": + version "0.2.3" + resolved "https://registry.yarnpkg.com/@bcoe/v8-coverage/-/v8-coverage-0.2.3.tgz#75a2e8b51cb758a7553d6804a5932d7aace75c39" + integrity sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw== + +"@istanbuljs/schema@^0.1.2": + version "0.1.3" + resolved "https://registry.yarnpkg.com/@istanbuljs/schema/-/schema-0.1.3.tgz#e45e384e4b8ec16bce2fd903af78450f6bf7ec98" + integrity sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA== + "@jridgewell/gen-mapping@^0.3.0": version "0.3.2" resolved "https://registry.yarnpkg.com/@jridgewell/gen-mapping/-/gen-mapping-0.3.2.tgz#c1aedc61e853f2bb9f5dfe6d4442d3b565b253b9" @@ -125,9 +143,9 @@ integrity sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw== "@rollup/plugin-commonjs@^23.0.2": - version "23.0.2" - resolved "https://registry.yarnpkg.com/@rollup/plugin-commonjs/-/plugin-commonjs-23.0.2.tgz#3a3a5b7b1b1cb29037eb4992edcaae997d7ebd92" - integrity sha512-e9ThuiRf93YlVxc4qNIurvv+Hp9dnD+4PjOqQs5vAYfcZ3+AXSrcdzXnVjWxcGQOa6KGJFcRZyUI3ktWLavFjg== + version "23.0.3" + resolved "https://registry.yarnpkg.com/@rollup/plugin-commonjs/-/plugin-commonjs-23.0.3.tgz#442cd8ccca1b7563a503da86fc84a1a7112b54bb" + integrity sha512-31HxrT5emGfTyIfAs1lDQHj6EfYxTXcwtX5pIIhq+B/xZBNIqQ179d/CkYxlpYmFCxT78AeU4M8aL8Iv/IBxFA== dependencies: "@rollup/pluginutils" "^5.0.1" commondir "^1.0.1" @@ -174,6 +192,21 @@ resolved "https://registry.yarnpkg.com/@types/google-protobuf/-/google-protobuf-3.15.6.tgz#674a69493ef2c849b95eafe69167ea59079eb504" integrity sha512-pYVNNJ+winC4aek+lZp93sIKxnXt5qMkuKmaqS3WGuTq0Bw1ZDYNBgzG5kkdtwcv+GmYJGo3yEg6z2cKKAiEdw== +"@types/is-windows@^1.0.0": + version "1.0.0" + resolved "https://registry.yarnpkg.com/@types/is-windows/-/is-windows-1.0.0.tgz#1011fa129d87091e2f6faf9042d6704cdf2e7be0" + integrity sha512-tJ1rq04tGKuIJoWIH0Gyuwv4RQ3+tIu7wQrC0MV47raQ44kIzXSSFKfrxFUOWVRvesoF7mrTqigXmqoZJsXwTg== + +"@types/istanbul-lib-coverage@^2.0.1": + version "2.0.4" + resolved "https://registry.yarnpkg.com/@types/istanbul-lib-coverage/-/istanbul-lib-coverage-2.0.4.tgz#8467d4b3c087805d63580480890791277ce35c44" + integrity sha512-z/QT1XN4K4KYuslS23k62yDIDLwLFkzxOuMplDtObz0+y7VqJCaO2o+SPwHCvLFZh7xazvvoor2tA/hPz9ee7g== + +"@types/jasmine@^4.3.1": + version "4.3.1" + resolved "https://registry.yarnpkg.com/@types/jasmine/-/jasmine-4.3.1.tgz#2d8ab5601c2fe7d9673dcb157e03f128ab5c5fff" + integrity sha512-Vu8l+UGcshYmV1VWwULgnV/2RDbBaO6i2Ptx7nd//oJPIZGhoI1YLST4VKagD2Pq/Bc2/7zvtvhM7F3p4SN7kQ== + "@types/linkify-it@*": version "3.0.2" resolved "https://registry.yarnpkg.com/@types/linkify-it/-/linkify-it-3.0.2.tgz#fd2cd2edbaa7eaac7e7f3c1748b52a19143846c9" @@ -192,10 +225,10 @@ resolved "https://registry.yarnpkg.com/@types/mdurl/-/mdurl-1.0.2.tgz#e2ce9d83a613bacf284c7be7d491945e39e1f8e9" integrity sha512-eC4U9MlIcu2q0KQmXszyn5Akca/0jrQmwDRgpAMJai7qBWq4amIQhZyNau4VYGtCeALvW1/NtjzJJ567aZxfKA== -"@types/node@>=13.7.0": - version "18.11.9" - resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.9.tgz#02d013de7058cea16d36168ef2fc653464cfbad4" - integrity sha512-CRpX21/kGdzjOpFsZSkcrXMGIBWMGNIHXXBVFSH+ggkftxg+XYP20TESbh+zFvFj3EQOl5byk0HTRn1IL6hbqg== +"@types/node@>=13.7.0", "@types/node@^18.11.11": + version "18.11.11" + resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.11.tgz#1d455ac0211549a8409d3cdb371cd55cc971e8dc" + integrity sha512-KJ021B1nlQUBLopzZmPBVuGU9un7WJd/W4ya7Ih02B4Uwky5Nja0yGYav2EfYIk0RR2Q9oVhf60S2XR1BCWJ2g== "@types/offscreencanvas@^2019.7.0": version "2019.7.0" @@ -207,6 +240,11 @@ resolved "https://registry.yarnpkg.com/@types/resolve/-/resolve-1.20.2.tgz#97d26e00cd4a0423b4af620abecf3e6f442b7975" integrity sha512-60BCwRFOZCQhDncwQdxxeOEEkbc5dIMccYLwbxsS4TUNeVECQ/pBJ0j09mrHOl/JJvpRPGwO9SvE4nR2Nb/a4Q== +"@xmldom/xmldom@^0.8.5": + version "0.8.6" + resolved "https://registry.yarnpkg.com/@xmldom/xmldom/-/xmldom-0.8.6.tgz#8a1524eb5bd5e965c1e3735476f0262469f71440" + integrity sha512-uRjjusqpoqfmRkTaNuLJ2VohVr67Q5YwDATW3VU7PfzTj6IRaihGrYI7zckGZjxQPBIp63nfvJbM+Yu5ICh0Bg== + acorn-jsx@^5.3.2: version "5.3.2" resolved "https://registry.yarnpkg.com/acorn-jsx/-/acorn-jsx-5.3.2.tgz#7ed5bb55908b3b2f1bc55c6af1653bada7f07937" @@ -217,7 +255,12 @@ acorn@^8.5.0, acorn@^8.8.0: resolved "https://registry.yarnpkg.com/acorn/-/acorn-8.8.1.tgz#0a3f9cbecc4ec3bea6f0a80b66ae8dd2da250b73" integrity sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA== -ansi-styles@^4.1.0: +ansi-regex@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" + integrity sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ== + +ansi-styles@^4.0.0, ansi-styles@^4.1.0: version "4.3.0" resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937" integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg== @@ -264,6 +307,25 @@ builtin-modules@^3.3.0: resolved "https://registry.yarnpkg.com/builtin-modules/-/builtin-modules-3.3.0.tgz#cae62812b89801e9656336e46223e030386be7b6" integrity sha512-zhaCDicdLuWN5UbN5IMnFqNMhNfo919sH85y2/ea+5Yg9TsTkeZxpL+JLbp6cgYFS4sRLp3YV4S6yDuqVWHYOw== +c8@~7.5.0: + version "7.5.0" + resolved "https://registry.yarnpkg.com/c8/-/c8-7.5.0.tgz#a69439ab82848f344a74bb25dc5dd4e867764481" + integrity sha512-GSkLsbvDr+FIwjNSJ8OwzWAyuznEYGTAd1pzb/Kr0FMLuV4vqYJTyjboDTwmlUNAG6jAU3PFWzqIdKrOt1D8tw== + dependencies: + "@bcoe/v8-coverage" "^0.2.3" + "@istanbuljs/schema" "^0.1.2" + find-up "^5.0.0" + foreground-child "^2.0.0" + furi "^2.0.0" + istanbul-lib-coverage "^3.0.0" + istanbul-lib-report "^3.0.0" + istanbul-reports "^3.0.2" + rimraf "^3.0.0" + test-exclude "^6.0.0" + v8-to-istanbul "^7.1.0" + yargs "^16.0.0" + yargs-parser "^20.0.0" + catharsis@^0.9.0: version "0.9.0" resolved "https://registry.yarnpkg.com/catharsis/-/catharsis-0.9.0.tgz#40382a168be0e6da308c277d3a2b3eb40c7d2121" @@ -279,6 +341,15 @@ chalk@^4.0.0: ansi-styles "^4.1.0" supports-color "^7.1.0" +cliui@^7.0.2: + version "7.0.4" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.4.tgz#a0265ee655476fc807aea9df3df8df7783808b4f" + integrity sha512-OcRE68cOsVMXp1Yvonl/fzkQOyjLSu/8bhPDfQt0e0/Eb283TKP20Fs2MqoPsr9SwA595rRCA+QMzYc9nBP+JQ== + dependencies: + string-width "^4.2.0" + strip-ansi "^6.0.0" + wrap-ansi "^7.0.0" + color-convert@^2.0.1: version "2.0.1" resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3" @@ -306,6 +377,20 @@ concat-map@0.0.1: resolved "https://registry.yarnpkg.com/concat-map/-/concat-map-0.0.1.tgz#d8a96bd77fd68df7793a73036a3ba0d5405d477b" integrity sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg== +convert-source-map@^1.6.0: + version "1.9.0" + resolved "https://registry.yarnpkg.com/convert-source-map/-/convert-source-map-1.9.0.tgz#7faae62353fb4213366d0ca98358d22e8368b05f" + integrity sha512-ASFBup0Mz1uyiIjANan1jzLQami9z1PoYSZCiiYW2FczPbenXc45FZdBZLzOT+r6+iciuEModtmCti+hjaAk0A== + +cross-spawn@^7.0.0: + version "7.0.3" + resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.3.tgz#f73a85b9d5d41d045551c177e2882d4ac85728a6" + integrity sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w== + dependencies: + path-key "^3.1.0" + shebang-command "^2.0.0" + which "^2.0.1" + deep-is@~0.1.3: version "0.1.4" resolved "https://registry.yarnpkg.com/deep-is/-/deep-is-0.1.4.tgz#a6f2dce612fadd2ef1f519b73551f17e85199831" @@ -316,11 +401,21 @@ deepmerge@^4.2.2: resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-4.2.2.tgz#44d2ea3679b8f4d4ffba33f03d865fc1e7bf4955" integrity sha512-FJ3UgI4gIl+PHZm53knsuSFpE+nESMr7M4v9QcgB7S63Kj/6WqMiFQJpBBYz1Pt+66bZpP3Q7Lye0Oo9MPKEdg== +emoji-regex@^8.0.0: + version "8.0.0" + resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-8.0.0.tgz#e818fd69ce5ccfcb404594f842963bf53164cc37" + integrity sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A== + entities@~2.1.0: version "2.1.0" resolved "https://registry.yarnpkg.com/entities/-/entities-2.1.0.tgz#992d3129cf7df6870b96c57858c249a120f8b8b5" integrity sha512-hCx1oky9PFrJ611mf0ifBLBRW8lUUVRlFolb5gWRfIELabBlbp9xZvrqZLZAs+NxFnbfQoeGd8wDkygjg7U85w== +escalade@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" + integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== + escape-string-regexp@^2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-2.0.0.tgz#a30304e99daa32e23b2fd20f51babd07cffca344" @@ -382,6 +477,22 @@ fast-levenshtein@~2.0.6: resolved "https://registry.yarnpkg.com/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz#3d8a5c66883a16a30ca8643e851f19baa7797917" integrity sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw== +find-up@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/find-up/-/find-up-5.0.0.tgz#4c92819ecb7083561e4f4a240a86be5198f536fc" + integrity sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng== + dependencies: + locate-path "^6.0.0" + path-exists "^4.0.0" + +foreground-child@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/foreground-child/-/foreground-child-2.0.0.tgz#71b32800c9f15aa8f2f83f4a6bd9bff35d861a53" + integrity sha512-dCIq9FpEcyQyXKCkyzmlPTFNgrCzPudOe+mhvJU5zAtlBnGVy2yKxtfsxK2tQBThwq225jcvBjpw1Gr40uzZCA== + dependencies: + cross-spawn "^7.0.0" + signal-exit "^3.0.2" + fs.realpath@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/fs.realpath/-/fs.realpath-1.0.0.tgz#1504ad2523158caa40db4a2787cb01411994ea4f" @@ -397,7 +508,20 @@ function-bind@^1.1.1: resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== -glob@^7.1.3: +furi@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/furi/-/furi-2.0.0.tgz#13d85826a1af21acc691da6254b3888fc39f0b4a" + integrity sha512-uKuNsaU0WVaK/vmvj23wW1bicOFfyqSsAIH71bRZx8kA4Xj+YCHin7CJKJJjkIsmxYaPFLk9ljmjEyB7xF7WvQ== + dependencies: + "@types/is-windows" "^1.0.0" + is-windows "^1.0.2" + +get-caller-file@^2.0.5: + version "2.0.5" + resolved "https://registry.yarnpkg.com/get-caller-file/-/get-caller-file-2.0.5.tgz#4f94412a82db32f36e3b0b9741f8a97feb031f7e" + integrity sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg== + +glob@^7.1.3, glob@^7.1.4, glob@^7.1.6: version "7.2.3" resolved "https://registry.yarnpkg.com/glob/-/glob-7.2.3.tgz#b8df0fb802bbfa8e89bd1d938b4e16578ed44f2b" integrity sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q== @@ -442,6 +566,11 @@ has@^1.0.3: dependencies: function-bind "^1.1.1" +html-escaper@^2.0.0: + version "2.0.2" + resolved "https://registry.yarnpkg.com/html-escaper/-/html-escaper-2.0.2.tgz#dfd60027da36a36dfcbe236262c00a5822681453" + integrity sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg== + inflight@^1.0.4: version "1.0.6" resolved "https://registry.yarnpkg.com/inflight/-/inflight-1.0.6.tgz#49bd6331d7d02d0c09bc910a1075ba8165b56df9" @@ -469,6 +598,11 @@ is-core-module@^2.9.0: dependencies: has "^1.0.3" +is-fullwidth-code-point@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz#f116f8064fe90b3f7844a38997c0b75051269f1d" + integrity sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg== + is-module@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/is-module/-/is-module-1.0.0.tgz#3258fb69f78c14d5b815d664336b4cffb6441591" @@ -481,6 +615,59 @@ is-reference@1.2.1: dependencies: "@types/estree" "*" +is-windows@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/is-windows/-/is-windows-1.0.2.tgz#d1850eb9791ecd18e6182ce12a30f396634bb19d" + integrity sha512-eXK1UInq2bPmjyX6e3VHIzMLobc4J94i4AWn+Hpq3OU5KkrRC96OAcR3PRJ/pGu6m8TRnBHP9dkXQVsT/COVIA== + +isexe@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/isexe/-/isexe-2.0.0.tgz#e8fbf374dc556ff8947a10dcb0572d633f2cfa10" + integrity sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw== + +istanbul-lib-coverage@^3.0.0: + version "3.2.0" + resolved "https://registry.yarnpkg.com/istanbul-lib-coverage/-/istanbul-lib-coverage-3.2.0.tgz#189e7909d0a39fa5a3dfad5b03f71947770191d3" + integrity sha512-eOeJ5BHCmHYvQK7xt9GkdHuzuCGS1Y6g9Gvnx3Ym33fz/HpLRYxiS0wHNr+m/MBC8B647Xt608vCDEvhl9c6Mw== + +istanbul-lib-report@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/istanbul-lib-report/-/istanbul-lib-report-3.0.0.tgz#7518fe52ea44de372f460a76b5ecda9ffb73d8a6" + integrity sha512-wcdi+uAKzfiGT2abPpKZ0hSU1rGQjUQnLvtY5MpQ7QCTahD3VODhcu4wcfY1YtkGaDD5yuydOLINXsfbus9ROw== + dependencies: + istanbul-lib-coverage "^3.0.0" + make-dir "^3.0.0" + supports-color "^7.1.0" + +istanbul-reports@^3.0.2: + version "3.1.5" + resolved "https://registry.yarnpkg.com/istanbul-reports/-/istanbul-reports-3.1.5.tgz#cc9a6ab25cb25659810e4785ed9d9fb742578bae" + integrity sha512-nUsEMa9pBt/NOHqbcbeJEgqIlY/K7rVWUX6Lql2orY5e9roQOthbR3vtY4zzf2orPELg80fnxxk9zUyPlgwD1w== + dependencies: + html-escaper "^2.0.0" + istanbul-lib-report "^3.0.0" + +jasmine-core@^4.5.0: + version "4.5.0" + resolved "https://registry.yarnpkg.com/jasmine-core/-/jasmine-core-4.5.0.tgz#1a6bd0bde3f60996164311c88a0995d67ceda7c3" + integrity sha512-9PMzyvhtocxb3aXJVOPqBDswdgyAeSB81QnLop4npOpbqnheaTEwPc9ZloQeVswugPManznQBjD8kWDTjlnHuw== + +jasmine-reporters@~2.5.0: + version "2.5.2" + resolved "https://registry.yarnpkg.com/jasmine-reporters/-/jasmine-reporters-2.5.2.tgz#b5dfa1d9c40b8020c5225e0e1e2b9953d66a4d69" + integrity sha512-qdewRUuFOSiWhiyWZX8Yx3YNQ9JG51ntBEO4ekLQRpktxFTwUHy24a86zD/Oi2BRTKksEdfWQZcQFqzjqIkPig== + dependencies: + "@xmldom/xmldom" "^0.8.5" + mkdirp "^1.0.4" + +jasmine@^4.5.0: + version "4.5.0" + resolved "https://registry.yarnpkg.com/jasmine/-/jasmine-4.5.0.tgz#8d3c0d0a33a61e4d05c9f9747ee5dfaf6f7b5d3d" + integrity sha512-9olGRvNZyADIwYL9XBNBst5BTU/YaePzuddK+YRslc7rI9MdTIE4r3xaBKbv2GEmzYYUfMOdTR8/i6JfLZaxSQ== + dependencies: + glob "^7.1.6" + jasmine-core "^4.5.0" + js2xmlparser@^4.0.2: version "4.0.2" resolved "https://registry.yarnpkg.com/js2xmlparser/-/js2xmlparser-4.0.2.tgz#2a1fdf01e90585ef2ae872a01bc169c6a8d5e60a" @@ -531,7 +718,14 @@ linkify-it@^3.0.1: dependencies: uc.micro "^1.0.1" -lodash@^4.17.14, lodash@^4.17.15: +locate-path@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-6.0.0.tgz#55321eb309febbc59c4801d931a72452a681d286" + integrity sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw== + dependencies: + p-locate "^5.0.0" + +lodash@^4.17.15, lodash@^4.17.21: version "4.17.21" resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c" integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg== @@ -555,6 +749,13 @@ magic-string@^0.26.4: dependencies: sourcemap-codec "^1.4.8" +make-dir@^3.0.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/make-dir/-/make-dir-3.1.0.tgz#415e967046b3a7f1d185277d84aa58203726a13f" + integrity sha512-g3FeP20LNwhALb/6Cz6Dd4F2ngze0jz7tbzrD2wAV+o9FeNHe4rL+yK2md0J/fiSf1sa1ADhXqi5+oVwOM/eGw== + dependencies: + semver "^6.0.0" + markdown-it-anchor@^8.4.1: version "8.6.5" resolved "https://registry.yarnpkg.com/markdown-it-anchor/-/markdown-it-anchor-8.6.5.tgz#30c4bc5bbff327f15ce3c429010ec7ba75e7b5f8" @@ -572,16 +773,16 @@ markdown-it@^12.3.2: uc.micro "^1.0.5" marked@^4.0.10: - version "4.2.2" - resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.2.tgz#1d2075ad6cdfe42e651ac221c32d949a26c0672a" - integrity sha512-JjBTFTAvuTgANXx82a5vzK9JLSMoV6V3LBVn4Uhdso6t7vXrGx7g1Cd2r6NYSsxrYbQGFCMqBDhFHyK5q2UvcQ== + version "4.2.3" + resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.3.tgz#bd76a5eb510ff1d8421bc6c3b2f0b93488c15bea" + integrity sha512-slWRdJkbTZ+PjkyJnE30Uid64eHwbwa1Q25INCAYfZlK4o6ylagBy/Le9eWntqJFoFT93ikUKMv47GZ4gTwHkw== mdurl@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/mdurl/-/mdurl-1.0.1.tgz#fe85b2ec75a59037f2adfec100fd6c601761152e" integrity sha512-/sKlQJCBYVY9Ers9hqzKou4H6V5UWc/M59TH2dvkt+84itfnq7uFOMLpOiOS4ujvHP4etln18fmIxA5R5fll0g== -minimatch@^3.1.1: +minimatch@^3.0.4, minimatch@^3.1.1: version "3.1.2" resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-3.1.2.tgz#19cd194bfd3e428f049a70817c038d89ab4be35b" integrity sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw== @@ -589,9 +790,9 @@ minimatch@^3.1.1: brace-expansion "^1.1.7" minimatch@^5.0.1: - version "5.1.0" - resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-5.1.0.tgz#1717b464f4971b144f6aabe8f2d0b8e4511e09c7" - integrity sha512-9TPBGGak4nHfGZsPBohm9AWg6NoT7QTCehS3BIJABslyZbzxfV78QM2Y6+i741OPZIafFAaiiEMh5OyIrJPgtg== + version "5.1.1" + resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-5.1.1.tgz#6c9dffcf9927ff2a31e74b5af11adf8b9604b022" + integrity sha512-362NP+zlprccbEt/SkxKfRMHnNY85V74mVnpUpNyr3F35covl09Kec7/sEFLt3RA4oXmewtoaanoIf67SE5Y5g== dependencies: brace-expansion "^2.0.1" @@ -624,11 +825,35 @@ optionator@^0.8.1: type-check "~0.3.2" word-wrap "~1.2.3" +p-limit@^3.0.2: + version "3.1.0" + resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-3.1.0.tgz#e1daccbe78d0d1388ca18c64fea38e3e57e3706b" + integrity sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ== + dependencies: + yocto-queue "^0.1.0" + +p-locate@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/p-locate/-/p-locate-5.0.0.tgz#83c8315c6785005e3bd021839411c9e110e6d834" + integrity sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw== + dependencies: + p-limit "^3.0.2" + +path-exists@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-4.0.0.tgz#513bdbe2d3b95d7762e8c1137efa195c6c61b5b3" + integrity sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w== + path-is-absolute@^1.0.0: version "1.0.1" resolved "https://registry.yarnpkg.com/path-is-absolute/-/path-is-absolute-1.0.1.tgz#174b9268735534ffbc7ace6bf53a5a9e1b5c5f5f" integrity sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg== +path-key@^3.1.0: + version "3.1.1" + resolved "https://registry.yarnpkg.com/path-key/-/path-key-3.1.1.tgz#581f6ade658cbba65a0d3380de7753295054f375" + integrity sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q== + path-parse@^1.0.7: version "1.0.7" resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.7.tgz#fbc114b60ca42b30d9daf5858e4bd68bbedb6735" @@ -678,12 +903,17 @@ protobufjs@^7.1.2: "@types/node" ">=13.7.0" long "^5.0.0" +require-directory@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" + integrity sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q== + requizzle@^0.2.3: - version "0.2.3" - resolved "https://registry.yarnpkg.com/requizzle/-/requizzle-0.2.3.tgz#4675c90aacafb2c036bd39ba2daa4a1cb777fded" - integrity sha512-YanoyJjykPxGHii0fZP0uUPEXpvqfBDxWV7s6GKAiiOsiqhX6vHNyW3Qzdmqp/iq/ExbhaGbVrjB4ruEVSM4GQ== + version "0.2.4" + resolved "https://registry.yarnpkg.com/requizzle/-/requizzle-0.2.4.tgz#319eb658b28c370f0c20f968fa8ceab98c13d27c" + integrity sha512-JRrFk1D4OQ4SqovXOgdav+K8EAhSB/LJZqCz8tbX0KObcdeM15Ss59ozWMBWmmINMagCwmqn4ZNryUGpBsl6Jw== dependencies: - lodash "^4.17.14" + lodash "^4.17.21" resolve@^1.22.1: version "1.22.1" @@ -713,6 +943,11 @@ semver@5.6.0: resolved "https://registry.yarnpkg.com/semver/-/semver-5.6.0.tgz#7e74256fbaa49c75aa7c7a205cc22799cac80004" integrity sha512-RS9R6R35NYgQn++fkDWaOmqGoj4Ek9gGs+DPxNUZKuwE183xjJroKvyo1IzVFeXvUrvmALy6FWD5xrdJT25gMg== +semver@^6.0.0: + version "6.3.0" + resolved "https://registry.yarnpkg.com/semver/-/semver-6.3.0.tgz#ee0a64c8af5e8ceea67687b133761e1becbd1d3d" + integrity sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw== + semver@^7.1.2: version "7.3.8" resolved "https://registry.yarnpkg.com/semver/-/semver-7.3.8.tgz#07a78feafb3f7b32347d725e33de7e2a2df67798" @@ -720,6 +955,23 @@ semver@^7.1.2: dependencies: lru-cache "^6.0.0" +shebang-command@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/shebang-command/-/shebang-command-2.0.0.tgz#ccd0af4f8835fbdc265b82461aaf0c36663f34ea" + integrity sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA== + dependencies: + shebang-regex "^3.0.0" + +shebang-regex@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/shebang-regex/-/shebang-regex-3.0.0.tgz#ae16f1644d873ecad843b0307b143362d4c42172" + integrity sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A== + +signal-exit@^3.0.2: + version "3.0.7" + resolved "https://registry.yarnpkg.com/signal-exit/-/signal-exit-3.0.7.tgz#a9a1767f8af84155114eaabd73f99273c8f59ad9" + integrity sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ== + source-map-support@0.5.9: version "0.5.9" resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.9.tgz#41bc953b2534267ea2d605bccfa7bfa3111ced5f" @@ -741,11 +993,32 @@ source-map@^0.6.0, source-map@~0.6.1: resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.6.1.tgz#74722af32e9614e9c287a8d0bbde48b5e2f1a263" integrity sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g== +source-map@^0.7.3: + version "0.7.4" + resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.7.4.tgz#a9bbe705c9d8846f4e08ff6765acf0f1b0898656" + integrity sha512-l3BikUxvPOcn5E74dZiq5BGsTb5yEwhaTSzccU6t4sDOH8NWJCstKO5QT2CvtFoK6F0saL7p9xHAqHOlCPJygA== + sourcemap-codec@^1.4.8: version "1.4.8" resolved "https://registry.yarnpkg.com/sourcemap-codec/-/sourcemap-codec-1.4.8.tgz#ea804bd94857402e6992d05a38ef1ae35a9ab4c4" integrity sha512-9NykojV5Uih4lgo5So5dtw+f0JgJX30KCNI8gwhz2J9A15wD0Ml6tjHKwf6fTSa6fAdVBdZeNOs9eJ71qCk8vA== +string-width@^4.1.0, string-width@^4.2.0: + version "4.2.3" + resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" + integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== + dependencies: + emoji-regex "^8.0.0" + is-fullwidth-code-point "^3.0.0" + strip-ansi "^6.0.1" + +strip-ansi@^6.0.0, strip-ansi@^6.0.1: + version "6.0.1" + resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9" + integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== + dependencies: + ansi-regex "^5.0.1" + strip-json-comments@^3.1.0: version "3.1.1" resolved "https://registry.yarnpkg.com/strip-json-comments/-/strip-json-comments-3.1.1.tgz#31f1281b3832630434831c310c01cccda8cbe006" @@ -769,15 +1042,24 @@ taffydb@2.6.2: integrity sha512-y3JaeRSplks6NYQuCOj3ZFMO3j60rTwbuKCvZxsAraGYH2epusatvZ0baZYA01WsGqJBq/Dl6vOrMUJqyMj8kA== terser@^5.15.1: - version "5.15.1" - resolved "https://registry.yarnpkg.com/terser/-/terser-5.15.1.tgz#8561af6e0fd6d839669c73b92bdd5777d870ed6c" - integrity sha512-K1faMUvpm/FBxjBXud0LWVAGxmvoPbZbfTCYbSgaaYQaIXI3/TdI7a7ZGA73Zrou6Q8Zmz3oeUTsp/dj+ag2Xw== + version "5.16.1" + resolved "https://registry.yarnpkg.com/terser/-/terser-5.16.1.tgz#5af3bc3d0f24241c7fb2024199d5c461a1075880" + integrity sha512-xvQfyfA1ayT0qdK47zskQgRZeWLoOQ8JQ6mIgRGVNwZKdQMU+5FkCBjmv4QjcrTzyZquRw2FVtlJSRUmMKQslw== dependencies: "@jridgewell/source-map" "^0.3.2" acorn "^8.5.0" commander "^2.20.0" source-map-support "~0.5.20" +test-exclude@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/test-exclude/-/test-exclude-6.0.0.tgz#04a8698661d805ea6fa293b6cb9e63ac044ef15e" + integrity sha512-cAGWPIyOHU6zlmg88jwm7VRyXnMN7iV68OGAbYDk/Mh/xC/pzVPlQtY6ngoIH/5/tciuhGfvESU8GrHrcxD56w== + dependencies: + "@istanbuljs/schema" "^0.1.2" + glob "^7.1.4" + minimatch "^3.0.4" + tmp@^0.2.1: version "0.2.1" resolved "https://registry.yarnpkg.com/tmp/-/tmp-0.2.1.tgz#8457fc3037dcf4719c251367a1af6500ee1ccf14" @@ -812,9 +1094,9 @@ type-check@~0.3.2: prelude-ls "~1.1.2" typescript@^4.8.4: - version "4.8.4" - resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.8.4.tgz#c464abca159669597be5f96b8943500b238e60e6" - integrity sha512-QCh+85mCy+h0IGff8r5XWzOVSbBO+KfeYrMQh7NJ58QujwcE22u+NUSmUxqF+un70P9GXKxa2HCNiTTMJknyjQ== + version "4.9.3" + resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.9.3.tgz#3aea307c1746b8c384435d8ac36b8a2e580d85db" + integrity sha512-CIfGzTelbKNEnLpLdGFgdyKhG23CKdKgQPOBc+OUNrkJ2vr+KSzsSV5kq5iWhEQbok+quxgGzrAtGWCyU7tHnA== uc.micro@^1.0.1, uc.micro@^1.0.5: version "1.0.6" @@ -831,11 +1113,36 @@ underscore@~1.13.2: resolved "https://registry.yarnpkg.com/underscore/-/underscore-1.13.6.tgz#04786a1f589dc6c09f761fc5f45b89e935136441" integrity sha512-+A5Sja4HP1M08MaXya7p5LvjuM7K6q/2EaC0+iovj/wOcMsTzMvDFbasi/oSapiwOlt252IqsKqPjCl7huKS0A== +v8-to-istanbul@^7.1.0: + version "7.1.2" + resolved "https://registry.yarnpkg.com/v8-to-istanbul/-/v8-to-istanbul-7.1.2.tgz#30898d1a7fa0c84d225a2c1434fb958f290883c1" + integrity sha512-TxNb7YEUwkLXCQYeudi6lgQ/SZrzNO4kMdlqVxaZPUIUjCv6iSSypUQX70kNBSERpQ8fk48+d61FXk+tgqcWow== + dependencies: + "@types/istanbul-lib-coverage" "^2.0.1" + convert-source-map "^1.6.0" + source-map "^0.7.3" + +which@^2.0.1: + version "2.0.2" + resolved "https://registry.yarnpkg.com/which/-/which-2.0.2.tgz#7c6a8dd0a636a0327e10b59c9286eee93f3f51b1" + integrity sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA== + dependencies: + isexe "^2.0.0" + word-wrap@~1.2.3: version "1.2.3" resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c" integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ== +wrap-ansi@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43" + integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== + dependencies: + ansi-styles "^4.0.0" + string-width "^4.1.0" + strip-ansi "^6.0.0" + wrappy@1: version "1.0.2" resolved "https://registry.yarnpkg.com/wrappy/-/wrappy-1.0.2.tgz#b5243d8f3ec1aa35f1364605bc0d1036e30ab69f" @@ -846,7 +1153,35 @@ xmlcreate@^2.0.4: resolved "https://registry.yarnpkg.com/xmlcreate/-/xmlcreate-2.0.4.tgz#0c5ab0f99cdd02a81065fa9cd8f8ae87624889be" integrity sha512-nquOebG4sngPmGPICTS5EnxqhKbCmz5Ox5hsszI2T6U5qdrJizBc+0ilYSEjTSzU0yZcmvppztXe/5Al5fUwdg== +y18n@^5.0.5: + version "5.0.8" + resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.8.tgz#7f4934d0f7ca8c56f95314939ddcd2dd91ce1d55" + integrity sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA== + yallist@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72" integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A== + +yargs-parser@^20.0.0, yargs-parser@^20.2.2: + version "20.2.9" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-20.2.9.tgz#2eb7dc3b0289718fc295f362753845c41a0c94ee" + integrity sha512-y11nGElTIV+CT3Zv9t7VKl+Q3hTQoT9a1Qzezhhl6Rp21gJ/IVTW7Z3y9EWXhuUBC2Shnf+DX0antecpAwSP8w== + +yargs@^16.0.0: + version "16.2.0" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-16.2.0.tgz#1c82bf0f6b6a66eafce7ef30e376f49a12477f66" + integrity sha512-D1mvvtDG0L5ft/jGWkLpG1+m0eQxOfaBvTNELraWj22wSVUMWxZUvYgJYcKh6jGGIkJFhH4IZPQhR4TKpc8mBw== + dependencies: + cliui "^7.0.2" + escalade "^3.1.1" + get-caller-file "^2.0.5" + require-directory "^2.1.1" + string-width "^4.2.0" + y18n "^5.0.5" + yargs-parser "^20.2.2" + +yocto-queue@^0.1.0: + version "0.1.0" + resolved "https://registry.yarnpkg.com/yocto-queue/-/yocto-queue-0.1.0.tgz#0294eb3dee05028d31ee1a5fa2c556a6aaf10a1b" + integrity sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q== From 955f090f9f0e69300c3bc331c52a426b0dec5dab Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 7 Dec 2022 13:06:57 -0800 Subject: [PATCH 113/346] Retire the visibility group "//mediapipe/framework:mediapipe_internal". PiperOrigin-RevId: 493687025 --- mediapipe/framework/profiler/BUILD | 4 +--- mediapipe/framework/tool/BUILD | 19 +++++++++---------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 2947b9844..3b6976fc8 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -291,9 +291,7 @@ cc_library( "-ObjC++", ], }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], + visibility = ["//visibility:private"], deps = [ "@com_google_absl//absl/flags:flag", "//mediapipe/framework/port:logging", diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 52d04b4b1..453b5a0e8 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -90,7 +90,7 @@ mediapipe_proto_library( name = "packet_generator_wrapper_calculator_proto", srcs = ["packet_generator_wrapper_calculator.proto"], def_py_proto = False, - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:packet_generator_proto", @@ -120,13 +120,13 @@ cc_library( name = "fill_packet_set", srcs = ["fill_packet_set.cc"], hdrs = ["fill_packet_set.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ + ":status_util", "//mediapipe/framework:packet_set", "//mediapipe/framework:packet_type", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", - "//mediapipe/framework/tool:status_util", "@com_google_absl//absl/memory", ], ) @@ -162,7 +162,6 @@ cc_library( cc_test( name = "executor_util_test", srcs = ["executor_util_test.cc"], - visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":executor_util", "//mediapipe/framework/port:gtest_main", @@ -173,7 +172,7 @@ cc_test( cc_library( name = "options_map", hdrs = ["options_map.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe:__subpackages__"], deps = [ ":type_util", "//mediapipe/framework:calculator_cc_proto", @@ -193,7 +192,7 @@ cc_library( name = "options_field_util", srcs = ["options_field_util.cc"], hdrs = ["options_field_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:private"], deps = [ ":field_data_cc_proto", ":name_util", @@ -216,7 +215,7 @@ cc_library( name = "options_syntax_util", srcs = ["options_syntax_util.cc"], hdrs = ["options_syntax_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:private"], deps = [ ":name_util", ":options_field_util", @@ -235,8 +234,9 @@ cc_library( name = "options_util", srcs = ["options_util.cc"], hdrs = ["options_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:public"], deps = [ + ":name_util", ":options_field_util", ":options_map", ":options_registry", @@ -254,7 +254,6 @@ cc_library( "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:name_util", "@com_google_absl//absl/strings", ], ) @@ -323,7 +322,7 @@ mediapipe_cc_test( cc_library( name = "packet_generator_wrapper_calculator", srcs = ["packet_generator_wrapper_calculator.cc"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":packet_generator_wrapper_calculator_cc_proto", "//mediapipe/framework:calculator_base", From ea74db86dd2926278c9d2486bf58e23caa1a97a6 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 7 Dec 2022 14:04:37 -0800 Subject: [PATCH 114/346] Tensor: clang tidy fixes. PiperOrigin-RevId: 493703073 --- mediapipe/framework/formats/tensor.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index c31eba350..9e1406dbb 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -551,7 +551,7 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { }); } else #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - + { // Transfer data from texture if not transferred from SSBO/MTLBuffer // yet. if (valid_ & kValidOpenGlTexture2d) { @@ -582,6 +582,7 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { } }); } + } #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 valid_ |= kValidCpu; } From 7faee517c4606e647ae63ae4296fae54d08f6abb Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 7 Dec 2022 14:31:02 -0800 Subject: [PATCH 115/346] Tensor: Move general CPU/SSBO tensor storage into Ahwb-backed CPU/SSBO storage. PiperOrigin-RevId: 493710495 --- mediapipe/framework/formats/tensor.h | 1 + mediapipe/framework/formats/tensor_ahwb.cc | 40 +++++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 3ed72c6fd..151aa299d 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -418,6 +418,7 @@ class Tensor { void ReleaseAhwbStuff(); void* MapAhwbToCpuRead() const; void* MapAhwbToCpuWrite() const; + void MoveCpuOrSsboToAhwb() const; #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 mutable std::shared_ptr gl_context_; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 90d89c40a..21bae9593 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -215,10 +215,15 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const { CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer)) << "Interoperability bettween OpenGL buffer and AHardwareBuffer is not " "supported on targe system."; + bool transfer = !ahwb_; CHECK(AllocateAHardwareBuffer()) << "AHardwareBuffer is not supported on the target system."; valid_ |= kValidAHardwareBuffer; - if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd(); + if (transfer) { + MoveCpuOrSsboToAhwb(); + } else { + if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd(); + } return {ahwb_, ssbo_written_, &fence_fd_, // The FD is created for SSBO -> AHWB synchronization. @@ -303,6 +308,39 @@ bool Tensor::AllocateAhwbMapToSsbo() const { return false; } +// Moves Cpu/Ssbo resource under the Ahwb backed memory. +void Tensor::MoveCpuOrSsboToAhwb() const { + void* dest = nullptr; + if (__builtin_available(android 26, *)) { + auto error = AHardwareBuffer_lock( + ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest); + CHECK(error == 0) << "AHardwareBuffer_lock " << error; + } + if (valid_ & kValidOpenGlBuffer) { + gl_context_->Run([this, dest]() { + glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); + const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(), + GL_MAP_READ_BIT); + std::memcpy(dest, src, bytes()); + glUnmapBuffer(GL_SHADER_STORAGE_BUFFER); + glDeleteBuffers(1, &opengl_buffer_); + }); + opengl_buffer_ = GL_INVALID_INDEX; + gl_context_ = nullptr; + } else if (valid_ & kValidCpu) { + std::memcpy(dest, cpu_buffer_, bytes()); + // Free CPU memory because next time AHWB is mapped instead. + free(cpu_buffer_); + cpu_buffer_ = nullptr; + } else { + LOG(FATAL) << "Can't convert tensor with mask " << valid_ << " into AHWB."; + } + if (__builtin_available(android 26, *)) { + auto error = AHardwareBuffer_unlock(ahwb_, nullptr); + CHECK(error == 0) << "AHardwareBuffer_unlock " << error; + } +} + // SSBO is created on top of AHWB. A fence is inserted into the GPU queue before // the GPU task that is going to read from the SSBO. When the writing into AHWB // is finished then the GPU reads from the SSBO. From ef1507ed5df48f00daaa2111518cc0a32faec3a6 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 7 Dec 2022 14:43:59 -0800 Subject: [PATCH 116/346] Retire the visibility group "//mediapipe/framework:mediapipe_internal". PiperOrigin-RevId: 493713823 --- mediapipe/tasks/cc/core/BUILD | 5 ++++- mediapipe/tasks/cc/text/tokenizers/BUILD | 2 +- mediapipe/tasks/testdata/text/BUILD | 5 ++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index f8004d257..d440271df 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -309,7 +309,10 @@ cc_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = [ + "//mediapipe/calculators:__subpackages__", + "//mediapipe/tasks:internal", + ], deps = [ "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto", diff --git a/mediapipe/tasks/cc/text/tokenizers/BUILD b/mediapipe/tasks/cc/text/tokenizers/BUILD index 7f1ea2848..92fac8eaa 100644 --- a/mediapipe/tasks/cc/text/tokenizers/BUILD +++ b/mediapipe/tasks/cc/text/tokenizers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/framework:mediapipe_internal"]) +package(default_visibility = ["//mediapipe/calculators/tensor:__subpackages__"]) licenses(["notice"]) diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 081e63c2c..a0131c056 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -18,7 +18,10 @@ load( ) package( - default_visibility = ["//mediapipe/framework:mediapipe_internal"], + default_visibility = [ + "//mediapipe/calculators/tensor:__subpackages__", + "//mediapipe/tasks:__subpackages__", + ], licenses = ["notice"], # Apache 2.0 ) From 91664eb254bb44adb03b1e4823e0a5250d1f3837 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 7 Dec 2022 14:52:58 -0800 Subject: [PATCH 117/346] Object Detector deduplication PiperOrigin-RevId: 493716159 --- mediapipe/calculators/util/BUILD | 17 +++ .../util/detections_deduplicate_calculator.cc | 114 ++++++++++++++++++ .../tasks/cc/vision/object_detector/BUILD | 1 + .../object_detector/object_detector_graph.cc | 7 +- 4 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 mediapipe/calculators/util/detections_deduplicate_calculator.cc diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 43eadd53b..1529ead8a 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -456,6 +456,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "detections_deduplicate_calculator", + srcs = [ + "detections_deduplicate_calculator.cc", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + ], + alwayslink = 1, +) + cc_library( name = "rect_transformation_calculator", srcs = ["rect_transformation_calculator.cc"], diff --git a/mediapipe/calculators/util/detections_deduplicate_calculator.cc b/mediapipe/calculators/util/detections_deduplicate_calculator.cc new file mode 100644 index 000000000..2dfa09028 --- /dev/null +++ b/mediapipe/calculators/util/detections_deduplicate_calculator.cc @@ -0,0 +1,114 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" + +namespace mediapipe { +namespace api2 { +namespace { + +struct BoundingBoxHash { + size_t operator()(const LocationData::BoundingBox& bbox) const { + return std::hash{}(bbox.xmin()) ^ std::hash{}(bbox.ymin()) ^ + std::hash{}(bbox.width()) ^ std::hash{}(bbox.height()); + } +}; + +struct BoundingBoxEq { + bool operator()(const LocationData::BoundingBox& lhs, + const LocationData::BoundingBox& rhs) const { + return lhs.xmin() == rhs.xmin() && lhs.ymin() == rhs.ymin() && + lhs.width() == rhs.width() && lhs.height() == rhs.height(); + } +}; + +} // namespace + +// This Calculator deduplicates the bunding boxes with exactly the same +// coordinates, and folds the labels into a single Detection proto. Note +// non-maximum-suppression remove the overlapping bounding boxes within a class, +// while the deduplication operation merges bounding boxes from different +// classes. + +// Example config: +// node { +// calculator: "DetectionsDeduplicateCalculator" +// input_stream: "detections" +// output_stream: "deduplicated_detections" +// } +class DetectionsDeduplicateCalculator : public Node { + public: + static constexpr Input> kIn{""}; + static constexpr Output> kOut{""}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + absl::Status Open(mediapipe::CalculatorContext* cc) { + cc->SetOffset(::mediapipe::TimestampDiff(0)); + return absl::OkStatus(); + } + + absl::Status Process(mediapipe::CalculatorContext* cc) { + const std::vector& raw_detections = kIn(cc).Get(); + absl::flat_hash_map + bbox_to_detections; + std::vector deduplicated_detections; + for (const auto& detection : raw_detections) { + if (!detection.has_location_data() || + !detection.location_data().has_bounding_box()) { + return absl::InvalidArgumentError( + "The location data of Detections must be BoundingBox."); + } + if (bbox_to_detections.contains( + detection.location_data().bounding_box())) { + // The bbox location already exists. Merge the detection labels into + // the existing detection proto. + Detection& deduplicated_detection = + *bbox_to_detections[detection.location_data().bounding_box()]; + deduplicated_detection.mutable_score()->MergeFrom(detection.score()); + deduplicated_detection.mutable_label()->MergeFrom(detection.label()); + deduplicated_detection.mutable_label_id()->MergeFrom( + detection.label_id()); + deduplicated_detection.mutable_display_name()->MergeFrom( + detection.display_name()); + } else { + // The bbox location appears first time. Add the detection to output + // detection vector. + deduplicated_detections.push_back(detection); + bbox_to_detections[detection.location_data().bounding_box()] = + &deduplicated_detections.back(); + } + } + kOut(cc).Send(std::move(deduplicated_detections)); + return absl::OkStatus(); + } +}; + +MEDIAPIPE_REGISTER_NODE(DetectionsDeduplicateCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index c2dd9995d..224eca520 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -63,6 +63,7 @@ cc_library( "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", "//mediapipe/calculators/util:detection_projection_calculator", "//mediapipe/calculators/util:detection_transformation_calculator", + "//mediapipe/calculators/util:detections_deduplicate_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index a1625c16c..fd95bb1ac 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -662,11 +662,16 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { detection_transformation.Out(kPixelDetectionsTag) >> detection_label_id_to_text.In(""); + // Deduplicate Detections with same bounding box coordinates. + auto& detections_deduplicate = + graph.AddNode("DetectionsDeduplicateCalculator"); + detection_label_id_to_text.Out("") >> detections_deduplicate.In(""); + // Outputs the labeled detections and the processed image as the subgraph // output streams. return {{ /* detections= */ - detection_label_id_to_text[Output>("")], + detections_deduplicate[Output>("")], /* image= */ preprocessing[Output(kImageTag)], }}; } From 5f97b29b3ba41d1a7765221ed242e0fab9a89751 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 7 Dec 2022 15:23:10 -0800 Subject: [PATCH 118/346] Update Bazel dependencies for Apple PiperOrigin-RevId: 493723833 --- WORKSPACE | 56 +++++++++++++++++++++++++------------------------------ 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index d43394883..bf5e4236b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -320,12 +320,30 @@ http_archive( ], ) -# iOS basic build deps. +# Load Zlib before initializing TensorFlow and the iOS build rules to guarantee +# that the target @zlib//:mini_zlib is available +http_archive( + name = "zlib", + build_file = "//third_party:zlib.BUILD", + sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", + strip_prefix = "zlib-1.2.11", + urls = [ + "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", + "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 + ], + patches = [ + "@//third_party:zlib.diff", + ], + patch_args = [ + "-p1", + ], +) +# iOS basic build deps. http_archive( name = "build_bazel_rules_apple", - sha256 = "77e8bf6fda706f420a55874ae6ee4df0c9d95da6c7838228b26910fc82eea5a2", - url = "https://github.com/bazelbuild/rules_apple/releases/download/0.32.0/rules_apple.0.32.0.tar.gz", + sha256 = "f94e6dddf74739ef5cb30f000e13a2a613f6ebfa5e63588305a71fce8a8a9911", + url = "https://github.com/bazelbuild/rules_apple/releases/download/1.1.3/rules_apple.1.1.3.tar.gz", patches = [ # Bypass checking ios unit test runner when building MP ios applications. "@//third_party:build_bazel_rules_apple_bypass_test_runner_check.diff" @@ -339,29 +357,24 @@ load( "@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies", ) - apple_rules_dependencies() load( "@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies", ) - swift_rules_dependencies() -http_archive( - name = "build_bazel_apple_support", - sha256 = "741366f79d900c11e11d8efd6cc6c66a31bfb2451178b58e0b5edc6f1db17b35", - urls = [ - "https://github.com/bazelbuild/apple_support/releases/download/0.10.0/apple_support.0.10.0.tar.gz" - ], +load( + "@build_bazel_rules_swift//swift:extras.bzl", + "swift_rules_extra_dependencies", ) +swift_rules_extra_dependencies() load( "@build_bazel_apple_support//lib:repositories.bzl", "apple_support_dependencies", ) - apple_support_dependencies() # More iOS deps. @@ -442,25 +455,6 @@ http_archive( ], ) -# Load Zlib before initializing TensorFlow to guarantee that the target -# @zlib//:mini_zlib is available -http_archive( - name = "zlib", - build_file = "//third_party:zlib.BUILD", - sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", - strip_prefix = "zlib-1.2.11", - urls = [ - "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", - "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 - ], - patches = [ - "@//third_party:zlib.diff", - ], - patch_args = [ - "-p1", - ], -) - # TensorFlow repo should always go after the other external dependencies. # TF on 2022-08-10. _TENSORFLOW_GIT_COMMIT = "af1d5bc4fbb66d9e6cc1cf89503014a99233583b" From a59f0a99243a77c5f1cef684c5cd542c320c59f8 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 7 Dec 2022 15:49:08 -0800 Subject: [PATCH 119/346] Make java/C++/python tasks API public visible. PiperOrigin-RevId: 493730506 --- mediapipe/tasks/cc/audio/audio_classifier/BUILD | 4 +--- mediapipe/tasks/cc/audio/audio_embedder/BUILD | 4 +--- mediapipe/tasks/cc/vision/object_detector/BUILD | 4 +--- mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD | 2 +- .../com/google/mediapipe/tasks/components/containers/BUILD | 2 +- .../com/google/mediapipe/tasks/components/processors/BUILD | 2 +- .../java/com/google/mediapipe/tasks/components/utils/BUILD | 2 +- mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD | 2 +- mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD | 2 +- mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD | 2 +- mediapipe/tasks/python/audio/BUILD | 2 +- mediapipe/tasks/python/audio/core/BUILD | 2 +- mediapipe/tasks/python/components/containers/BUILD | 2 +- mediapipe/tasks/python/components/processors/BUILD | 2 +- mediapipe/tasks/python/components/utils/BUILD | 2 +- mediapipe/tasks/python/core/BUILD | 2 +- mediapipe/tasks/python/text/BUILD | 2 +- mediapipe/tasks/python/text/core/BUILD | 2 +- mediapipe/tasks/python/vision/BUILD | 2 +- mediapipe/tasks/python/vision/core/BUILD | 2 +- 20 files changed, 20 insertions(+), 26 deletions(-) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index f61472413..c575caabe 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -22,9 +22,7 @@ cc_library( name = "audio_classifier", srcs = ["audio_classifier.cc"], hdrs = ["audio_classifier.h"], - visibility = [ - "//mediapipe/tasks:users", - ], + visibility = ["//visibility:public"], deps = [ ":audio_classifier_graph", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/audio/audio_embedder/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/BUILD index 6a0f627b2..1dfdd6f1b 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/cc/audio/audio_embedder/BUILD @@ -22,9 +22,7 @@ cc_library( name = "audio_embedder", srcs = ["audio_embedder.cc"], hdrs = ["audio_embedder.h"], - visibility = [ - "//mediapipe/tasks:users", - ], + visibility = ["//visibility:public"], deps = [ ":audio_embedder_graph", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 224eca520..77373303a 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -22,9 +22,7 @@ cc_library( name = "object_detector", srcs = ["object_detector.cc"], hdrs = ["object_detector.h"], - visibility = [ - "//mediapipe/tasks:users", - ], + visibility = ["//visibility:public"], deps = [ ":object_detector_graph", "//mediapipe/calculators/core:concatenate_vector_calculator", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD index 2d29ccf23..e5d472e8a 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD @@ -14,7 +14,7 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) android_library( name = "core", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index ad17d5552..4d302b950 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD index 1f99f1612..b4d453935 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD index b2d27bfa7..6c724106f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index 01b1f653a..31f885267 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) android_library( name = "core", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 5b10e9aab..31cd2c89a 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -14,7 +14,7 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) # The native library of all MediaPipe text tasks. cc_binary( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 6161fe032..f469aed0c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -14,7 +14,7 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) android_library( name = "core", diff --git a/mediapipe/tasks/python/audio/BUILD b/mediapipe/tasks/python/audio/BUILD index ce7c5ce08..6dda7a53c 100644 --- a/mediapipe/tasks/python/audio/BUILD +++ b/mediapipe/tasks/python/audio/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/audio/core/BUILD b/mediapipe/tasks/python/audio/core/BUILD index 3cb9cb8e8..5b4203d7b 100644 --- a/mediapipe/tasks/python/audio/core/BUILD +++ b/mediapipe/tasks/python/audio/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 9d275e167..7108617ff 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/components/processors/BUILD b/mediapipe/tasks/python/components/processors/BUILD index f87a579b0..695f6df91 100644 --- a/mediapipe/tasks/python/components/processors/BUILD +++ b/mediapipe/tasks/python/components/processors/BUILD @@ -16,7 +16,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/components/utils/BUILD b/mediapipe/tasks/python/components/utils/BUILD index 31114f326..1a18531c6 100644 --- a/mediapipe/tasks/python/components/utils/BUILD +++ b/mediapipe/tasks/python/components/utils/BUILD @@ -16,7 +16,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index fc0018ab1..447189d6f 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index e2a51cdbd..9d5d23261 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/text/core/BUILD b/mediapipe/tasks/python/text/core/BUILD index 072a0c7d8..e76bd4b6d 100644 --- a/mediapipe/tasks/python/text/core/BUILD +++ b/mediapipe/tasks/python/text/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 241ca4341..5f4aa38ff 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/vision/core/BUILD b/mediapipe/tasks/python/vision/core/BUILD index e2b2b3dec..18df690a0 100644 --- a/mediapipe/tasks/python/vision/core/BUILD +++ b/mediapipe/tasks/python/vision/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) From a0efcb47f23666f84448d82fcede6dab9fdfbf55 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 7 Dec 2022 16:37:08 -0800 Subject: [PATCH 120/346] internal change PiperOrigin-RevId: 493742399 --- .../tasks/cc/components/containers/BUILD | 13 + .../components/containers/detection_result.cc | 73 ++++++ .../components/containers/detection_result.h | 52 ++++ .../tasks/cc/components/containers/rect.cc | 34 +++ .../tasks/cc/components/containers/rect.h | 29 ++- .../cc/vision/core/base_vision_task_api.h | 4 +- .../cc/vision/core/image_processing_options.h | 3 +- ...hand_landmarks_deduplication_calculator.cc | 14 +- .../hand_landmarker/hand_landmarker_test.cc | 4 +- .../image_classifier/image_classifier_test.cc | 19 +- .../image_embedder/image_embedder_test.cc | 6 +- .../image_segmenter/image_segmenter_test.cc | 4 +- .../tasks/cc/vision/object_detector/BUILD | 1 + .../vision/object_detector/object_detector.cc | 15 +- .../vision/object_detector/object_detector.h | 12 +- .../object_detector/object_detector_test.cc | 227 ++++++++++-------- .../tasks/cc/vision/utils/landmarks_utils.cc | 8 +- .../tasks/cc/vision/utils/landmarks_utils.h | 10 +- 18 files changed, 377 insertions(+), 151 deletions(-) create mode 100644 mediapipe/tasks/cc/components/containers/detection_result.cc create mode 100644 mediapipe/tasks/cc/components/containers/detection_result.h create mode 100644 mediapipe/tasks/cc/components/containers/rect.cc diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index 35d3f4785..0750a1482 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -18,6 +18,7 @@ licenses(["notice"]) cc_library( name = "rect", + srcs = ["rect.cc"], hdrs = ["rect.h"], ) @@ -41,6 +42,18 @@ cc_library( ], ) +cc_library( + name = "detection_result", + srcs = ["detection_result.cc"], + hdrs = ["detection_result.h"], + deps = [ + ":category", + ":rect", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + ], +) + cc_library( name = "embedding_result", srcs = ["embedding_result.cc"], diff --git a/mediapipe/tasks/cc/components/containers/detection_result.cc b/mediapipe/tasks/cc/components/containers/detection_result.cc new file mode 100644 index 000000000..43c8ca0f5 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/detection_result.cc @@ -0,0 +1,73 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/detection_result.h" + +#include + +#include +#include +#include + +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +constexpr int kDefaultCategoryIndex = -1; + +Detection ConvertToDetectionResult( + const mediapipe::Detection& detection_proto) { + Detection detection; + for (int idx = 0; idx < detection_proto.score_size(); ++idx) { + detection.categories.push_back( + {/* index= */ detection_proto.label_id_size() > idx + ? detection_proto.label_id(idx) + : kDefaultCategoryIndex, + /* score= */ detection_proto.score(idx), + /* category_name */ detection_proto.label_size() > idx + ? detection_proto.label(idx) + : "", + /* display_name */ detection_proto.display_name_size() > idx + ? detection_proto.display_name(idx) + : ""}); + } + Rect bounding_box; + if (detection_proto.location_data().has_bounding_box()) { + mediapipe::LocationData::BoundingBox bounding_box_proto = + detection_proto.location_data().bounding_box(); + bounding_box.left = bounding_box_proto.xmin(); + bounding_box.top = bounding_box_proto.ymin(); + bounding_box.right = bounding_box_proto.xmin() + bounding_box_proto.width(); + bounding_box.bottom = + bounding_box_proto.ymin() + bounding_box_proto.height(); + } + detection.bounding_box = bounding_box; + return detection; +} + +DetectionResult ConvertToDetectionResult( + std::vector detections_proto) { + DetectionResult detection_result; + detection_result.detections.reserve(detections_proto.size()); + for (const auto& detection_proto : detections_proto) { + detection_result.detections.push_back( + ConvertToDetectionResult(detection_proto)); + } + return detection_result; +} +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/detection_result.h b/mediapipe/tasks/cc/components/containers/detection_result.h new file mode 100644 index 000000000..546f324d6 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/detection_result.h @@ -0,0 +1,52 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ + +#include +#include +#include + +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +// Detection for a single bounding box. +struct Detection { + // A vector of detected categories. + std::vector categories; + // The bounding box location. + Rect bounding_box; +}; + +// Detection results of a model. +struct DetectionResult { + // A vector of Detections. + std::vector detections; +}; + +// Utility function to convert from Detection proto to Detection struct. +Detection ConvertToDetection(const mediapipe::Detection& detection_proto); + +// Utility function to convert from list of Detection proto to DetectionResult +// struct. +DetectionResult ConvertToDetectionResult( + std::vector detections_proto); + +} // namespace mediapipe::tasks::components::containers +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ diff --git a/mediapipe/tasks/cc/components/containers/rect.cc b/mediapipe/tasks/cc/components/containers/rect.cc new file mode 100644 index 000000000..4a94832a6 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/rect.cc @@ -0,0 +1,34 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +RectF ToRectF(const Rect& rect, int image_height, int image_width) { + return RectF{static_cast(rect.left) / image_width, + static_cast(rect.top) / image_height, + static_cast(rect.right) / image_width, + static_cast(rect.bottom) / image_height}; +} + +Rect ToRect(const RectF& rect, int image_height, int image_width) { + return Rect{static_cast(rect.left * image_width), + static_cast(rect.top * image_height), + static_cast(rect.right * image_width), + static_cast(rect.bottom * image_height)}; +} + +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/rect.h b/mediapipe/tasks/cc/components/containers/rect.h index 3f5432cf2..551d91588 100644 --- a/mediapipe/tasks/cc/components/containers/rect.h +++ b/mediapipe/tasks/cc/components/containers/rect.h @@ -16,20 +16,47 @@ limitations under the License. #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ +#include + namespace mediapipe::tasks::components::containers { +constexpr float kRectFTolerance = 1e-4; + // Defines a rectangle, used e.g. as part of detection results or as input // region-of-interest. // +struct Rect { + int left; + int top; + int right; + int bottom; +}; + +inline bool operator==(const Rect& lhs, const Rect& rhs) { + return lhs.left == rhs.left && lhs.top == rhs.top && lhs.right == rhs.right && + lhs.bottom == rhs.bottom; +} + // The coordinates are normalized wrt the image dimensions, i.e. generally in // [0,1] but they may exceed these bounds if describing a region overlapping the // image. The origin is on the top-left corner of the image. -struct Rect { +struct RectF { float left; float top; float right; float bottom; }; +inline bool operator==(const RectF& lhs, const RectF& rhs) { + return abs(lhs.left - rhs.left) < kRectFTolerance && + abs(lhs.top - rhs.top) < kRectFTolerance && + abs(lhs.right - rhs.right) < kRectFTolerance && + abs(lhs.bottom - rhs.bottom) < kRectFTolerance; +} + +RectF ToRectF(const Rect& rect, int image_height, int image_width); + +Rect ToRect(const RectF& rect, int image_height, int image_width); + } // namespace mediapipe::tasks::components::containers #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ diff --git a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h index c3c0a0261..a86b2cca8 100644 --- a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h +++ b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h @@ -129,13 +129,13 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi { if (roi.left >= roi.right || roi.top >= roi.bottom) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, - "Expected Rect with left < right and top < bottom.", + "Expected RectF with left < right and top < bottom.", MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); } if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, - "Expected Rect values to be in [0,1].", + "Expected RectF values to be in [0,1].", MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); } normalized_rect.set_x_center((roi.left + roi.right) / 2.0); diff --git a/mediapipe/tasks/cc/vision/core/image_processing_options.h b/mediapipe/tasks/cc/vision/core/image_processing_options.h index 7e764c1fe..1983272fc 100644 --- a/mediapipe/tasks/cc/vision/core/image_processing_options.h +++ b/mediapipe/tasks/cc/vision/core/image_processing_options.h @@ -35,7 +35,8 @@ struct ImageProcessingOptions { // the full image is used. // // Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom. - std::optional region_of_interest = std::nullopt; + std::optional region_of_interest = + std::nullopt; // The rotation to apply to the image (or cropped region-of-interest), in // degrees clockwise. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc index 564184c64..266ce223f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc @@ -44,7 +44,7 @@ namespace { using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::utils::CalculateIOU; using ::mediapipe::tasks::vision::utils::DuplicatesFinder; @@ -126,7 +126,7 @@ absl::StatusOr HandBaselineDistance( return distance; } -Rect CalculateBound(const NormalizedLandmarkList& list) { +RectF CalculateBound(const NormalizedLandmarkList& list) { constexpr float kMinInitialValue = std::numeric_limits::max(); constexpr float kMaxInitialValue = std::numeric_limits::lowest(); @@ -144,10 +144,10 @@ Rect CalculateBound(const NormalizedLandmarkList& list) { } // Populate normalized non rotated face bounding box - return Rect{/*left=*/bounding_box_left, - /*top=*/bounding_box_top, - /*right=*/bounding_box_right, - /*bottom=*/bounding_box_bottom}; + return RectF{/*left=*/bounding_box_left, + /*top=*/bounding_box_top, + /*right=*/bounding_box_right, + /*bottom=*/bounding_box_bottom}; } // Uses IoU and distance of some corresponding hand landmarks to detect @@ -172,7 +172,7 @@ class HandDuplicatesFinder : public DuplicatesFinder { const int num = multi_landmarks.size(); std::vector baseline_distances; baseline_distances.reserve(num); - std::vector bounds; + std::vector bounds; bounds.reserve(num); for (const NormalizedLandmarkList& list : multi_landmarks) { ASSIGN_OR_RETURN(const float baseline_distance, diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc index fa49a4c1f..94d1b1c12 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc @@ -50,7 +50,7 @@ namespace { using ::file::Defaults; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::EqualsProto; @@ -188,7 +188,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { options->running_mode = core::RunningMode::IMAGE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr hand_landmarker, HandLandmarker::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = hand_landmarker->Detect(image, image_processing_options); diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index 1144e9032..7aa2a148c 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -52,7 +52,7 @@ namespace { using ::mediapipe::file::JoinPath; using ::mediapipe::tasks::components::containers::Category; using ::mediapipe::tasks::components::containers::Classifications; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -472,7 +472,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Region-of-interest around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( @@ -526,7 +526,8 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Region-of-interest around the chair, with 90° anti-clockwise rotation. - Rect roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, /*bottom=*/0.3049}; + RectF roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, + /*bottom=*/0.3049}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/-90}; @@ -554,13 +555,13 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { ImageClassifier::Create(std::move(options))); // Invalid: left > right. - Rect roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; + RectF roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect with left < right and top < bottom")); + HasSubstr("Expected RectF with left < right and top < bottom")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -573,7 +574,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect with left < right and top < bottom")); + HasSubstr("Expected RectF with left < right and top < bottom")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -586,7 +587,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect values to be in [0,1]")); + HasSubstr("Expected RectF values to be in [0,1]")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -695,7 +696,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. // Region-of-interest around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { @@ -837,7 +838,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc index 6098a9a70..dd602bef5 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc @@ -41,7 +41,7 @@ namespace image_embedder { namespace { using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -320,7 +320,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { Image crop, DecodeImageFromFile( JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); // Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg". - Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; + RectF roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; // Extract both embeddings. @@ -388,7 +388,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger_rotated.jpg"))); // Region-of-interest corresponding to burger_crop.jpg. - Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; + RectF roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/-90}; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index d5ea088a1..f9618c1b1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -47,7 +47,7 @@ namespace { using ::mediapipe::Image; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -299,7 +299,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = segmenter->Segment(image, image_processing_options); diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 77373303a..5269796ae 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -33,6 +33,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/components/containers:detection_result", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index dd19237ff..e0222dd70 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -56,6 +57,7 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.ObjectDetectorGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::tasks::components::containers::ConvertToDetectionResult; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; @@ -129,7 +131,8 @@ absl::StatusOr> ObjectDetector::Create( Packet detections_packet = status_or_packets.value()[kDetectionsOutStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName]; - result_callback(detections_packet.Get>(), + result_callback(ConvertToDetectionResult( + detections_packet.Get>()), image_packet.Get(), detections_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); @@ -144,7 +147,7 @@ absl::StatusOr> ObjectDetector::Create( std::move(packets_callback)); } -absl::StatusOr> ObjectDetector::Detect( +absl::StatusOr ObjectDetector::Detect( mediapipe::Image image, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -161,10 +164,11 @@ absl::StatusOr> ObjectDetector::Detect( ProcessImageData( {{kImageInStreamName, MakePacket(std::move(image))}, {kNormRectName, MakePacket(std::move(norm_rect))}})); - return output_packets[kDetectionsOutStreamName].Get>(); + return ConvertToDetectionResult( + output_packets[kDetectionsOutStreamName].Get>()); } -absl::StatusOr> ObjectDetector::DetectForVideo( +absl::StatusOr ObjectDetector::DetectForVideo( mediapipe::Image image, int64 timestamp_ms, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -185,7 +189,8 @@ absl::StatusOr> ObjectDetector::DetectForVideo( {kNormRectName, MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); - return output_packets[kDetectionsOutStreamName].Get>(); + return ConvertToDetectionResult( + output_packets[kDetectionsOutStreamName].Get>()); } absl::Status ObjectDetector::DetectAsync( diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.h b/mediapipe/tasks/cc/vision/object_detector/object_detector.h index 44ce68ed9..249a2ebf5 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.h +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" @@ -36,6 +37,10 @@ namespace mediapipe { namespace tasks { namespace vision { +// Alias the shared DetectionResult struct as result typo. +using ObjectDetectorResult = + ::mediapipe::tasks::components::containers::DetectionResult; + // The options for configuring a mediapipe object detector task. struct ObjectDetectorOptions { // Base options for configuring MediaPipe Tasks, such as specifying the TfLite @@ -79,8 +84,7 @@ struct ObjectDetectorOptions { // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. - std::function>, - const Image&, int64)> + std::function, const Image&, int64)> result_callback = nullptr; }; @@ -165,7 +169,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // underlying image data. // TODO: Describes the output bounding boxes for gpu input // images after enabling the gpu support in MediaPipe Tasks. - absl::StatusOr> Detect( + absl::StatusOr Detect( mediapipe::Image image, std::optional image_processing_options = std::nullopt); @@ -188,7 +192,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // unrotated input frame of reference coordinates system, i.e. in `[0, // image_width) x [0, image_height)`, which are the dimensions of the // underlying image data. - absl::StatusOr> DetectForVideo( + absl::StatusOr DetectForVideo( mediapipe::Image image, int64 timestamp_ms, std::optional image_processing_options = std::nullopt); diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index 1747685dd..798e3f238 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -65,10 +66,14 @@ namespace vision { namespace { using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::ConvertToDetectionResult; +using ::mediapipe::tasks::components::containers::Detection; +using ::mediapipe::tasks::components::containers::DetectionResult; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; +using DetectionProto = mediapipe::Detection; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kMobileSsdWithMetadata[] = @@ -83,47 +88,45 @@ constexpr char kEfficientDetWithMetadata[] = // Checks that the two provided `Detection` proto vectors are equal, with a // tolerancy on floating-point scores to account for numerical instabilities. // If the proto definition changes, please also change this function. -void ExpectApproximatelyEqual(const std::vector& actual, - const std::vector& expected) { +void ExpectApproximatelyEqual(const ObjectDetectorResult& actual, + const ObjectDetectorResult& expected) { const float kPrecision = 1e-6; - EXPECT_EQ(actual.size(), expected.size()); - for (int i = 0; i < actual.size(); ++i) { - const Detection& a = actual[i]; - const Detection& b = expected[i]; - EXPECT_THAT(a.location_data().bounding_box(), - EqualsProto(b.location_data().bounding_box())); - EXPECT_EQ(a.label_size(), 1); - EXPECT_EQ(b.label_size(), 1); - EXPECT_EQ(a.label(0), b.label(0)); - EXPECT_EQ(a.score_size(), 1); - EXPECT_EQ(b.score_size(), 1); - EXPECT_NEAR(a.score(0), b.score(0), kPrecision); + EXPECT_EQ(actual.detections.size(), expected.detections.size()); + for (int i = 0; i < actual.detections.size(); ++i) { + const Detection& a = actual.detections[i]; + const Detection& b = expected.detections[i]; + EXPECT_EQ(a.bounding_box, b.bounding_box); + EXPECT_EQ(a.categories.size(), 1); + EXPECT_EQ(b.categories.size(), 1); + EXPECT_EQ(a.categories[0].category_name, b.categories[0].category_name); + EXPECT_NEAR(a.categories[0].score, b.categories[0].score, kPrecision); } } -std::vector GenerateMobileSsdNoImageResizingFullExpectedResults() { - return {ParseTextProtoOrDie(R"pb( +std::vector +GenerateMobileSsdNoImageResizingFullExpectedResults() { + return {ParseTextProtoOrDie(R"pb( label: "cat" score: 0.6328125 location_data { format: BOUNDING_BOX bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "cat" score: 0.59765625 location_data { format: BOUNDING_BOX bounding_box { xmin: 151 ymin: 78 width: 104 height: 223 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "cat" score: 0.5 location_data { format: BOUNDING_BOX bounding_box { xmin: 65 ymin: 199 width: 41 height: 101 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "dog" score: 0.48828125 location_data { @@ -263,8 +266,8 @@ TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInImageOrVideoMode) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = running_mode; options->result_callback = - [](absl::StatusOr> detections, - const Image& image, int64 timestamp_ms) {}; + [](absl::StatusOr detections, const Image& image, + int64 timestamp_ms) {}; absl::StatusOr> object_detector = ObjectDetector::Create(std::move(options)); EXPECT_EQ(object_detector.status().code(), @@ -340,34 +343,36 @@ TEST_F(ImageModeTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.69921875 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.64453125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.51171875 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.48828125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 } - })pb")}); + results, + ConvertToDetectionResult( + {ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.69921875 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.64453125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.51171875 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.48828125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 } + })pb")})); } TEST_F(ImageModeTest, SucceedsEfficientDetModel) { @@ -383,34 +388,36 @@ TEST_F(ImageModeTest, SucceedsEfficientDetModel) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.7578125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.72265625 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.6289063 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.5859375 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 } - })pb")}); + results, + ConvertToDetectionResult( + {ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.7578125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.72265625 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.6289063 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.5859375 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 } + })pb")})); } TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { @@ -426,7 +433,8 @@ TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, GenerateMobileSsdNoImageResizingFullExpectedResults()); + results, ConvertToDetectionResult( + GenerateMobileSsdNoImageResizingFullExpectedResults())); } TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { @@ -442,13 +450,14 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( + results, + ConvertToDetectionResult({ParseTextProtoOrDie(R"pb( label: "cat" score: 0.6531269142 location_data { format: BOUNDING_BOX bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } - })pb")}); + })pb")})); } TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { @@ -463,11 +472,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, - {full_expected_results[0], full_expected_results[1], - full_expected_results[2]}); + + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[0], + full_expected_results[1], + full_expected_results[2]})); } TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { @@ -482,10 +493,11 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); ExpectApproximatelyEqual( - results, {full_expected_results[0], full_expected_results[1]}); + results, ConvertToDetectionResult( + {full_expected_results[0], full_expected_results[1]})); } TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { @@ -501,9 +513,10 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, {full_expected_results[3]}); + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[3]})); } TEST_F(ImageModeTest, SucceedsWithDenylistOption) { @@ -519,9 +532,10 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, {full_expected_results[3]}); + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[3]})); } TEST_F(ImageModeTest, SucceedsWithRotation) { @@ -541,13 +555,14 @@ TEST_F(ImageModeTest, SucceedsWithRotation) { auto results, object_detector->Detect(image, image_processing_options)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( + results, + ConvertToDetectionResult({ParseTextProtoOrDie(R"pb( label: "cat" score: 0.7109375 location_data { format: BOUNDING_BOX bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 } - })pb")}); + })pb")})); } TEST_F(ImageModeTest, FailsWithRegionOfInterest) { @@ -560,7 +575,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = object_detector->Detect(image, image_processing_options); @@ -619,10 +634,11 @@ TEST_F(VideoModeTest, Succeeds) { for (int i = 0; i < iterations; ++i) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->DetectForVideo(image, i)); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); ExpectApproximatelyEqual( - results, {full_expected_results[0], full_expected_results[1]}); + results, ConvertToDetectionResult( + {full_expected_results[0], full_expected_results[1]})); } MP_ASSERT_OK(object_detector->Close()); } @@ -637,9 +653,8 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = - [](absl::StatusOr> detections, const Image& image, - int64 timestamp_ms) {}; + options->result_callback = [](absl::StatusOr detections, + const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); @@ -669,9 +684,8 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { options->running_mode = core::RunningMode::LIVE_STREAM; options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); - options->result_callback = - [](absl::StatusOr> detections, const Image& image, - int64 timestamp_ms) {}; + options->result_callback = [](absl::StatusOr detections, + const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); MP_ASSERT_OK(object_detector->DetectAsync(image, 1)); @@ -695,14 +709,14 @@ TEST_F(LiveStreamModeTest, Succeeds) { auto options = std::make_unique(); options->max_results = 2; options->running_mode = core::RunningMode::LIVE_STREAM; - std::vector> detection_results; + std::vector detection_results; std::vector> image_sizes; std::vector timestamps; options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->result_callback = [&detection_results, &image_sizes, ×tamps]( - absl::StatusOr> detections, const Image& image, + absl::StatusOr detections, const Image& image, int64 timestamp_ms) { MP_ASSERT_OK(detections.status()); detection_results.push_back(std::move(detections).value()); @@ -719,11 +733,12 @@ TEST_F(LiveStreamModeTest, Succeeds) { // number of iterations. ASSERT_LE(detection_results.size(), iterations); ASSERT_GT(detection_results.size(), 0); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); for (const auto& detection_result : detection_results) { ExpectApproximatelyEqual( - detection_result, {full_expected_results[0], full_expected_results[1]}); + detection_result, ConvertToDetectionResult({full_expected_results[0], + full_expected_results[1]})); } for (const auto& image_size : image_sizes) { EXPECT_EQ(image_size.first, image.width()); diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc index 2ce9e2454..fe4e63824 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc @@ -22,13 +22,13 @@ limitations under the License. namespace mediapipe::tasks::vision::utils { -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; -float CalculateArea(const Rect& rect) { +float CalculateArea(const RectF& rect) { return (rect.right - rect.left) * (rect.bottom - rect.top); } -float CalculateIntersectionArea(const Rect& a, const Rect& b) { +float CalculateIntersectionArea(const RectF& a, const RectF& b) { const float intersection_left = std::max(a.left, b.left); const float intersection_top = std::max(a.top, b.top); const float intersection_right = std::min(a.right, b.right); @@ -38,7 +38,7 @@ float CalculateIntersectionArea(const Rect& a, const Rect& b) { std::max(intersection_right - intersection_left, 0.0); } -float CalculateIOU(const Rect& a, const Rect& b) { +float CalculateIOU(const RectF& a, const RectF& b) { const float area_a = CalculateArea(a); const float area_b = CalculateArea(b); if (area_a <= 0 || area_b <= 0) return 0.0; diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h index 73114d2ef..4d1fac62f 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h @@ -27,15 +27,15 @@ limitations under the License. namespace mediapipe::tasks::vision::utils { // Calculates intersection over union for two bounds. -float CalculateIOU(const components::containers::Rect& a, - const components::containers::Rect& b); +float CalculateIOU(const components::containers::RectF& a, + const components::containers::RectF& b); // Calculates area for face bound -float CalculateArea(const components::containers::Rect& rect); +float CalculateArea(const components::containers::RectF& rect); // Calucates intersection area of two face bounds -float CalculateIntersectionArea(const components::containers::Rect& a, - const components::containers::Rect& b); +float CalculateIntersectionArea(const components::containers::RectF& a, + const components::containers::RectF& b); } // namespace mediapipe::tasks::vision::utils #endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_ From 700c7b4b2258d3a01bf8424146a4cf94e8ca7282 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 7 Dec 2022 18:54:34 -0800 Subject: [PATCH 121/346] Internal refactoring for TextEmbedder. PiperOrigin-RevId: 493766612 --- .../tasks/cc/components/processors/BUILD | 3 + .../cc/components/processors/proto/BUILD | 6 + .../processors/proto/text_model_type.proto | 31 +++++ .../text_preprocessing_graph_options.proto | 15 +-- .../processors/text_preprocessing_graph.cc | 126 +++++------------- mediapipe/tasks/cc/text/utils/BUILD | 40 ++++++ .../tasks/cc/text/utils/text_model_utils.cc | 119 +++++++++++++++++ .../tasks/cc/text/utils/text_model_utils.h | 33 +++++ .../cc/text/utils/text_model_utils_test.cc | 108 +++++++++++++++ 9 files changed, 375 insertions(+), 106 deletions(-) create mode 100644 mediapipe/tasks/cc/components/processors/proto/text_model_type.proto create mode 100644 mediapipe/tasks/cc/text/utils/text_model_utils.cc create mode 100644 mediapipe/tasks/cc/text/utils/text_model_utils.h create mode 100644 mediapipe/tasks/cc/text/utils/text_model_utils_test.cc diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 185bf231b..cec44a9e3 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -150,9 +150,12 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/text/utils:text_model_utils", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index f48c4bad8..816ba47e3 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -60,10 +60,16 @@ mediapipe_proto_library( ], ) +mediapipe_proto_library( + name = "text_model_type_proto", + srcs = ["text_model_type.proto"], +) + mediapipe_proto_library( name = "text_preprocessing_graph_options_proto", srcs = ["text_preprocessing_graph_options.proto"], deps = [ + ":text_model_type_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], diff --git a/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto b/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto new file mode 100644 index 000000000..7ffc0db07 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto @@ -0,0 +1,31 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.components.processors.proto; + +message TextModelType { + // TFLite text models supported by MediaPipe tasks. + enum ModelType { + UNSPECIFIED_MODEL = 0; + // A BERT-based model. + BERT_MODEL = 1; + // A model expecting input passed through a regex-based tokenizer. + REGEX_MODEL = 2; + // A model taking a string tensor input. + STRING_MODEL = 3; + } +} diff --git a/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto index a67cfd8a9..b610f7757 100644 --- a/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto @@ -18,25 +18,16 @@ syntax = "proto2"; package mediapipe.tasks.components.processors.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/processors/proto/text_model_type.proto"; message TextPreprocessingGraphOptions { extend mediapipe.CalculatorOptions { optional TextPreprocessingGraphOptions ext = 476978751; } - // The type of text preprocessor required for the TFLite model. - enum PreprocessorType { - UNSPECIFIED_PREPROCESSOR = 0; - // Used for the BertPreprocessorCalculator. - BERT_PREPROCESSOR = 1; - // Used for the RegexPreprocessorCalculator. - REGEX_PREPROCESSOR = 2; - // Used for the TextToTensorCalculator. - STRING_PREPROCESSOR = 3; - } - optional PreprocessorType preprocessor_type = 1; + optional TextModelType.ModelType model_type = 1; // The maximum input sequence length for the TFLite model. Used with - // BERT_PREPROCESSOR and REGEX_PREPROCESSOR. + // BERT_MODEL and REGEX_MODEL. optional int32 max_seq_len = 2; } diff --git a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc index de16375bd..f6c15c441 100644 --- a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc @@ -25,15 +25,14 @@ limitations under the License. #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/subgraph.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" -namespace mediapipe { -namespace tasks { -namespace components { -namespace processors { - +namespace mediapipe::tasks::components::processors { namespace { using ::mediapipe::api2::Input; @@ -42,91 +41,35 @@ using ::mediapipe::api2::SideInput; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::SideSource; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::processors::proto::TextModelType; using ::mediapipe::tasks::components::processors::proto:: TextPreprocessingGraphOptions; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::mediapipe::tasks::text::utils::GetModelType; constexpr char kTextTag[] = "TEXT"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kTensorsTag[] = "TENSORS"; -constexpr int kNumInputTensorsForBert = 3; -constexpr int kNumInputTensorsForRegex = 1; - -// Gets the name of the MediaPipe calculator associated with -// `preprocessor_type`. -absl::StatusOr GetCalculatorNameFromPreprocessorType( - TextPreprocessingGraphOptions::PreprocessorType preprocessor_type) { - switch (preprocessor_type) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: +// Gets the name of the MediaPipe preprocessor calculator associated with +// `model_type`. +absl::StatusOr GetCalculatorNameFromModelType( + TextModelType::ModelType model_type) { + switch (model_type) { + case TextModelType::UNSPECIFIED_MODEL: return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, "Unspecified preprocessor type", + absl::StatusCode::kInvalidArgument, "Unspecified model type", MediaPipeTasksStatus::kInvalidArgumentError); - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: + case TextModelType::BERT_MODEL: return "BertPreprocessorCalculator"; - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: + case TextModelType::REGEX_MODEL: return "RegexPreprocessorCalculator"; - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: + case TextModelType::STRING_MODEL: return "TextToTensorCalculator"; } } -// Determines the PreprocessorType for the model based on its metadata as well -// as its input tensors' type and count. Returns an error if there is no -// compatible preprocessor. -absl::StatusOr -GetPreprocessorType(const ModelResources& model_resources) { - const tflite::SubGraph& model_graph = - *(*model_resources.GetTfLiteModel()->subgraphs())[0]; - bool all_int32_tensors = - absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { - return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32; - }); - bool all_string_tensors = - absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { - return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING; - }); - if (!all_int32_tensors && !all_string_tensors) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "All input tensors should have type int32 or all should have type " - "string", - MediaPipeTasksStatus::kInvalidInputTensorTypeError); - } - if (all_string_tensors) { - return TextPreprocessingGraphOptions::STRING_PREPROCESSOR; - } - - // Otherwise, all tensors should have type int32 - const ModelMetadataExtractor* metadata_extractor = - model_resources.GetMetadataExtractor(); - if (metadata_extractor->GetModelMetadata() == nullptr || - metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "Text models with int32 input tensors require TFLite Model " - "Metadata but none was found", - MediaPipeTasksStatus::kMetadataNotFoundError); - } - - if (model_graph.inputs()->size() == kNumInputTensorsForBert) { - return TextPreprocessingGraphOptions::BERT_PREPROCESSOR; - } - - if (model_graph.inputs()->size() == kNumInputTensorsForRegex) { - return TextPreprocessingGraphOptions::REGEX_PREPROCESSOR; - } - - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::Substitute("Models with int32 input tensors should take exactly $0 " - "or $1 input tensors, but found $2", - kNumInputTensorsForBert, kNumInputTensorsForRegex, - model_graph.inputs()->size()), - MediaPipeTasksStatus::kInvalidNumInputTensorsError); -} - // Returns the maximum input sequence length accepted by the TFLite // model that owns `model graph` or returns an error if the model's input // tensors' shape is invalid for text preprocessing. This util assumes that the @@ -181,17 +124,16 @@ absl::Status ConfigureTextPreprocessingGraph( MediaPipeTasksStatus::kInvalidArgumentError); } - ASSIGN_OR_RETURN( - TextPreprocessingGraphOptions::PreprocessorType preprocessor_type, - GetPreprocessorType(model_resources)); - options.set_preprocessor_type(preprocessor_type); - switch (preprocessor_type) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { + ASSIGN_OR_RETURN(TextModelType::ModelType model_type, + GetModelType(model_resources)); + options.set_model_type(model_type); + switch (model_type) { + case TextModelType::UNSPECIFIED_MODEL: + case TextModelType::STRING_MODEL: { break; } - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { + case TextModelType::BERT_MODEL: + case TextModelType::REGEX_MODEL: { ASSIGN_OR_RETURN( int max_seq_len, GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0])); @@ -239,23 +181,22 @@ class TextPreprocessingGraph : public mediapipe::Subgraph { absl::StatusOr>> BuildTextPreprocessing( const TextPreprocessingGraphOptions& options, Source text_in, SideSource metadata_extractor_in, Graph& graph) { - ASSIGN_OR_RETURN( - std::string preprocessor_name, - GetCalculatorNameFromPreprocessorType(options.preprocessor_type())); + ASSIGN_OR_RETURN(std::string preprocessor_name, + GetCalculatorNameFromModelType(options.model_type())); auto& text_preprocessor = graph.AddNode(preprocessor_name); - switch (options.preprocessor_type()) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { + switch (options.model_type()) { + case TextModelType::UNSPECIFIED_MODEL: + case TextModelType::STRING_MODEL: { break; } - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: { + case TextModelType::BERT_MODEL: { text_preprocessor.GetOptions() .set_bert_max_seq_len(options.max_seq_len()); metadata_extractor_in >> text_preprocessor.SideIn(kMetadataExtractorTag); break; } - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { + case TextModelType::REGEX_MODEL: { text_preprocessor.GetOptions() .set_max_seq_len(options.max_seq_len()); metadata_extractor_in >> @@ -270,7 +211,4 @@ class TextPreprocessingGraph : public mediapipe::Subgraph { REGISTER_MEDIAPIPE_GRAPH( ::mediapipe::tasks::components::processors::TextPreprocessingGraph); -} // namespace processors -} // namespace components -} // namespace tasks -} // namespace mediapipe +} // namespace mediapipe::tasks::components::processors diff --git a/mediapipe/tasks/cc/text/utils/BUILD b/mediapipe/tasks/cc/text/utils/BUILD index 710e8a984..092a7d450 100644 --- a/mediapipe/tasks/cc/text/utils/BUILD +++ b/mediapipe/tasks/cc/text/utils/BUILD @@ -43,3 +43,43 @@ cc_test( "@com_google_absl//absl/container:node_hash_map", ], ) + +cc_library( + name = "text_model_utils", + srcs = ["text_model_utils.cc"], + hdrs = ["text_model_utils.h"], + deps = [ + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_test( + name = "text_model_utils_test", + srcs = ["text_model_utils_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:mobilebert_embedding_model", + "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + deps = [ + ":text_model_utils", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +) diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils.cc b/mediapipe/tasks/cc/text/utils/text_model_utils.cc new file mode 100644 index 000000000..9d0005ec1 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils.cc @@ -0,0 +1,119 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/substitute.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace mediapipe::tasks::text::utils { +namespace { + +using ::mediapipe::tasks::components::processors::proto::TextModelType; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; + +constexpr int kNumInputTensorsForBert = 3; +constexpr int kNumInputTensorsForRegex = 1; +constexpr int kNumInputTensorsForStringPreprocessor = 1; + +// Determines the ModelType for a model with int32 input tensors based +// on the number of input tensors. Returns an error if there is missing metadata +// or an invalid number of input tensors. +absl::StatusOr GetIntTensorModelType( + const ModelResources& model_resources, int num_input_tensors) { + const ModelMetadataExtractor* metadata_extractor = + model_resources.GetMetadataExtractor(); + if (metadata_extractor->GetModelMetadata() == nullptr || + metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Text models with int32 input tensors require TFLite Model " + "Metadata but none was found", + MediaPipeTasksStatus::kMetadataNotFoundError); + } + + if (num_input_tensors == kNumInputTensorsForBert) { + return TextModelType::BERT_MODEL; + } + + if (num_input_tensors == kNumInputTensorsForRegex) { + return TextModelType::REGEX_MODEL; + } + + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute("Models with int32 input tensors should take exactly $0 " + "or $1 input tensors, but found $2", + kNumInputTensorsForBert, kNumInputTensorsForRegex, + num_input_tensors), + MediaPipeTasksStatus::kInvalidNumInputTensorsError); +} + +// Determines the ModelType for a model with string input tensors based +// on the number of input tensors. Returns an error if there is an invalid +// number of input tensors. +absl::StatusOr GetStringTensorModelType( + const ModelResources& model_resources, int num_input_tensors) { + if (num_input_tensors == kNumInputTensorsForStringPreprocessor) { + return TextModelType::STRING_MODEL; + } + + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute("Models with string input tensors should take exactly " + "$0 tensors, but found $1", + kNumInputTensorsForStringPreprocessor, + num_input_tensors), + MediaPipeTasksStatus::kInvalidNumInputTensorsError); +} +} // namespace + +absl::StatusOr GetModelType( + const ModelResources& model_resources) { + const tflite::SubGraph& model_graph = + *(*model_resources.GetTfLiteModel()->subgraphs())[0]; + bool all_int32_tensors = + absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { + return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32; + }); + bool all_string_tensors = + absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { + return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING; + }); + if (!all_int32_tensors && !all_string_tensors) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "All input tensors should have type int32 or all should have type " + "string", + MediaPipeTasksStatus::kInvalidInputTensorTypeError); + } + if (all_string_tensors) { + return GetStringTensorModelType(model_resources, + model_graph.inputs()->size()); + } + + // Otherwise, all tensors should have type int32 + return GetIntTensorModelType(model_resources, model_graph.inputs()->size()); +} + +} // namespace mediapipe::tasks::text::utils diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils.h b/mediapipe/tasks/cc/text/utils/text_model_utils.h new file mode 100644 index 000000000..da8783d33 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils.h @@ -0,0 +1,33 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ + +#include "absl/status/statusor.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" + +namespace mediapipe::tasks::text::utils { + +// Determines the ModelType for the model based on its metadata as well +// as its input tensors' type and count. Returns an error if there is no +// compatible model type. +absl::StatusOr +GetModelType(const core::ModelResources& model_resources); + +} // namespace mediapipe::tasks::text::utils + +#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc new file mode 100644 index 000000000..c02f8eca5 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe::tasks::text::utils { + +namespace { + +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::processors::proto::TextModelType; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::core::proto::ExternalFile; + +constexpr absl::string_view kTestModelResourcesTag = "test_model_resources"; + +constexpr absl::string_view kTestDataDirectory = + "/mediapipe/tasks/testdata/text/"; +// Classification model with BERT preprocessing. +constexpr absl::string_view kBertClassifierPath = "bert_text_classifier.tflite"; +// Embedding model with BERT preprocessing. +constexpr absl::string_view kMobileBert = + "mobilebert_embedding_with_metadata.tflite"; +// Classification model with regex preprocessing. +constexpr absl::string_view kRegexClassifierPath = + "test_model_text_classifier_with_regex_tokenizer.tflite"; +// Embedding model with regex preprocessing. +constexpr absl::string_view kRegexOneEmbeddingModel = + "regex_one_embedding_with_metadata.tflite"; +// Classification model that takes a string tensor and outputs a bool tensor. +constexpr absl::string_view kStringToBoolModelPath = + "test_model_text_classifier_bool_output.tflite"; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +absl::StatusOr GetModelTypeFromFile( + absl::string_view file_name) { + auto model_file = std::make_unique(); + model_file->set_file_name(GetFullPath(file_name)); + ASSIGN_OR_RETURN(auto model_resources, + ModelResources::Create(std::string(kTestModelResourcesTag), + std::move(model_file))); + return GetModelType(*model_resources); +} + +} // namespace + +class TextModelUtilsTest : public tflite_shims::testing::Test {}; + +TEST_F(TextModelUtilsTest, BertClassifierModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kBertClassifierPath)); + ASSERT_EQ(model_type, TextModelType::BERT_MODEL); +} + +TEST_F(TextModelUtilsTest, BertEmbedderModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, GetModelTypeFromFile(kMobileBert)); + ASSERT_EQ(model_type, TextModelType::BERT_MODEL); +} + +TEST_F(TextModelUtilsTest, RegexClassifierModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kRegexClassifierPath)); + ASSERT_EQ(model_type, TextModelType::REGEX_MODEL); +} + +TEST_F(TextModelUtilsTest, RegexEmbedderModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kRegexOneEmbeddingModel)); + ASSERT_EQ(model_type, TextModelType::REGEX_MODEL); +} + +TEST_F(TextModelUtilsTest, StringInputModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kStringToBoolModelPath)); + ASSERT_EQ(model_type, TextModelType::STRING_MODEL); +} + +} // namespace mediapipe::tasks::text::utils From 24c8fa97e9aeb75ac6344957aff2a2d5b953061b Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Wed, 7 Dec 2022 19:04:31 -0800 Subject: [PATCH 122/346] Internal change PiperOrigin-RevId: 493768013 --- mediapipe/examples/ios/faceeffect/BUILD | 4 ++-- mediapipe/examples/ios/facemeshgpu/BUILD | 2 +- mediapipe/examples/ios/handtrackinggpu/BUILD | 2 +- mediapipe/examples/ios/iristrackinggpu/BUILD | 2 +- mediapipe/examples/ios/posetrackinggpu/BUILD | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/examples/ios/faceeffect/BUILD b/mediapipe/examples/ios/faceeffect/BUILD index e0c3abb86..7d3a75cc6 100644 --- a/mediapipe/examples/ios/faceeffect/BUILD +++ b/mediapipe/examples/ios/faceeffect/BUILD @@ -74,10 +74,12 @@ objc_library( ], features = ["-layering_check"], deps = [ + "//mediapipe/framework/formats:matrix_data_cc_proto", "//third_party/apple_frameworks:AVFoundation", "//third_party/apple_frameworks:CoreGraphics", "//third_party/apple_frameworks:CoreMedia", "//third_party/apple_frameworks:UIKit", + "//mediapipe/modules/face_geometry/protos:face_geometry_cc_proto", "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_input_sources_ios", "//mediapipe/objc:mediapipe_layer_renderer", @@ -85,9 +87,7 @@ objc_library( "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/graphs/face_effect:face_effect_gpu_deps", - "//mediapipe/modules/face_geometry/protos:face_geometry_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/facemeshgpu/BUILD b/mediapipe/examples/ios/facemeshgpu/BUILD index 02103ce2f..6caf8c09c 100644 --- a/mediapipe/examples/ios/facemeshgpu/BUILD +++ b/mediapipe/examples/ios/facemeshgpu/BUILD @@ -67,12 +67,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/face_mesh:mobile_calculators", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/handtrackinggpu/BUILD b/mediapipe/examples/ios/handtrackinggpu/BUILD index 647b7670a..c5b8e7b58 100644 --- a/mediapipe/examples/ios/handtrackinggpu/BUILD +++ b/mediapipe/examples/ios/handtrackinggpu/BUILD @@ -68,12 +68,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/hand_tracking:mobile_calculators", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/iristrackinggpu/BUILD b/mediapipe/examples/ios/iristrackinggpu/BUILD index 056447d63..646d2e5a2 100644 --- a/mediapipe/examples/ios/iristrackinggpu/BUILD +++ b/mediapipe/examples/ios/iristrackinggpu/BUILD @@ -68,12 +68,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/iris_tracking:iris_tracking_gpu_deps", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/posetrackinggpu/BUILD b/mediapipe/examples/ios/posetrackinggpu/BUILD index 86b41ed36..4fbc2280c 100644 --- a/mediapipe/examples/ios/posetrackinggpu/BUILD +++ b/mediapipe/examples/ios/posetrackinggpu/BUILD @@ -67,12 +67,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/pose_tracking:pose_tracking_gpu_deps", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) From 9ae2e43b70188cd73fd478364b71d32410f9c21c Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 7 Dec 2022 19:17:14 -0800 Subject: [PATCH 123/346] Open Source the remaining MediaPipe Tasks tests for Web PiperOrigin-RevId: 493769657 --- .../audio_classifier_graph_options.proto | 1 + .../proto/audio_embedder_graph_options.proto | 1 + .../proto/text_classifier_graph_options.proto | 1 + .../proto/text_embedder_graph_options.proto | 1 + .../gesture_classifier_graph_options.proto | 1 + .../gesture_embedder_graph_options.proto | 1 + .../gesture_recognizer_graph_options.proto | 1 + ...and_gesture_recognizer_graph_options.proto | 1 + .../proto/hand_detector_graph_options.proto | 1 + .../proto/hand_landmarker_graph_options.proto | 1 + ...and_landmarks_detector_graph_options.proto | 1 + .../image_classifier_graph_options.proto | 1 + .../proto/image_embedder_graph_options.proto | 1 + .../proto/image_segmenter_graph_options.proto | 1 + .../proto/object_detector_options.proto | 1 + .../tasks/web/audio/audio_classifier/BUILD | 21 ++ .../audio_classifier/audio_classifier_test.ts | 208 ++++++++++++ .../tasks/web/audio/audio_embedder/BUILD | 21 ++ .../audio_embedder/audio_embedder_test.ts | 185 +++++++++++ .../tasks/web/text/text_classifier/BUILD | 22 ++ .../text_classifier/text_classifier_test.ts | 152 +++++++++ mediapipe/tasks/web/text/text_embedder/BUILD | 21 ++ .../text/text_embedder/text_embedder_test.ts | 165 ++++++++++ mediapipe/tasks/web/vision/core/BUILD | 18 + .../vision/core/vision_task_runner.test.ts | 99 ++++++ .../tasks/web/vision/gesture_recognizer/BUILD | 25 ++ .../gesture_recognizer_test.ts | 307 ++++++++++++++++++ .../tasks/web/vision/hand_landmarker/BUILD | 25 ++ .../hand_landmarker/hand_landmarker_test.ts | 251 ++++++++++++++ .../tasks/web/vision/image_classifier/BUILD | 24 ++ .../image_classifier/image_classifier_test.ts | 150 +++++++++ .../tasks/web/vision/image_embedder/BUILD | 21 ++ .../image_embedder/image_embedder_test.ts | 158 +++++++++ .../tasks/web/vision/object_detector/BUILD | 24 ++ .../object_detector/object_detector_test.ts | 229 +++++++++++++ 35 files changed, 2141 insertions(+) create mode 100644 mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts create mode 100644 mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts create mode 100644 mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts create mode 100644 mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts create mode 100644 mediapipe/tasks/web/vision/core/vision_task_runner.test.ts create mode 100644 mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts create mode 100644 mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts create mode 100644 mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts create mode 100644 mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts create mode 100644 mediapipe/tasks/web/vision/object_detector/object_detector_test.ts diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto index 5d4ba3296..cc26b3070 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.audio.audio_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto b/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto index 25c5d5474..367a1bf26 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.audio.audio_embedder.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto index 8f4d7eea6..41f87b519 100644 --- a/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.text.text_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto index e7e3a63c7..fc8e02858 100644 --- a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.text.text_embedder.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto index dcefa075f..edbabc018 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto index bff4e0a9c..df909a6db 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto index 57d8a3746..fef22c07c 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto"; import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto index 7df2fed37..ae85509da 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto index a009f2365..bede70da5 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_detector.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.handdetector.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto index 51e4e129a..d0edf99c0 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto"; import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto index 195f6e5cc..a2d520963 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.handlandmarker.proto"; diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto index 76315e230..24b126a35 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto index 72b3e7ee3..24ee866f2 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_embedder.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto index 4d8100842..5c7d2ec71 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_segmenter.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; diff --git a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto index cba58ace8..3f6932f8f 100644 --- a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto +++ b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.object_detector.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.objectdetector.proto"; diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index dc82a4a24..24ef31feb 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -2,6 +2,7 @@ # # This task takes audio data and outputs the classification result. +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -44,3 +45,23 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core:classifier_options", ], ) + +mediapipe_ts_library( + name = "audio_classifier_test_lib", + testonly = True, + srcs = [ + "audio_classifier_test.ts", + ], + deps = [ + ":audio_classifier", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "audio_classifier_test", + deps = [":audio_classifier_test_lib"], +) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts new file mode 100644 index 000000000..d5c0a9429 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts @@ -0,0 +1,208 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {AudioClassifier} from './audio_classifier'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class AudioClassifierFake extends AudioClassifier implements + MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = + 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + private protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + private resultProtoVector: ClassificationResult[] = []; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('timestamped_classifications'); + this.protoVectorListener = listener; + }); + spyOn(this.graphRunner, 'addDoubleToStream') + .and.callFake((sampleRate, streamName, timestamp) => { + if (streamName === 'sample_rate') { + this.lastSampleRate = sampleRate; + } + }); + spyOn(this.graphRunner, 'addAudioToStreamWithShape') + .and.callFake( + (audioData, numChannels, numSamples, streamName, timestamp) => { + expect(numChannels).toBe(1); + }); + spyOn(this.graphRunner, 'finishProcessing').and.callFake(() => { + if (!this.protoVectorListener) return; + this.protoVectorListener(this.resultProtoVector.map( + classificationResult => classificationResult.serializeBinary())); + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } + + /** Sets the Protobuf that will be send to the API. */ + setResults(results: ClassificationResult[]): void { + this.resultProtoVector = results; + } +} + +describe('AudioClassifier', () => { + let audioClassifier: AudioClassifierFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + audioClassifier = new AudioClassifierFake(); + await audioClassifier.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(audioClassifier); + verifyListenersRegistered(audioClassifier); + }); + + it('reloads graph when settings are changed', async () => { + await audioClassifier.setOptions({maxResults: 1}); + verifyGraph(audioClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(audioClassifier); + + await audioClassifier.setOptions({maxResults: 5}); + verifyGraph(audioClassifier, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(audioClassifier); + }); + + it('merges options', async () => { + await audioClassifier.setOptions({maxResults: 1}); + await audioClassifier.setOptions({displayNamesLocale: 'en'}); + verifyGraph(audioClassifier, [ + 'classifierOptions', { + maxResults: 1, + displayNamesLocale: 'en', + scoreThreshold: undefined, + categoryAllowlistList: [], + categoryDenylistList: [] + } + ]); + }); + + it('uses a sample rate of 48000 by default', async () => { + audioClassifier.classify(new Float32Array([])); + expect(audioClassifier.lastSampleRate).toEqual(48000); + }); + + it('uses default sample rate if none provided', async () => { + audioClassifier.setDefaultSampleRate(16000); + audioClassifier.classify(new Float32Array([])); + expect(audioClassifier.lastSampleRate).toEqual(16000); + }); + + it('uses custom sample rate if provided', async () => { + audioClassifier.setDefaultSampleRate(16000); + audioClassifier.classify(new Float32Array([]), 44100); + expect(audioClassifier.lastSampleRate).toEqual(44100); + }); + + it('transforms results', async () => { + const resultProtoVector: ClassificationResult[] = []; + + let classificationResult = new ClassificationResult(); + classificationResult.setTimestampMs(0); + let classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + let classificationList = new ClassificationList(); + let clasification = new Classification(); + clasification.setIndex(1); + clasification.setScore(0.2); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + resultProtoVector.push(classificationResult); + + classificationResult = new ClassificationResult(); + classificationResult.setTimestampMs(1); + classifcations = new Classifications(); + classificationList = new ClassificationList(); + clasification = new Classification(); + clasification.setIndex(2); + clasification.setScore(0.3); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + resultProtoVector.push(classificationResult); + + // Invoke the audio classifier + audioClassifier.setResults(resultProtoVector); + const results = audioClassifier.classify(new Float32Array([])); + expect(results.length).toEqual(2); + expect(results[0]).toEqual({ + classifications: [{ + categories: [{ + index: 1, + score: 0.2, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }], + timestampMs: 0 + }); + expect(results[1]).toEqual({ + classifications: [{ + categories: [{index: 2, score: 0.3, displayName: '', categoryName: ''}], + headIndex: 0, + headName: '' + }], + timestampMs: 1 + }); + }); + + it('clears results between invocations', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + audioClassifier.setResults([classificationResult]); + + // Invoke the gesture recognizer twice + const classifications1 = audioClassifier.classify(new Float32Array([])); + const classifications2 = audioClassifier.classify(new Float32Array([])); + + // Verify that gestures2 is not a concatenation of all previously returned + // gestures. + expect(classifications1).toEqual(classifications2); + }); +}); diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index dc84d0cd6..0817776c5 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -3,6 +3,7 @@ # This task takes audio input and performs embedding. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -43,3 +44,23 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core:embedder_options", ], ) + +mediapipe_ts_library( + name = "audio_embedder_test_lib", + testonly = True, + srcs = [ + "audio_embedder_test.ts", + ], + deps = [ + ":audio_embedder", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "audio_embedder_test", + deps = [":audio_embedder_test_lib"], +) diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts new file mode 100644 index 000000000..2f605ff98 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts @@ -0,0 +1,185 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Embedding, EmbeddingResult as EmbeddingResultProto, FloatEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {AudioEmbedder, AudioEmbedderResult} from './audio_embedder'; + + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class AudioEmbedderFake extends AudioEmbedder implements MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph'; + graph: CalculatorGraphConfig|undefined; + attachListenerSpies: jasmine.Spy[] = []; + fakeWasmModule: SpyWasmModule; + + protoListener: ((binaryProto: Uint8Array) => void)|undefined; + protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('embeddings_out'); + this.protoListener = listener; + }); + this.attachListenerSpies[1] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('timestamped_embeddings_out'); + this.protoVectorListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addDoubleToStream').and.callFake(sampleRate => { + this.lastSampleRate = sampleRate; + }); + spyOn(this.graphRunner, 'addAudioToStreamWithShape'); + } +} + +describe('AudioEmbedder', () => { + let audioEmbedder: AudioEmbedderFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + audioEmbedder = new AudioEmbedderFake(); + await audioEmbedder.setOptions({}); // Initialize graph + }); + + it('initializes graph', () => { + verifyGraph(audioEmbedder); + verifyListenersRegistered(audioEmbedder); + }); + + it('reloads graph when settings are changed', async () => { + await audioEmbedder.setOptions({quantize: true}); + verifyGraph(audioEmbedder, [['embedderOptions', 'quantize'], true]); + verifyListenersRegistered(audioEmbedder); + + await audioEmbedder.setOptions({quantize: undefined}); + verifyGraph(audioEmbedder, [['embedderOptions', 'quantize'], undefined]); + verifyListenersRegistered(audioEmbedder); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await audioEmbedder.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + audioEmbedder, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('combines options', async () => { + await audioEmbedder.setOptions({quantize: true}); + await audioEmbedder.setOptions({l2Normalize: true}); + verifyGraph( + audioEmbedder, + ['embedderOptions', {'quantize': true, 'l2Normalize': true}]); + }); + + it('uses a sample rate of 48000 by default', async () => { + audioEmbedder.embed(new Float32Array([])); + expect(audioEmbedder.lastSampleRate).toEqual(48000); + }); + + it('uses default sample rate if none provided', async () => { + audioEmbedder.setDefaultSampleRate(16000); + audioEmbedder.embed(new Float32Array([])); + expect(audioEmbedder.lastSampleRate).toEqual(16000); + }); + + it('uses custom sample rate if provided', async () => { + audioEmbedder.setDefaultSampleRate(16000); + audioEmbedder.embed(new Float32Array([]), 44100); + expect(audioEmbedder.lastSampleRate).toEqual(44100); + }); + + describe('transforms results', () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + embedding.setFloatEmbedding(floatEmbedding); + const resultProto = new EmbeddingResultProto(); + resultProto.addEmbeddings(embedding); + + function validateEmbeddingResult( + expectedEmbeddignResult: AudioEmbedderResult[]) { + expect(expectedEmbeddignResult.length).toEqual(1); + + const [embeddingResult] = expectedEmbeddignResult; + expect(embeddingResult.embeddings.length).toEqual(1); + expect(embeddingResult.embeddings[0]) + .toEqual( + {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); + } + + it('from embeddings strem', async () => { + audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(audioEmbedder); + // Pass the test data to our listener + audioEmbedder.protoListener!(resultProto.serializeBinary()); + }); + + // Invoke the audio embedder + const embeddingResults = audioEmbedder.embed(new Float32Array([])); + validateEmbeddingResult(embeddingResults); + }); + + it('from timestamped embeddgins stream', async () => { + audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(audioEmbedder); + // Pass the test data to our listener + audioEmbedder.protoVectorListener!([resultProto.serializeBinary()]); + }); + + // Invoke the audio embedder + const embeddingResults = audioEmbedder.embed(new Float32Array([]), 42); + validateEmbeddingResult(embeddingResults); + }); + }); +}); diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 07f78ac20..fd97c3db4 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -4,6 +4,7 @@ # BERT-based text classification). load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -45,3 +46,24 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core:classifier_options", ], ) + +mediapipe_ts_library( + name = "text_classifier_test_lib", + testonly = True, + srcs = [ + "text_classifier_test.ts", + ], + deps = [ + ":text_classifier", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "text_classifier_test", + deps = [":text_classifier_test_lib"], +) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts new file mode 100644 index 000000000..841bf8c48 --- /dev/null +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts @@ -0,0 +1,152 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {TextClassifier} from './text_classifier'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class TextClassifierFake extends TextClassifier implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProto: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('classifications_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } +} + +describe('TextClassifier', () => { + let textClassifier: TextClassifierFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + textClassifier = new TextClassifierFake(); + await textClassifier.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(textClassifier); + verifyListenersRegistered(textClassifier); + }); + + it('reloads graph when settings are changed', async () => { + await textClassifier.setOptions({maxResults: 1}); + verifyGraph(textClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(textClassifier); + + await textClassifier.setOptions({maxResults: 5}); + verifyGraph(textClassifier, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(textClassifier); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await textClassifier.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + textClassifier, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await textClassifier.setOptions({maxResults: 1}); + await textClassifier.setOptions({displayNamesLocale: 'en'}); + verifyGraph(textClassifier, [ + 'classifierOptions', { + maxResults: 1, + displayNamesLocale: 'en', + scoreThreshold: undefined, + categoryAllowlistList: [], + categoryDenylistList: [] + } + ]); + }); + + it('transforms results', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + clasification.setIndex(1); + clasification.setScore(0.2); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + // Pass the test data to our listener + textClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(textClassifier); + textClassifier.protoListener!(classificationResult.serializeBinary()); + }); + + // Invoke the text classifier + const result = textClassifier.classify('foo'); + + expect(textClassifier.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result).toEqual({ + classifications: [{ + categories: [{ + index: 1, + score: 0.2, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }] + }); + }); +}); diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 7d796fb7e..1514944bf 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -4,6 +4,7 @@ # load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -44,3 +45,23 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core:embedder_options", ], ) + +mediapipe_ts_library( + name = "text_embedder_test_lib", + testonly = True, + srcs = [ + "text_embedder_test.ts", + ], + deps = [ + ":text_embedder", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "text_embedder_test", + deps = [":text_embedder_test_lib"], +) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts new file mode 100644 index 000000000..04a9b371a --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts @@ -0,0 +1,165 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Embedding, EmbeddingResult, FloatEmbedding, QuantizedEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {TextEmbedder} from './text_embedder'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class TextEmbedderFake extends TextEmbedder implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.text.text_embedder.TextEmbedderGraph'; + graph: CalculatorGraphConfig|undefined; + attachListenerSpies: jasmine.Spy[] = []; + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProtos: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('embeddings_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } +} + +describe('TextEmbedder', () => { + let textEmbedder: TextEmbedderFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + textEmbedder = new TextEmbedderFake(); + await textEmbedder.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(textEmbedder); + verifyListenersRegistered(textEmbedder); + }); + + it('reloads graph when settings are changed', async () => { + await textEmbedder.setOptions({quantize: true}); + verifyGraph(textEmbedder, [['embedderOptions', 'quantize'], true]); + verifyListenersRegistered(textEmbedder); + + await textEmbedder.setOptions({quantize: undefined}); + verifyGraph(textEmbedder, [['embedderOptions', 'quantize'], undefined]); + verifyListenersRegistered(textEmbedder); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await textEmbedder.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + textEmbedder, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('combines options', async () => { + await textEmbedder.setOptions({quantize: true}); + await textEmbedder.setOptions({l2Normalize: true}); + verifyGraph( + textEmbedder, + ['embedderOptions', {'quantize': true, 'l2Normalize': true}]); + }); + + it('transforms results', async () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + embedding.setFloatEmbedding(floatEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + + // Pass the test data to our listener + textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(textEmbedder); + textEmbedder.protoListener!(resultProto.serializeBinary()); + }); + + // Invoke the text embedder + const embeddingResult = textEmbedder.embed('foo'); + + expect(textEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingResult.embeddings.length).toEqual(1); + expect(embeddingResult.embeddings[0]) + .toEqual( + {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); + }); + + it('transforms custom quantized values', async () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const quantizedEmbedding = new QuantizedEmbedding(); + const quantizedValues = new Uint8Array([1, 2, 3]); + quantizedEmbedding.setValues(quantizedValues); + + embedding.setQuantizedEmbedding(quantizedEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + + // Pass the test data to our listener + textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(textEmbedder); + textEmbedder.protoListener!(resultProto.serializeBinary()); + }); + + // Invoke the text embedder + const embeddingsResult = textEmbedder.embed('foo'); + + expect(textEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingsResult.embeddings.length).toEqual(1); + expect(embeddingsResult.embeddings[0]).toEqual({ + quantizedEmbedding: new Uint8Array([1, 2, 3]), + headIndex: 1, + headName: 'headName' + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index b389a9b01..e4ea3036f 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -1,5 +1,6 @@ # This package contains options shared by all MediaPipe Vision Tasks for Web. +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -22,3 +23,20 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:graph_runner_ts", ], ) + +mediapipe_ts_library( + name = "vision_task_runner_test_lib", + testonly = True, + srcs = ["vision_task_runner.test.ts"], + deps = [ + ":vision_task_runner", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +jasmine_node_test( + name = "vision_task_runner_test", + deps = [":vision_task_runner_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts new file mode 100644 index 000000000..6cc9ea328 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -0,0 +1,99 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; + +import {VisionTaskRunner} from './vision_task_runner'; + +class VisionTaskRunnerFake extends VisionTaskRunner { + baseOptions = new BaseOptionsProto(); + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + } + + protected override process(): void {} + + override processImageData(image: ImageSource): void { + super.processImageData(image); + } + + override processVideoData(imageFrame: ImageSource, timestamp: number): void { + super.processVideoData(imageFrame, timestamp); + } +} + +describe('VisionTaskRunner', () => { + const streamMode = { + modelAsset: undefined, + useStreamMode: true, + acceleration: undefined, + }; + + const imageMode = { + modelAsset: undefined, + useStreamMode: false, + acceleration: undefined, + }; + + let visionTaskRunner: VisionTaskRunnerFake; + + beforeEach(() => { + visionTaskRunner = new VisionTaskRunnerFake(); + }); + + it('can enable image mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'image'}); + expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + }); + + it('can enable video mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode); + }); + + it('can clear running mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + + // Clear running mode + await visionTaskRunner.setOptions({runningMode: undefined}); + expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + }); + + it('cannot process images with video mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + expect(() => { + visionTaskRunner.processImageData({} as HTMLImageElement); + }).toThrowError(/Task is not initialized with image mode./); + }); + + it('cannot process video with image mode', async () => { + // Use default for `useStreamMode` + expect(() => { + visionTaskRunner.processVideoData({} as HTMLImageElement, 42); + }).toThrowError(/Task is not initialized with video mode./); + + // Explicitly set to image mode + await visionTaskRunner.setOptions({runningMode: 'image'}); + expect(() => { + visionTaskRunner.processVideoData({} as HTMLImageElement, 42); + }).toThrowError(/Task is not initialized with video mode./); + }); +}); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index 6e2e56196..aa2f9c366 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -4,6 +4,7 @@ # the detection results for one or more gesture categories, using Gesture Recognizer. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -52,3 +53,27 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "gesture_recognizer_test_lib", + testonly = True, + srcs = [ + "gesture_recognizer_test.ts", + ], + deps = [ + ":gesture_recognizer", + ":gesture_recognizer_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "gesture_recognizer_test", + tags = ["nomsan"], + deps = [":gesture_recognizer_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts new file mode 100644 index 000000000..c0f0d1554 --- /dev/null +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -0,0 +1,307 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {GestureRecognizer, GestureRecognizerOptions} from './gesture_recognizer'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +type ProtoListener = ((binaryProtos: Uint8Array[]) => void); + +function createHandednesses(): Uint8Array[] { + const handsProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.1); + classification.setIndex(1); + classification.setLabel('handedness_label'); + classification.setDisplayName('handedness_display_name'); + handsProto.addClassification(classification); + return [handsProto.serializeBinary()]; +} + +function createGestures(): Uint8Array[] { + const gesturesProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.2); + classification.setIndex(2); + classification.setLabel('gesture_label'); + classification.setDisplayName('gesture_display_name'); + gesturesProto.addClassification(classification); + return [gesturesProto.serializeBinary()]; +} + +function createLandmarks(): Uint8Array[] { + const handLandmarksProto = new NormalizedLandmarkList(); + const landmark = new NormalizedLandmark(); + landmark.setX(0.3); + landmark.setY(0.4); + landmark.setZ(0.5); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +function createWorldLandmarks(): Uint8Array[] { + const handLandmarksProto = new LandmarkList(); + const landmark = new Landmark(); + landmark.setX(21); + landmark.setY(22); + landmark.setZ(23); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +class GestureRecognizerFake extends GestureRecognizer implements + MediapipeTasksFake { + calculatorName = + 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + listeners = new Map(); + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toMatch( + /(hand_landmarks|world_hand_landmarks|handedness|hand_gestures)/); + this.listeners.set(stream, listener); + }); + + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + spyOn(this.graphRunner, 'addProtoToStream'); + } + + getGraphRunner(): GraphRunnerImageLib { + return this.graphRunner; + } +} + +describe('GestureRecognizer', () => { + let gestureRecognizer: GestureRecognizerFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + gestureRecognizer = new GestureRecognizerFake(); + await gestureRecognizer.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(gestureRecognizer); + verifyListenersRegistered(gestureRecognizer); + }); + + it('reloads graph when settings are changed', async () => { + await gestureRecognizer.setOptions({numHands: 1}); + verifyGraph(gestureRecognizer, [ + ['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 1 + ]); + verifyListenersRegistered(gestureRecognizer); + + await gestureRecognizer.setOptions({numHands: 5}); + verifyGraph(gestureRecognizer, [ + ['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 5 + ]); + verifyListenersRegistered(gestureRecognizer); + }); + + it('merges options', async () => { + await gestureRecognizer.setOptions({numHands: 1}); + await gestureRecognizer.setOptions({minHandDetectionConfidence: 0.5}); + verifyGraph(gestureRecognizer, [ + ['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 1 + ]); + verifyGraph(gestureRecognizer, [ + [ + 'handLandmarkerGraphOptions', 'handDetectorGraphOptions', + 'minDetectionConfidence' + ], + 0.5 + ]); + }); + + describe('setOptions() ', () => { + interface TestCase { + optionPath: [keyof GestureRecognizerOptions, ...string[]]; + fieldPath: string[]; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionPath: ['numHands'], + fieldPath: [ + 'handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands' + ], + customValue: 5, + defaultValue: 1 + }, + { + optionPath: ['minHandDetectionConfidence'], + fieldPath: [ + 'handLandmarkerGraphOptions', 'handDetectorGraphOptions', + 'minDetectionConfidence' + ], + customValue: 0.1, + defaultValue: 0.5 + }, + { + optionPath: ['minHandPresenceConfidence'], + fieldPath: [ + 'handLandmarkerGraphOptions', 'handLandmarksDetectorGraphOptions', + 'minDetectionConfidence' + ], + customValue: 0.2, + defaultValue: 0.5 + }, + { + optionPath: ['minTrackingConfidence'], + fieldPath: ['handLandmarkerGraphOptions', 'minTrackingConfidence'], + customValue: 0.3, + defaultValue: 0.5 + }, + { + optionPath: ['cannedGesturesClassifierOptions', 'scoreThreshold'], + fieldPath: [ + 'handGestureRecognizerGraphOptions', + 'cannedGestureClassifierGraphOptions', 'classifierOptions', + 'scoreThreshold' + ], + customValue: 0.4, + defaultValue: undefined + }, + { + optionPath: ['customGesturesClassifierOptions', 'scoreThreshold'], + fieldPath: [ + 'handGestureRecognizerGraphOptions', + 'customGestureClassifierGraphOptions', 'classifierOptions', + 'scoreThreshold' + ], + customValue: 0.5, + defaultValue: undefined, + }, + ]; + + /** Creates an options object that can be passed to setOptions() */ + function createOptions( + path: string[], value: unknown): GestureRecognizerOptions { + const options: Record = {}; + let currentLevel = options; + for (const element of path.slice(0, -1)) { + currentLevel[element] = {}; + currentLevel = currentLevel[element] as Record; + } + currentLevel[path[path.length - 1]] = value; + return options; + } + + for (const testCase of testCases) { + it(`uses default value for ${testCase.optionPath[0]}`, async () => { + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.defaultValue]); + }); + + it(`can set ${testCase.optionPath[0]}`, async () => { + await gestureRecognizer.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.customValue]); + }); + + it(`can clear ${testCase.optionPath[0]}`, async () => { + await gestureRecognizer.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.customValue]); + + await gestureRecognizer.setOptions( + createOptions(testCase.optionPath, undefined)); + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.defaultValue]); + }); + } + }); + + it('transforms results', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(gestureRecognizer); + gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + gestureRecognizer.listeners.get('handedness')!(createHandednesses()); + gestureRecognizer.listeners.get('hand_gestures')!(createGestures()); + }); + + // Invoke the gesture recognizer + const gestures = gestureRecognizer.recognize({} as HTMLImageElement); + expect(gestureRecognizer.getGraphRunner().addProtoToStream) + .toHaveBeenCalledTimes(1); + expect(gestureRecognizer.getGraphRunner().addGpuBufferAsImageToStream) + .toHaveBeenCalledTimes(1); + expect(gestureRecognizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + + expect(gestures).toEqual({ + 'gestures': [[{ + 'score': 0.2, + 'index': 2, + 'categoryName': 'gesture_label', + 'displayName': 'gesture_display_name' + }]], + 'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]], + 'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]], + 'handednesses': [[{ + 'score': 0.1, + 'index': 1, + 'categoryName': 'handedness_label', + 'displayName': 'handedness_display_name' + }]] + }); + }); + + it('clears results between invoations', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + gestureRecognizer.listeners.get('handedness')!(createHandednesses()); + gestureRecognizer.listeners.get('hand_gestures')!(createGestures()); + }); + + // Invoke the gesture recognizer twice + const gestures1 = gestureRecognizer.recognize({} as HTMLImageElement); + const gestures2 = gestureRecognizer.recognize({} as HTMLImageElement); + + // Verify that gestures2 is not a concatenation of all previously returned + // gestures. + expect(gestures2).toEqual(gestures1); + }); +}); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index 520898e34..d1f1e48f3 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -4,6 +4,7 @@ # the detection results for one or more hand categories, using Hand Landmarker. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -47,3 +48,27 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "hand_landmarker_test_lib", + testonly = True, + srcs = [ + "hand_landmarker_test.ts", + ], + deps = [ + ":hand_landmarker", + ":hand_landmarker_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "hand_landmarker_test", + tags = ["nomsan"], + deps = [":hand_landmarker_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts new file mode 100644 index 000000000..fc26680e0 --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -0,0 +1,251 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {HandLandmarker} from './hand_landmarker'; +import {HandLandmarkerOptions} from './hand_landmarker_options'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +type ProtoListener = ((binaryProtos: Uint8Array[]) => void); + +function createHandednesses(): Uint8Array[] { + const handsProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.1); + classification.setIndex(1); + classification.setLabel('handedness_label'); + classification.setDisplayName('handedness_display_name'); + handsProto.addClassification(classification); + return [handsProto.serializeBinary()]; +} + +function createLandmarks(): Uint8Array[] { + const handLandmarksProto = new NormalizedLandmarkList(); + const landmark = new NormalizedLandmark(); + landmark.setX(0.3); + landmark.setY(0.4); + landmark.setZ(0.5); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +function createWorldLandmarks(): Uint8Array[] { + const handLandmarksProto = new LandmarkList(); + const landmark = new Landmark(); + landmark.setX(21); + landmark.setY(22); + landmark.setZ(23); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +class HandLandmarkerFake extends HandLandmarker implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + listeners = new Map(); + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toMatch( + /(hand_landmarks|world_hand_landmarks|handedness|hand_hands)/); + this.listeners.set(stream, listener); + }); + + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + spyOn(this.graphRunner, 'addProtoToStream'); + } + + getGraphRunner(): GraphRunnerImageLib { + return this.graphRunner; + } +} + +describe('HandLandmarker', () => { + let handLandmarker: HandLandmarkerFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + handLandmarker = new HandLandmarkerFake(); + await handLandmarker.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(handLandmarker); + verifyListenersRegistered(handLandmarker); + }); + + it('reloads graph when settings are changed', async () => { + verifyListenersRegistered(handLandmarker); + + await handLandmarker.setOptions({numHands: 1}); + verifyGraph(handLandmarker, [['handDetectorGraphOptions', 'numHands'], 1]); + verifyListenersRegistered(handLandmarker); + + await handLandmarker.setOptions({numHands: 5}); + verifyGraph(handLandmarker, [['handDetectorGraphOptions', 'numHands'], 5]); + verifyListenersRegistered(handLandmarker); + }); + + it('merges options', async () => { + await handLandmarker.setOptions({numHands: 1}); + await handLandmarker.setOptions({minHandDetectionConfidence: 0.5}); + verifyGraph(handLandmarker, [ + 'handDetectorGraphOptions', + {numHands: 1, baseOptions: undefined, minDetectionConfidence: 0.5} + ]); + }); + + describe('setOptions() ', () => { + interface TestCase { + optionPath: [keyof HandLandmarkerOptions, ...string[]]; + fieldPath: string[]; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionPath: ['numHands'], + fieldPath: ['handDetectorGraphOptions', 'numHands'], + customValue: 5, + defaultValue: 1 + }, + { + optionPath: ['minHandDetectionConfidence'], + fieldPath: ['handDetectorGraphOptions', 'minDetectionConfidence'], + customValue: 0.1, + defaultValue: 0.5 + }, + { + optionPath: ['minHandPresenceConfidence'], + fieldPath: + ['handLandmarksDetectorGraphOptions', 'minDetectionConfidence'], + customValue: 0.2, + defaultValue: 0.5 + }, + { + optionPath: ['minTrackingConfidence'], + fieldPath: ['minTrackingConfidence'], + customValue: 0.3, + defaultValue: 0.5 + }, + ]; + + /** Creates an options object that can be passed to setOptions() */ + function createOptions( + path: string[], value: unknown): HandLandmarkerOptions { + const options: Record = {}; + let currentLevel = options; + for (const element of path.slice(0, -1)) { + currentLevel[element] = {}; + currentLevel = currentLevel[element] as Record; + } + currentLevel[path[path.length - 1]] = value; + return options; + } + + for (const testCase of testCases) { + it(`uses default value for ${testCase.optionPath[0]}`, async () => { + verifyGraph( + handLandmarker, [testCase.fieldPath, testCase.defaultValue]); + }); + + it(`can set ${testCase.optionPath[0]}`, async () => { + await handLandmarker.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph(handLandmarker, [testCase.fieldPath, testCase.customValue]); + }); + + it(`can clear ${testCase.optionPath[0]}`, async () => { + await handLandmarker.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph(handLandmarker, [testCase.fieldPath, testCase.customValue]); + + await handLandmarker.setOptions( + createOptions(testCase.optionPath, undefined)); + verifyGraph( + handLandmarker, [testCase.fieldPath, testCase.defaultValue]); + }); + } + }); + + it('transforms results', async () => { + // Pass the test data to our listener + handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(handLandmarker); + handLandmarker.listeners.get('hand_landmarks')!(createLandmarks()); + handLandmarker.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + handLandmarker.listeners.get('handedness')!(createHandednesses()); + }); + + // Invoke the hand landmarker + const landmarks = handLandmarker.detect({} as HTMLImageElement); + expect(handLandmarker.getGraphRunner().addProtoToStream) + .toHaveBeenCalledTimes(1); + expect(handLandmarker.getGraphRunner().addGpuBufferAsImageToStream) + .toHaveBeenCalledTimes(1); + expect(handLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + + expect(landmarks).toEqual({ + 'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]], + 'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]], + 'handednesses': [[{ + 'score': 0.1, + 'index': 1, + 'categoryName': 'handedness_label', + 'displayName': 'handedness_display_name' + }]] + }); + }); + + it('clears results between invoations', async () => { + // Pass the test data to our listener + handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { + handLandmarker.listeners.get('hand_landmarks')!(createLandmarks()); + handLandmarker.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + handLandmarker.listeners.get('handedness')!(createHandednesses()); + }); + + // Invoke the hand landmarker twice + const landmarks1 = handLandmarker.detect({} as HTMLImageElement); + const landmarks2 = handLandmarker.detect({} as HTMLImageElement); + + // Verify that hands2 is not a concatenation of all previously returned + // hands. + expect(landmarks1).toEqual(landmarks2); + }); +}); diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index 848c162ae..310575964 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -3,6 +3,7 @@ # This task takes video or image frames and outputs the classification result. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -44,3 +45,26 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "image_classifier_test_lib", + testonly = True, + srcs = [ + "image_classifier_test.ts", + ], + deps = [ + ":image_classifier", + ":image_classifier_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "image_classifier_test", + tags = ["nomsan"], + deps = [":image_classifier_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts new file mode 100644 index 000000000..2041a0cef --- /dev/null +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts @@ -0,0 +1,150 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {ImageClassifier} from './image_classifier'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class ImageClassifierFake extends ImageClassifier implements + MediapipeTasksFake { + calculatorName = + 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProto: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('classifications'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ImageClassifier', () => { + let imageClassifier: ImageClassifierFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + imageClassifier = new ImageClassifierFake(); + await imageClassifier.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(imageClassifier); + verifyListenersRegistered(imageClassifier); + }); + + it('reloads graph when settings are changed', async () => { + await imageClassifier.setOptions({maxResults: 1}); + verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(imageClassifier); + + await imageClassifier.setOptions({maxResults: 5}); + verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(imageClassifier); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await imageClassifier.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + imageClassifier, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await imageClassifier.setOptions({maxResults: 1}); + await imageClassifier.setOptions({displayNamesLocale: 'en'}); + verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyGraph( + imageClassifier, [['classifierOptions', 'displayNamesLocale'], 'en']); + }); + + it('transforms results', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + clasification.setIndex(1); + clasification.setScore(0.2); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + // Pass the test data to our listener + imageClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageClassifier); + imageClassifier.protoListener!(classificationResult.serializeBinary()); + }); + + // Invoke the image classifier + const result = imageClassifier.classify({} as HTMLImageElement); + + expect(imageClassifier.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result).toEqual({ + classifications: [{ + categories: [{ + index: 1, + score: 0.2, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }] + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index 6c9d80fb1..de4785e6c 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -3,6 +3,7 @@ # This task performs embedding extraction on images. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -45,3 +46,23 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "image_embedder_test_lib", + testonly = True, + srcs = [ + "image_embedder_test.ts", + ], + deps = [ + ":image_embedder", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "image_embedder_test", + deps = [":image_embedder_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts new file mode 100644 index 000000000..cafe0f3d8 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts @@ -0,0 +1,158 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Embedding, EmbeddingResult, FloatEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {ImageEmbedder} from './image_embedder'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class ImageEmbedderFake extends ImageEmbedder implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'; + graph: CalculatorGraphConfig|undefined; + attachListenerSpies: jasmine.Spy[] = []; + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProtos: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('embeddings_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ImageEmbedder', () => { + let imageEmbedder: ImageEmbedderFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + imageEmbedder = new ImageEmbedderFake(); + await imageEmbedder.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(imageEmbedder); + verifyListenersRegistered(imageEmbedder); + }); + + it('reloads graph when settings are changed', async () => { + verifyListenersRegistered(imageEmbedder); + + await imageEmbedder.setOptions({quantize: true}); + verifyGraph(imageEmbedder, [['embedderOptions', 'quantize'], true]); + verifyListenersRegistered(imageEmbedder); + + await imageEmbedder.setOptions({quantize: undefined}); + verifyGraph(imageEmbedder, [['embedderOptions', 'quantize'], undefined]); + verifyListenersRegistered(imageEmbedder); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await imageEmbedder.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + imageEmbedder, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('overrides options', async () => { + await imageEmbedder.setOptions({quantize: true}); + await imageEmbedder.setOptions({l2Normalize: true}); + verifyGraph( + imageEmbedder, + ['embedderOptions', {'quantize': true, 'l2Normalize': true}]); + }); + + describe('transforms result', () => { + beforeEach(() => { + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + embedding.setFloatEmbedding(floatEmbedding); + + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + resultProto.setTimestampMs(42); + + // Pass the test data to our listener + imageEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageEmbedder); + imageEmbedder.protoListener!(resultProto.serializeBinary()); + }); + }); + + it('for image mode', async () => { + // Invoke the image embedder + const embeddingResult = imageEmbedder.embed({} as HTMLImageElement); + + expect(imageEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingResult).toEqual({ + embeddings: + [{headIndex: 1, headName: 'headName', floatEmbedding: [0.1, 0.9]}], + timestampMs: 42 + }); + }); + + it('for video mode', async () => { + await imageEmbedder.setOptions({runningMode: 'video'}); + + // Invoke the video embedder + const embeddingResult = + imageEmbedder.embedForVideo({} as HTMLImageElement, 42); + + expect(imageEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingResult).toEqual({ + embeddings: + [{headIndex: 1, headName: 'headName', floatEmbedding: [0.1, 0.9]}], + timestampMs: 42 + }); + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index f73790895..fc206a2d7 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -4,6 +4,7 @@ # the detection results for one or more object categories, using Object Detector. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -41,3 +42,26 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "object_detector_test_lib", + testonly = True, + srcs = [ + "object_detector_test.ts", + ], + deps = [ + ":object_detector", + ":object_detector_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/framework/formats:location_data_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "object_detector_test", + tags = ["nomsan"], + deps = [":object_detector_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts new file mode 100644 index 000000000..fff1a1c48 --- /dev/null +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -0,0 +1,229 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {LocationData} from '../../../../framework/formats/location_data_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {ObjectDetector} from './object_detector'; +import {ObjectDetectorOptions} from './object_detector_options'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class ObjectDetectorFake extends ObjectDetector implements MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = 'mediapipe.tasks.vision.ObjectDetectorGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('detections'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ObjectDetector', () => { + let objectDetector: ObjectDetectorFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + objectDetector = new ObjectDetectorFake(); + await objectDetector.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(objectDetector); + verifyListenersRegistered(objectDetector); + }); + + it('reloads graph when settings are changed', async () => { + await objectDetector.setOptions({maxResults: 1}); + verifyGraph(objectDetector, ['maxResults', 1]); + verifyListenersRegistered(objectDetector); + + await objectDetector.setOptions({maxResults: 5}); + verifyGraph(objectDetector, ['maxResults', 5]); + verifyListenersRegistered(objectDetector); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await objectDetector.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + objectDetector, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await objectDetector.setOptions({maxResults: 1}); + await objectDetector.setOptions({displayNamesLocale: 'en'}); + verifyGraph(objectDetector, ['maxResults', 1]); + verifyGraph(objectDetector, ['displayNamesLocale', 'en']); + }); + + describe('setOptions() ', () => { + interface TestCase { + optionName: keyof ObjectDetectorOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionName: 'maxResults', + protoName: 'maxResults', + customValue: 5, + defaultValue: -1 + }, + { + optionName: 'displayNamesLocale', + protoName: 'displayNamesLocale', + customValue: 'en', + defaultValue: 'en' + }, + { + optionName: 'scoreThreshold', + protoName: 'scoreThreshold', + customValue: 0.1, + defaultValue: undefined + }, + { + optionName: 'categoryAllowlist', + protoName: 'categoryAllowlistList', + customValue: ['foo'], + defaultValue: [] + }, + { + optionName: 'categoryDenylist', + protoName: 'categoryDenylistList', + customValue: ['bar'], + defaultValue: [] + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, async () => { + await objectDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(objectDetector, [testCase.protoName, testCase.customValue]); + }); + + it(`can clear ${testCase.optionName}`, async () => { + await objectDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(objectDetector, [testCase.protoName, testCase.customValue]); + await objectDetector.setOptions({[testCase.optionName]: undefined}); + verifyGraph( + objectDetector, [testCase.protoName, testCase.defaultValue]); + }); + } + }); + + it('transforms results', async () => { + const detectionProtos: Uint8Array[] = []; + + // Add a detection with all optional properties + let detection = new DetectionProto(); + detection.addScore(0.1); + detection.addLabelId(1); + detection.addLabel('foo'); + detection.addDisplayName('bar'); + let locationData = new LocationData(); + let boundingBox = new LocationData.BoundingBox(); + boundingBox.setXmin(1); + boundingBox.setYmin(2); + boundingBox.setWidth(3); + boundingBox.setHeight(4); + locationData.setBoundingBox(boundingBox); + detection.setLocationData(locationData); + detectionProtos.push(detection.serializeBinary()); + + // Add a detection without optional properties + detection = new DetectionProto(); + detection.addScore(0.2); + locationData = new LocationData(); + boundingBox = new LocationData.BoundingBox(); + locationData.setBoundingBox(boundingBox); + detection.setLocationData(locationData); + detectionProtos.push(detection.serializeBinary()); + + // Pass the test data to our listener + objectDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(objectDetector); + objectDetector.protoListener!(detectionProtos); + }); + + // Invoke the object detector + const detections = objectDetector.detect({} as HTMLImageElement); + + expect(objectDetector.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(detections.length).toEqual(2); + expect(detections[0]).toEqual({ + categories: [{ + score: 0.1, + index: 1, + categoryName: 'foo', + displayName: 'bar', + }], + boundingBox: {originX: 1, originY: 2, width: 3, height: 4} + }); + expect(detections[1]).toEqual({ + categories: [{ + score: 0.2, + index: -1, + categoryName: '', + displayName: '', + }], + boundingBox: {originX: 0, originY: 0, width: 0, height: 0} + }); + }); +}); From d1820320b15893a0f0b947ed208bdcfb630bb938 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 8 Dec 2022 10:23:53 +0530 Subject: [PATCH 124/346] Added base options --- mediapipe/tasks/ios/core/BUILD | 33 ++++++++++++ .../tasks/ios/core/sources/MPPBaseOptions.h | 51 +++++++++++++++++++ .../tasks/ios/core/sources/MPPBaseOptions.m | 36 +++++++++++++ .../tasks/ios/core/sources/MPPExternalFile.h | 28 ++++++++++ .../tasks/ios/core/sources/MPPExternalFile.m | 27 ++++++++++ 5 files changed, 175 insertions(+) create mode 100644 mediapipe/tasks/ios/core/BUILD create mode 100644 mediapipe/tasks/ios/core/sources/MPPBaseOptions.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPBaseOptions.m create mode 100644 mediapipe/tasks/ios/core/sources/MPPExternalFile.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPExternalFile.m diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD new file mode 100644 index 000000000..9b8ad7bec --- /dev/null +++ b/mediapipe/tasks/ios/core/BUILD @@ -0,0 +1,33 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPExternalFile", + srcs = ["sources/MPPExternalFile.m"], + hdrs = ["sources/MPPExternalFile.h"], +) + +objc_library( + name = "MPPBaseOptions", + srcs = ["sources/MPPBaseOptions.m"], + hdrs = ["sources/MPPBaseOptions.h"], + deps = [ + ":MPPExternalFile", + + ], +) diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h new file mode 100644 index 000000000..87b6826df --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h @@ -0,0 +1,51 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import +#import "mediapipe/tasks/ios/core/sources/MPPExternalFile.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * MediaPipe Tasks delegate. + */ +typedef NS_ENUM(NSUInteger, MPPDelegate) { + + /** CPU. */ + MPPDelegateCPU, + + /** GPU. */ + MPPDelegateGPU +} NS_SWIFT_NAME(Delegate); + +/** + * Holds the base options that is used for creation of any type of task. It has fields with + * important information acceleration configuration, TFLite model source etc. + */ +NS_SWIFT_NAME(BaseOptions) +@interface MPPBaseOptions : NSObject + +/** + * The external model file, as a single standalone TFLite file. It could be packed with TFLite Model + * Metadata[1] and associated files if exist. Fail to provide the necessary metadata and associated + * files might result in errors. + */ +@property(nonatomic, copy) MPPExternalFile *modelAssetFile; + +/** + * device delegate to run the MediaPipe pipeline. If the delegate is not set, the default + * delegate CPU is used. + */ +@property(nonatomic) MPPDelegate delegate; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m new file mode 100644 index 000000000..4c25b80e8 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m @@ -0,0 +1,36 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +@implementation MPPBaseOptions + +- (instancetype)init { + self = [super init]; + if (self) { + self.modelAssetFile = [[MPPExternalFile alloc] init]; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init]; + + baseOptions.modelAssetFile = self.modelAssetFile; + baseOptions.delegate = self.delegate; + + return baseOptions; +} + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPExternalFile.h b/mediapipe/tasks/ios/core/sources/MPPExternalFile.h new file mode 100644 index 000000000..a97802002 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPExternalFile.h @@ -0,0 +1,28 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * Holds information about an external file. + */ +NS_SWIFT_NAME(ExternalFile) +@interface MPPExternalFile : NSObject + +/** Path to the file in bundle. */ +@property(nonatomic, copy) NSString *filePath; +/// Add provision for other sources in future. + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPExternalFile.m b/mediapipe/tasks/ios/core/sources/MPPExternalFile.m new file mode 100644 index 000000000..70d85657c --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPExternalFile.m @@ -0,0 +1,27 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/core/sources/MPPExternalFile.h" + +@implementation MPPExternalFile + +- (id)copyWithZone:(NSZone *)zone { + MPPExternalFile *externalFile = [[MPPExternalFile alloc] init]; + + externalFile.filePath = self.filePath; + + return externalFile; +} + +@end From 66dbd9969a0aae7f71ce7096135a3b436ea76473 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 8 Dec 2022 10:25:01 +0530 Subject: [PATCH 125/346] Updated license text --- .../tasks/ios/core/sources/MPPBaseOptions.h | 25 +++++++++++-------- .../tasks/ios/core/sources/MPPExternalFile.h | 25 +++++++++++-------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h index 87b6826df..258b49b3b 100644 --- a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h @@ -1,14 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #import #import "mediapipe/tasks/ios/core/sources/MPPExternalFile.h" diff --git a/mediapipe/tasks/ios/core/sources/MPPExternalFile.h b/mediapipe/tasks/ios/core/sources/MPPExternalFile.h index a97802002..300fd4778 100644 --- a/mediapipe/tasks/ios/core/sources/MPPExternalFile.h +++ b/mediapipe/tasks/ios/core/sources/MPPExternalFile.h @@ -1,14 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #import NS_ASSUME_NONNULL_BEGIN From 13f8fa51393a2883ec825ad717bfffb693d59376 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 8 Dec 2022 07:59:46 -0800 Subject: [PATCH 126/346] Retire the visibility group "//mediapipe/framework:mediapipe_internal" in the "mediapipe/calculators/tensor" dir. PiperOrigin-RevId: 493895834 --- mediapipe/calculators/tensor/BUILD | 76 ++++-------------------------- 1 file changed, 8 insertions(+), 68 deletions(-) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 577ac4111..dec68deac 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -24,7 +24,7 @@ load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) exports_files( glob(["testdata/image_to_tensor/*"]), @@ -44,9 +44,6 @@ selects.config_setting_group( mediapipe_proto_library( name = "audio_to_tensor_calculator_proto", srcs = ["audio_to_tensor_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -64,9 +61,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":audio_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -113,9 +107,6 @@ cc_test( mediapipe_proto_library( name = "tensors_to_audio_calculator_proto", srcs = ["tensors_to_audio_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -125,9 +116,6 @@ mediapipe_proto_library( cc_library( name = "tensors_to_audio_calculator", srcs = ["tensors_to_audio_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":tensors_to_audio_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -164,9 +152,6 @@ cc_test( mediapipe_proto_library( name = "feedback_tensors_calculator_proto", srcs = ["feedback_tensors_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -184,9 +169,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":feedback_tensors_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -216,9 +198,6 @@ cc_test( mediapipe_proto_library( name = "bert_preprocessor_calculator_proto", srcs = ["bert_preprocessor_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -228,9 +207,6 @@ mediapipe_proto_library( cc_library( name = "bert_preprocessor_calculator", srcs = ["bert_preprocessor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":bert_preprocessor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -274,9 +250,6 @@ cc_test( mediapipe_proto_library( name = "regex_preprocessor_calculator_proto", srcs = ["regex_preprocessor_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -286,9 +259,6 @@ mediapipe_proto_library( cc_library( name = "regex_preprocessor_calculator", srcs = ["regex_preprocessor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":regex_preprocessor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -330,9 +300,6 @@ cc_test( cc_library( name = "text_to_tensor_calculator", srcs = ["text_to_tensor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", @@ -366,9 +333,6 @@ cc_test( cc_library( name = "universal_sentence_encoder_preprocessor_calculator", srcs = ["universal_sentence_encoder_preprocessor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", @@ -408,7 +372,6 @@ cc_test( mediapipe_proto_library( name = "inference_calculator_proto", srcs = ["inference_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -435,7 +398,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_calculator_cc_proto", ":inference_calculator_options_lib", @@ -460,7 +422,6 @@ cc_library( name = "inference_calculator_gl", srcs = ["inference_calculator_gl.cc"], tags = ["nomac"], # config problem with cpuinfo via TF - visibility = ["//visibility:public"], deps = [ ":inference_calculator_cc_proto", ":inference_calculator_interface", @@ -478,7 +439,6 @@ cc_library( name = "inference_calculator_gl_advanced", srcs = ["inference_calculator_gl_advanced.cc"], tags = ["nomac"], - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", "@com_google_absl//absl/memory", @@ -509,7 +469,6 @@ cc_library( "-framework MetalKit", ], tags = ["ios"], - visibility = ["//visibility:public"], deps = [ "inference_calculator_interface", "//mediapipe/gpu:MPPMetalHelper", @@ -538,7 +497,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework/formats:tensor", @@ -558,7 +516,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_runner", "//mediapipe/framework:mediapipe_profiling", @@ -588,7 +545,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", ":inference_calculator_utils", @@ -635,7 +591,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", ":inference_calculator_utils", @@ -651,7 +606,6 @@ cc_library( cc_library( name = "inference_calculator_gl_if_compute_shader_available", - visibility = ["//visibility:public"], deps = selects.with_or({ ":compute_shader_unavailable": [], "//conditions:default": [ @@ -667,7 +621,6 @@ cc_library( # inference_calculator_interface. cc_library( name = "inference_calculator", - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", ":inference_calculator_cpu", @@ -681,7 +634,6 @@ cc_library( mediapipe_proto_library( name = "tensor_converter_calculator_proto", srcs = ["tensor_converter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -706,7 +658,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensor_converter_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -725,6 +676,7 @@ cc_library( cc_library( name = "tensor_converter_calculator_gpu_deps", + visibility = ["//visibility:private"], deps = select({ "//mediapipe:android": [ "//mediapipe/gpu:gl_calculator_helper", @@ -769,7 +721,6 @@ cc_test( mediapipe_proto_library( name = "tensors_to_detections_calculator_proto", srcs = ["tensors_to_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -794,7 +745,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -817,6 +767,7 @@ cc_library( cc_library( name = "tensors_to_detections_calculator_gpu_deps", + visibility = ["//visibility:private"], deps = select({ "//mediapipe:ios": [ "//mediapipe/gpu:MPPMetalUtil", @@ -832,7 +783,6 @@ cc_library( mediapipe_proto_library( name = "tensors_to_landmarks_calculator_proto", srcs = ["tensors_to_landmarks_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -849,7 +799,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_landmarks_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -864,7 +813,6 @@ cc_library( mediapipe_proto_library( name = "landmarks_to_tensor_calculator_proto", srcs = ["landmarks_to_tensor_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -882,7 +830,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":landmarks_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -915,7 +862,6 @@ cc_test( mediapipe_proto_library( name = "tensors_to_floats_calculator_proto", srcs = ["tensors_to_floats_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -932,7 +878,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_floats_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -970,7 +915,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_classification_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -1001,7 +945,6 @@ cc_library( mediapipe_proto_library( name = "tensors_to_classification_calculator_proto", srcs = ["tensors_to_classification_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1039,7 +982,6 @@ cc_library( "//conditions:default": [], }), features = ["-layering_check"], # allow depending on image_to_tensor_calculator_gpu_deps - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_calculator_cc_proto", ":image_to_tensor_converter", @@ -1068,6 +1010,7 @@ cc_library( cc_library( name = "image_to_tensor_calculator_gpu_deps", + visibility = ["//visibility:private"], deps = selects.with_or({ "//mediapipe:android": [ ":image_to_tensor_converter_gl_buffer", @@ -1091,7 +1034,6 @@ cc_library( mediapipe_proto_library( name = "image_to_tensor_calculator_proto", srcs = ["image_to_tensor_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1154,7 +1096,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_utils", "//mediapipe/framework/formats:image", @@ -1174,7 +1115,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_converter", ":image_to_tensor_utils", @@ -1194,6 +1134,7 @@ cc_library( name = "image_to_tensor_converter_gl_buffer", srcs = ["image_to_tensor_converter_gl_buffer.cc"], hdrs = ["image_to_tensor_converter_gl_buffer.h"], + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + selects.with_or({ "//mediapipe:apple": [], "//conditions:default": [ @@ -1227,6 +1168,7 @@ cc_library( name = "image_to_tensor_converter_gl_texture", srcs = ["image_to_tensor_converter_gl_texture.cc"], hdrs = ["image_to_tensor_converter_gl_texture.h"], + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [ @@ -1251,6 +1193,7 @@ cc_library( name = "image_to_tensor_converter_gl_utils", srcs = ["image_to_tensor_converter_gl_utils.cc"], hdrs = ["image_to_tensor_converter_gl_utils.h"], + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [ @@ -1280,6 +1223,7 @@ cc_library( ], "//conditions:default": [], }), + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + select({ "//mediapipe:apple": [ ":image_to_tensor_converter", @@ -1311,7 +1255,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_calculator_cc_proto", "@com_google_absl//absl/status", @@ -1354,7 +1297,6 @@ selects.config_setting_group( mediapipe_proto_library( name = "tensors_to_segmentation_calculator_proto", srcs = ["tensors_to_segmentation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1372,7 +1314,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_segmentation_calculator_cc_proto", "@com_google_absl//absl/strings:str_format", @@ -1430,7 +1371,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", From a641ea12e15ec8e3ff552647ca569dc1ee9f59bc Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 8 Dec 2022 11:30:39 -0800 Subject: [PATCH 127/346] Update gesture recognizer to new mediapipe tasks pipeline PiperOrigin-RevId: 493950564 --- .../python/vision/gesture_recognizer/BUILD | 14 ++-- .../vision/gesture_recognizer/dataset.py | 62 ++++++++++------- .../vision/gesture_recognizer/dataset_test.py | 67 ++++++++----------- .../gesture_recognizer/metadata_writer.py | 62 +++++++++++++---- .../metadata_writer_test.py | 17 +++++ mediapipe/tasks/python/core/BUILD | 8 +-- mediapipe/tasks/python/vision/BUILD | 4 ++ 7 files changed, 147 insertions(+), 87 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index 256447a8d..9123e36b0 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -35,20 +35,21 @@ py_library( srcs = ["constants.py"], ) -# TODO: Change to py_library after migrating the MediaPipe hand solution -# library to MediaPipe hand task library. py_library( name = "dataset", srcs = ["dataset.py"], deps = [ ":constants", + ":metadata_writer", "//mediapipe/model_maker/python/core/data:classification_dataset", - "//mediapipe/model_maker/python/core/data:data_util", "//mediapipe/model_maker/python/core/utils:model_util", - "//mediapipe/python/solutions:hands", + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/vision:hand_landmarker", ], ) +# TODO: Remove notsan tag once tasks no longer has race condition issue py_test( name = "dataset_test", srcs = ["dataset_test.py"], @@ -56,10 +57,11 @@ py_test( ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], + tags = ["notsan"], deps = [ ":dataset", - "//mediapipe/python/solutions:hands", "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/vision:hand_landmarker", ], ) @@ -131,6 +133,7 @@ py_library( ], ) +# TODO: Remove notsan tag once tasks no longer has race condition issue py_test( name = "gesture_recognizer_test", size = "large", @@ -140,6 +143,7 @@ py_test( "//mediapipe/model_maker/models/gesture_recognizer:models", ], shard_count = 2, + tags = ["notsan"], deps = [ ":gesture_recognizer_import", "//mediapipe/model_maker/python/core/utils:test_util", diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py index 256f26fd6..6a2c878c0 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py @@ -16,16 +16,22 @@ import dataclasses import os import random -from typing import List, NamedTuple, Optional +from typing import List, Optional -import cv2 import tensorflow as tf from mediapipe.model_maker.python.core.data import classification_dataset -from mediapipe.model_maker.python.core.data import data_util from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.vision.gesture_recognizer import constants -from mediapipe.python.solutions import hands as mp_hands +from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.vision import hand_landmarker as hand_landmarker_module + +_Image = image_module.Image +_HandLandmarker = hand_landmarker_module.HandLandmarker +_HandLandmarkerOptions = hand_landmarker_module.HandLandmarkerOptions +_HandLandmarkerResult = hand_landmarker_module.HandLandmarkerResult @dataclasses.dataclass @@ -59,7 +65,7 @@ class HandData: handedness: List[float] -def _validate_data_sample(data: NamedTuple) -> bool: +def _validate_data_sample(data: _HandLandmarkerResult) -> bool: """Validates the input hand data sample. Args: @@ -70,19 +76,17 @@ def _validate_data_sample(data: NamedTuple) -> bool: 'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness' or any of these attributes' values are none. Otherwise, True. """ - if (not hasattr(data, 'multi_hand_landmarks') or - data.multi_hand_landmarks is None): + if data.hand_landmarks is None or not data.hand_landmarks: return False - if (not hasattr(data, 'multi_hand_world_landmarks') or - data.multi_hand_world_landmarks is None): + if data.hand_world_landmarks is None or not data.hand_world_landmarks: return False - if not hasattr(data, 'multi_handedness') or data.multi_handedness is None: + if data.handedness is None or not data.handedness: return False return True def _get_hand_data(all_image_paths: List[str], - min_detection_confidence: float) -> Optional[HandData]: + min_detection_confidence: float) -> List[Optional[HandData]]: """Computes hand data (landmarks and handedness) in the input image. Args: @@ -93,28 +97,36 @@ def _get_hand_data(all_image_paths: List[str], A HandData object. Returns None if no hand is detected. """ hand_data_result = [] - with mp_hands.Hands( - static_image_mode=True, - max_num_hands=1, - min_detection_confidence=min_detection_confidence) as hands: + hand_detector_model_buffer = model_util.load_tflite_model_buffer( + constants.HAND_DETECTOR_TFLITE_MODEL_FILE) + hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer( + constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE) + hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer) + hand_landmarker_options = _HandLandmarkerOptions( + base_options=base_options_module.BaseOptions( + model_asset_buffer=hand_landmarker_writer.populate()), + num_hands=1, + min_hand_detection_confidence=min_detection_confidence, + min_hand_presence_confidence=0.5, + min_tracking_confidence=1, + ) + with _HandLandmarker.create_from_options( + hand_landmarker_options) as hand_landmarker: for path in all_image_paths: tf.compat.v1.logging.info('Loading image %s', path) - image = data_util.load_image(path) - # Flip image around y-axis for correct handedness output - image = cv2.flip(image, 1) - data = hands.process(image) + image = _Image.create_from_file(path) + data = hand_landmarker.detect(image) if not _validate_data_sample(data): hand_data_result.append(None) continue - hand_landmarks = [[ - hand_landmark.x, hand_landmark.y, hand_landmark.z - ] for hand_landmark in data.multi_hand_landmarks[0].landmark] + hand_landmarks = [[hand_landmark.x, hand_landmark.y, hand_landmark.z] + for hand_landmark in data.hand_landmarks[0]] hand_world_landmarks = [[ hand_landmark.x, hand_landmark.y, hand_landmark.z - ] for hand_landmark in data.multi_hand_world_landmarks[0].landmark] + ] for hand_landmark in data.hand_world_landmarks[0]] handedness_scores = [ - handedness.score - for handedness in data.multi_handedness[0].classification + handedness.score for handedness in data.handedness[0] ] hand_data_result.append( HandData( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py index 76e70a58d..528d02edd 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py @@ -12,21 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import os import shutil from typing import NamedTuple import unittest -from absl import flags from absl.testing import parameterized import tensorflow as tf from mediapipe.model_maker.python.vision.gesture_recognizer import dataset -from mediapipe.python.solutions import hands as mp_hands from mediapipe.tasks.python.test import test_utils - -FLAGS = flags.FLAGS +from mediapipe.tasks.python.vision import hand_landmarker _TEST_DATA_DIRNAME = 'raw_data' @@ -39,14 +35,14 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams()) train_data, test_data = data.split(0.5) - self.assertLen(train_data, 17) + self.assertLen(train_data, 16) for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) self.assertEqual(train_data.num_classes, 4) self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock']) - self.assertLen(test_data, 18) + self.assertLen(test_data, 16) for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) @@ -60,7 +56,7 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): for _, elem in enumerate(data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) - self.assertLen(data, 35) + self.assertLen(data, 32) self.assertEqual(data.num_classes, 4) self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock']) @@ -105,51 +101,42 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): for _, elem in enumerate(data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) - self.assertLen(data, 35) + self.assertLen(data, 32) self.assertEqual(data.num_classes, 4) self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK']) @parameterized.named_parameters( dict( - testcase_name='invalid_field_name_multi_hand_landmark', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmark', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(1, 2, 3)), + testcase_name='none_handedness', + hand=hand_landmarker.HandLandmarkerResult( + handedness=None, hand_landmarks=[[2]], + hand_world_landmarks=[[3]])), dict( - testcase_name='invalid_field_name_multi_hand_world_landmarks', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmark', - 'multi_handedness' - ])(1, 2, 3)), + testcase_name='none_hand_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=None, + hand_world_landmarks=[[3]])), dict( - testcase_name='invalid_field_name_multi_handed', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handed' - ])(1, 2, 3)), + testcase_name='none_hand_world_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=[[2]], + hand_world_landmarks=None)), dict( - testcase_name='multi_hand_landmarks_is_none', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(None, 2, 3)), + testcase_name='empty_handedness', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[], hand_landmarks=[[2]], hand_world_landmarks=[[3]])), dict( - testcase_name='multi_hand_world_landmarks_is_none', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(1, None, 3)), + testcase_name='empty_hand_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=[], hand_world_landmarks=[[3]])), dict( - testcase_name='multi_handedness_is_none', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(1, 2, None)), + testcase_name='empty_hand_world_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])), ) def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple): with unittest.mock.patch.object( - mp_hands.Hands, 'process', return_value=hand): + hand_landmarker.HandLandmarker, 'detect', return_value=hand): input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME) with self.assertRaisesRegex(ValueError, 'No valid hand is detected'): dataset.Dataset.from_folder( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py index 58b67e072..b2e851afe 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py @@ -62,6 +62,50 @@ def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]: return f.read() +class HandLandmarkerMetadataWriter: + """MetadataWriter to write the model asset bundle for HandLandmarker.""" + + def __init__( + self, + hand_detector_model_buffer: bytearray, + hand_landmarks_detector_model_buffer: bytearray, + ) -> None: + """Initializes HandLandmarkerMetadataWriter to write model asset bundle. + + Args: + hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from + the TFLite hand detector model file. + hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata + loaded from the TFLite hand landmarks detector model file. + """ + self._hand_detector_model_buffer = hand_detector_model_buffer + self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer + self._temp_folder = tempfile.TemporaryDirectory() + + def __del__(self): + if os.path.exists(self._temp_folder.name): + self._temp_folder.cleanup() + + def populate(self): + """Creates the model asset bundle for hand landmarker task. + + Returns: + Model asset bundle in bytes + """ + landmark_models = { + _HAND_DETECTOR_TFLITE_NAME: + self._hand_detector_model_buffer, + _HAND_LANDMARKS_DETECTOR_TFLITE_NAME: + self._hand_landmarks_detector_model_buffer + } + output_hand_landmarker_path = os.path.join(self._temp_folder.name, + _HAND_LANDMARKER_BUNDLE_NAME) + writer_utils.create_model_asset_bundle(landmark_models, + output_hand_landmarker_path) + hand_landmarker_model_buffer = read_file(output_hand_landmarker_path) + return hand_landmarker_model_buffer + + class MetadataWriter: """MetadataWriter to write the metadata and the model asset bundle.""" @@ -86,8 +130,8 @@ class MetadataWriter: custom_gesture_classifier_metadata_writer: Metadata writer to write custom gesture classifier metadata into the TFLite file. """ - self._hand_detector_model_buffer = hand_detector_model_buffer - self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer + self._hand_landmarker_metadata_writer = HandLandmarkerMetadataWriter( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer) self._gesture_embedder_model_buffer = gesture_embedder_model_buffer self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer @@ -147,16 +191,8 @@ class MetadataWriter: A tuple of (model_asset_bundle_in_bytes, metadata_json_content) """ # Creates the model asset bundle for hand landmarker task. - landmark_models = { - _HAND_DETECTOR_TFLITE_NAME: - self._hand_detector_model_buffer, - _HAND_LANDMARKS_DETECTOR_TFLITE_NAME: - self._hand_landmarks_detector_model_buffer - } - output_hand_landmarker_path = os.path.join(self._temp_folder.name, - _HAND_LANDMARKER_BUNDLE_NAME) - writer_utils.create_model_asset_bundle(landmark_models, - output_hand_landmarker_path) + hand_landmarker_model_buffer = self._hand_landmarker_metadata_writer.populate( + ) # Write metadata into custom gesture classifier model. self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate( @@ -179,7 +215,7 @@ class MetadataWriter: # graph. gesture_recognizer_models = { _HAND_LANDMARKER_BUNDLE_NAME: - read_file(output_hand_landmarker_path), + hand_landmarker_model_buffer, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME: read_file(output_hand_gesture_recognizer_path), } diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py index 83998141d..fd26b274d 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py @@ -33,6 +33,23 @@ _CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path( class MetadataWriterTest(tf.test.TestCase): + def test_hand_landmarker_metadata_writer(self): + # Use dummy model buffer for unit test only. + hand_detector_model_buffer = b"\x11\x12" + hand_landmarks_detector_model_buffer = b"\x22" + writer = metadata_writer.HandLandmarkerMetadataWriter( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer) + model_bundle_content = writer.populate() + model_bundle_filepath = os.path.join(self.get_temp_dir(), + "hand_landmarker.task") + with open(model_bundle_filepath, "wb") as f: + f.write(model_bundle_content) + + with zipfile.ZipFile(model_bundle_filepath) as zf: + self.assertEqual( + set(zf.namelist()), + set(["hand_landmarks_detector.tflite", "hand_detector.tflite"])) + def test_write_metadata_and_create_model_asset_bundle_successful(self): # Use dummy model buffer for unit test only. hand_detector_model_buffer = b"\x11\x12" diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index 447189d6f..f14d59b99 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -23,15 +23,15 @@ py_library( srcs = [ "optional_dependencies.py", ], - deps = [ - "@org_tensorflow//tensorflow/tools/docs:doc_controls", - ], ) py_library( name = "base_options", srcs = ["base_options.py"], - visibility = ["//mediapipe/tasks:users"], + visibility = [ + "//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__", + "//mediapipe/tasks:users", + ], deps = [ ":optional_dependencies", "//mediapipe/tasks/cc/core/proto:base_options_py_pb2", diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 5f4aa38ff..eda8e290d 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -131,6 +131,10 @@ py_library( srcs = [ "hand_landmarker.py", ], + visibility = [ + "//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__", + "//mediapipe/tasks:internal", + ], deps = [ "//mediapipe/framework/formats:classification_py_pb2", "//mediapipe/framework/formats:landmark_py_pb2", From 0fbaa8dc8a0220d682081d70fd01bef71709a316 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Thu, 8 Dec 2022 12:59:46 -0800 Subject: [PATCH 128/346] Internal change. PiperOrigin-RevId: 493973435 --- .../framework/formats/tensor_ahwb_gpu_test.cc | 143 ++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 mediapipe/framework/formats/tensor_ahwb_gpu_test.cc diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc new file mode 100644 index 000000000..dd865a367 --- /dev/null +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -0,0 +1,143 @@ + +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) +#include + +#include + +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/formats/tensor_data_types.h" +#include "mediapipe/gpu/gpu_test_base.h" +#include "mediapipe/gpu/shader_util.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "testing/base/public/gunit.h" + +// The test creates OpenGL ES buffer, fills the buffer with incrementing values +// 0.0, 0.1, 0.2 etc. with the compute shader on GPU. +// Then the test requests the CPU view and compares the values. +// Float32 and Float16 tests are there. + +namespace { + +using mediapipe::Float16; +using mediapipe::Tensor; + +MATCHER_P(NearWithPrecision, precision, "") { + return std::abs(std::get<0>(arg) - std::get<1>(arg)) < precision; +} + +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + +// Utility function to fill the GPU buffer. +void FillGpuBuffer(GLuint name, std::size_t size, + const Tensor::ElementType fmt) { + std::string shader_source; + if (fmt == Tensor::ElementType::kFloat32) { + shader_source = R"( #version 310 es + precision highp float; + layout(local_size_x = 1, local_size_y = 1) in; + layout(std430, binding = 0) buffer Output {float elements[];} output_data; + void main() { + uint v = gl_GlobalInvocationID.x * 2u; + output_data.elements[v] = float(v) / 10.0; + output_data.elements[v + 1u] = float(v + 1u) / 10.0; + })"; + } else { + shader_source = R"( #version 310 es + precision highp float; + layout(local_size_x = 1, local_size_y = 1) in; + layout(std430, binding = 0) buffer Output {float elements[];} output_data; + void main() { + uint v = gl_GlobalInvocationID.x; + uint tmp = packHalf2x16(vec2((float(v)* 2.0 + 0.0) / 10.0, + (float(v) * 2.0 + 1.0) / 10.0)); + output_data.elements[v] = uintBitsToFloat(tmp); + })"; + } + GLuint shader; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCreateShader, &shader, GL_COMPUTE_SHADER)); + const GLchar* sources[] = {shader_source.c_str()}; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glShaderSource, shader, 1, sources, nullptr)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCompileShader, shader)); + GLint is_compiled = 0; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_COMPILE_STATUS, + &is_compiled)); + if (is_compiled == GL_FALSE) { + GLint max_length = 0; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_INFO_LOG_LENGTH, + &max_length)); + std::vector error_log(max_length); + glGetShaderInfoLog(shader, max_length, &max_length, error_log.data()); + glDeleteShader(shader); + FAIL() << error_log.data(); + return; + } + GLuint to_buffer_program; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCreateProgram, &to_buffer_program)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glAttachShader, to_buffer_program, shader)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteShader, shader)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glLinkProgram, to_buffer_program)); + + MP_ASSERT_OK( + TFLITE_GPU_CALL_GL(glBindBufferBase, GL_SHADER_STORAGE_BUFFER, 0, name)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glUseProgram, to_buffer_program)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDispatchCompute, size / 2, 1, 1)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glBindBuffer, GL_SHADER_STORAGE_BUFFER, 0)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteProgram, to_buffer_program)); +} + +class TensorAhwbGpuTest : public mediapipe::GpuTestBase { + public: +}; + +TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { + Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + RunInGlContext([&tensor] { + auto ssbo_view = tensor.GetOpenGlBufferWriteView(); + auto ssbo_name = ssbo_view.name(); + EXPECT_GT(ssbo_name, 0); + FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), + tensor.element_type()); + }); + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(testing::FloatEq(), reference)); +} + +TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { + Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})}; + RunInGlContext([&tensor] { + auto ssbo_view = tensor.GetOpenGlBufferWriteView(); + auto ssbo_name = ssbo_view.name(); + EXPECT_GT(ssbo_name, 0); + FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), + tensor.element_type()); + }); + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + // Precision is set to a reasonable value for Float16. + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(NearWithPrecision(0.001), reference)); +} + +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +} // namespace + +#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || + // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) From b4e1969e4381053038322275bfb8f15d855da9f2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 8 Dec 2022 14:01:17 -0800 Subject: [PATCH 129/346] Add pip package builder for model_maker PiperOrigin-RevId: 493989013 --- mediapipe/model_maker/MANIFEST.in | 1 + mediapipe/model_maker/__init__.py | 6 + .../python/core/utils/file_util.py | 19 ++- mediapipe/model_maker/python/text/__init__.py | 13 ++ mediapipe/model_maker/requirements.txt | 6 +- mediapipe/model_maker/setup.py | 147 ++++++++++++++++++ 6 files changed, 184 insertions(+), 8 deletions(-) create mode 100644 mediapipe/model_maker/MANIFEST.in create mode 100644 mediapipe/model_maker/python/text/__init__.py create mode 100644 mediapipe/model_maker/setup.py diff --git a/mediapipe/model_maker/MANIFEST.in b/mediapipe/model_maker/MANIFEST.in new file mode 100644 index 000000000..54ce01aff --- /dev/null +++ b/mediapipe/model_maker/MANIFEST.in @@ -0,0 +1 @@ +recursive-include pip_src/mediapipe_model_maker/models * diff --git a/mediapipe/model_maker/__init__.py b/mediapipe/model_maker/__init__.py index 7ca2f9216..9899a145b 100644 --- a/mediapipe/model_maker/__init__.py +++ b/mediapipe/model_maker/__init__.py @@ -11,3 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.vision import image_classifier +from mediapipe.model_maker.python.vision import gesture_recognizer +from mediapipe.model_maker.python.text import text_classifier diff --git a/mediapipe/model_maker/python/core/utils/file_util.py b/mediapipe/model_maker/python/core/utils/file_util.py index bccf928e2..66addad54 100644 --- a/mediapipe/model_maker/python/core/utils/file_util.py +++ b/mediapipe/model_maker/python/core/utils/file_util.py @@ -19,7 +19,7 @@ import os def get_absolute_path(file_path: str) -> str: - """Gets the absolute path of a file. + """Gets the absolute path of a file in the model_maker directory. Args: file_path: The path to a file relative to the `mediapipe` dir @@ -27,10 +27,17 @@ def get_absolute_path(file_path: str) -> str: Returns: The full path of the file """ - # Extract the file path before mediapipe/ as the `base_dir`. By joining it - # with the `path` which defines the relative path under mediapipe/, it - # yields to the absolute path of the model files directory. + # Extract the file path before and including 'model_maker' as the + # `mm_base_dir`. By joining it with the `path` after 'model_maker/', it + # yields to the absolute path of the model files directory. We must join + # on 'model_maker' because in the pypi package, the 'model_maker' directory + # is renamed to 'mediapipe_model_maker'. So we have to join on model_maker + # to ensure that the `mm_base_dir` path includes the renamed + # 'mediapipe_model_maker' directory. cwd = os.path.dirname(__file__) - base_dir = cwd[:cwd.rfind('mediapipe')] - absolute_path = os.path.join(base_dir, file_path) + cwd_stop_idx = cwd.rfind('model_maker') + len('model_maker') + mm_base_dir = cwd[:cwd_stop_idx] + file_path_start_idx = file_path.find('model_maker') + len('model_maker') + 1 + mm_relative_path = file_path[file_path_start_idx:] + absolute_path = os.path.join(mm_base_dir, mm_relative_path) return absolute_path diff --git a/mediapipe/model_maker/python/text/__init__.py b/mediapipe/model_maker/python/text/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/text/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt index 389ee484a..9b3c9f906 100644 --- a/mediapipe/model_maker/requirements.txt +++ b/mediapipe/model_maker/requirements.txt @@ -1,6 +1,8 @@ absl-py +mediapipe==0.9.1 numpy -opencv-contrib-python -tensorflow +opencv-python +tensorflow>=2.10 tensorflow-datasets tensorflow-hub +tf-models-official>=2.10.1 diff --git a/mediapipe/model_maker/setup.py b/mediapipe/model_maker/setup.py new file mode 100644 index 000000000..ea193db94 --- /dev/null +++ b/mediapipe/model_maker/setup.py @@ -0,0 +1,147 @@ +"""Copyright 2020-2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Setup for Mediapipe-Model-Maker package with setuptools. +""" + +import glob +import os +import shutil +import subprocess +import sys +import setuptools + + +__version__ = 'dev' +MM_ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) +# Build dir to copy all necessary files and build package +SRC_NAME = 'pip_src' +BUILD_DIR = os.path.join(MM_ROOT_PATH, SRC_NAME) +BUILD_MM_DIR = os.path.join(BUILD_DIR, 'mediapipe_model_maker') + + +def _parse_requirements(path): + with open(os.path.join(MM_ROOT_PATH, path)) as f: + return [ + line.rstrip() + for line in f + if not (line.isspace() or line.startswith('#')) + ] + + +def _copy_to_pip_src_dir(file): + """Copy a file from bazel-bin to the pip_src dir.""" + dst = file + dst_dir = os.path.dirname(dst) + if not os.path.exists(dst_dir): + os.makedirs(dst_dir) + src_file = os.path.join('../../bazel-bin/mediapipe/model_maker', file) + shutil.copyfile(src_file, file) + + +def _setup_build_dir(): + """Setup the BUILD_DIR directory to build the mediapipe_model_maker package. + + We need to create a new BUILD_DIR directory because any references to the path + `mediapipe/model_maker` needs to be renamed to `mediapipe_model_maker` to + avoid conflicting with the mediapipe package name. + This setup function performs the following actions: + 1. Copy python source code into BUILD_DIR and rename imports to + mediapipe_model_maker + 2. Download models from GCS into BUILD_DIR + """ + # Copy python source code into BUILD_DIR + if os.path.exists(BUILD_DIR): + shutil.rmtree(BUILD_DIR) + python_files = glob.glob('python/**/*.py', recursive=True) + python_files.append('__init__.py') + for python_file in python_files: + # Exclude test files from pip package + if '_test.py' in python_file: + continue + build_target_file = os.path.join(BUILD_MM_DIR, python_file) + with open(python_file, 'r') as file: + filedata = file.read() + # Rename all mediapipe.model_maker imports to mediapipe_model_maker + filedata = filedata.replace('from mediapipe.model_maker', + 'from mediapipe_model_maker') + os.makedirs(os.path.dirname(build_target_file), exist_ok=True) + with open(build_target_file, 'w') as file: + file.write(filedata) + + # Use bazel to download GCS model files + model_build_files = ['models/gesture_recognizer/BUILD'] + for model_build_file in model_build_files: + build_target_file = os.path.join(BUILD_MM_DIR, model_build_file) + os.makedirs(os.path.dirname(build_target_file), exist_ok=True) + shutil.copy(model_build_file, build_target_file) + external_files = [ + 'models/gesture_recognizer/canned_gesture_classifier.tflite', + 'models/gesture_recognizer/gesture_embedder.tflite', + 'models/gesture_recognizer/hand_landmark_full.tflite', + 'models/gesture_recognizer/palm_detection_full.tflite', + 'models/gesture_recognizer/gesture_embedder/keras_metadata.pb', + 'models/gesture_recognizer/gesture_embedder/saved_model.pb', + 'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001', + 'models/gesture_recognizer/gesture_embedder/variables/variables.index', + ] + for elem in external_files: + external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem) + sys.stderr.write('downloading file: %s\n' % external_file) + fetch_model_command = [ + 'bazel', + 'build', + external_file, + ] + if subprocess.call(fetch_model_command) != 0: + sys.exit(-1) + _copy_to_pip_src_dir(external_file) + +_setup_build_dir() + +setuptools.setup( + name='mediapipe-model-maker', + version=__version__, + url='https://github.com/google/mediapipe/tree/master/mediapipe/model_maker', + description='MediaPipe Model Maker is a simple, low-code solution for customizing on-device ML models', + author='The MediaPipe Authors', + author_email='mediapipe@google.com', + long_description='', + long_description_content_type='text/markdown', + packages=setuptools.find_packages(where=SRC_NAME), + package_dir={'': SRC_NAME}, + install_requires=_parse_requirements('requirements.txt'), + include_package_data=True, + classifiers=[ + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: MacOS :: MacOS X', + 'Operating System :: Microsoft :: Windows', + 'Operating System :: POSIX :: Linux', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3 :: Only', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + license='Apache 2.0', + keywords=['mediapipe', 'model', 'maker'], +) From 05535db5f77bdc9c46df36855fe4064ded89d7cb Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 8 Dec 2022 15:01:34 -0800 Subject: [PATCH 130/346] Fix assertion failure in Hair Segmentation demo PiperOrigin-RevId: 494004801 --- mediapipe/graphs/hair_segmentation/BUILD | 1 + .../hair_segmentation_desktop_live.pbtxt | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mediapipe/graphs/hair_segmentation/BUILD b/mediapipe/graphs/hair_segmentation/BUILD index b177726bf..945f02c62 100644 --- a/mediapipe/graphs/hair_segmentation/BUILD +++ b/mediapipe/graphs/hair_segmentation/BUILD @@ -43,6 +43,7 @@ cc_library( deps = [ "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:previous_loopback_calculator", + "//mediapipe/calculators/image:color_convert_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/image:recolor_calculator", "//mediapipe/calculators/image:set_alpha_calculator", diff --git a/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt b/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt index 36c6970e1..f48b26be0 100644 --- a/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt +++ b/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt @@ -60,7 +60,14 @@ node { tag_index: "LOOP" back_edge: true } - output_stream: "PREV_LOOP:previous_hair_mask" + output_stream: "PREV_LOOP:previous_hair_mask_rgb" +} + +# Converts the 4 channel hair mask to a single channel mask +node { + calculator: "ColorConvertCalculator" + input_stream: "RGB_IN:previous_hair_mask_rgb" + output_stream: "GRAY_OUT:previous_hair_mask" } # Embeds the hair mask generated from the previous round of hair segmentation From bea0caae6586343eb91e986b84aa00ef75fa67b1 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Thu, 8 Dec 2022 17:05:06 -0800 Subject: [PATCH 131/346] Tensor: Cpu -> Ahwb storage transfer PiperOrigin-RevId: 494033280 --- mediapipe/framework/formats/tensor.cc | 4 +-- mediapipe/framework/formats/tensor.h | 2 +- mediapipe/framework/formats/tensor_ahwb.cc | 3 +- .../framework/formats/tensor_ahwb_gpu_test.cc | 28 +++++++++++++++++++ 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index 9e1406dbb..fdafbff5c 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -361,7 +361,7 @@ void Tensor::AllocateOpenGlBuffer() const { LOG_IF(FATAL, !gl_context_) << "GlContext is not bound to the thread."; glGenBuffers(1, &opengl_buffer_); glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); - if (!AllocateAhwbMapToSsbo()) { + if (!use_ahwb_ || !AllocateAhwbMapToSsbo()) { glBufferData(GL_SHADER_STORAGE_BUFFER, bytes(), NULL, GL_STREAM_COPY); } glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); @@ -610,7 +610,7 @@ Tensor::CpuWriteView Tensor::GetCpuWriteView() const { void Tensor::AllocateCpuBuffer() const { if (!cpu_buffer_) { #ifdef MEDIAPIPE_TENSOR_USE_AHWB - if (AllocateAHardwareBuffer()) return; + if (use_ahwb_ && AllocateAHardwareBuffer()) return; #endif // MEDIAPIPE_TENSOR_USE_AHWB #if MEDIAPIPE_METAL_ENABLED cpu_buffer_ = AllocateVirtualMemory(bytes()); diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 151aa299d..9d3e90b6a 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -409,8 +409,8 @@ class Tensor { bool AllocateAHardwareBuffer(int size_alignment = 0) const; void CreateEglSyncAndFd() const; // Use Ahwb for other views: OpenGL / CPU buffer. - static inline bool use_ahwb_ = false; #endif // MEDIAPIPE_TENSOR_USE_AHWB + static inline bool use_ahwb_ = false; // Expects the target SSBO to be already bound. bool AllocateAhwbMapToSsbo() const; bool InsertAhwbToSsboFence() const; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 21bae9593..3c3ec8b17 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -214,7 +214,7 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const { "supported."; CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer)) << "Interoperability bettween OpenGL buffer and AHardwareBuffer is not " - "supported on targe system."; + "supported on target system."; bool transfer = !ahwb_; CHECK(AllocateAHardwareBuffer()) << "AHardwareBuffer is not supported on the target system."; @@ -268,7 +268,6 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView( } bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { - if (!use_ahwb_) return false; if (__builtin_available(android 26, *)) { if (ahwb_ == nullptr) { AHardwareBuffer_Desc desc = {}; diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index dd865a367..7ccd9c7f5 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -136,6 +136,34 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { testing::Pointwise(NearWithPrecision(0.001), reference)); } +TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { + // Request the CPU view to get the memory to be allocated. + // Request Ahwb view then to transform the storage into Ahwb. + Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + { + auto ptr = tensor.GetCpuWriteView().buffer(); + EXPECT_NE(ptr, nullptr); + for (int i = 0; i < num_elements; i++) { + ptr[i] = static_cast(i) / 10.0f; + } + } + { + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + } + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(testing::FloatEq(), reference)); +} + #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 } // namespace From 3aeec84ac016d9899a7829ad5651753942dcf275 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 9 Dec 2022 03:19:45 -0800 Subject: [PATCH 132/346] Internal change for profiling PiperOrigin-RevId: 494126771 --- mediapipe/framework/validated_graph_config.cc | 10 ++++++++++ mediapipe/framework/validated_graph_config.h | 12 ++++++++++++ 2 files changed, 22 insertions(+) diff --git a/mediapipe/framework/validated_graph_config.cc b/mediapipe/framework/validated_graph_config.cc index 01e3da83e..15eac3209 100644 --- a/mediapipe/framework/validated_graph_config.cc +++ b/mediapipe/framework/validated_graph_config.cc @@ -369,6 +369,7 @@ absl::Status ValidatedGraphConfig::Initialize( input_side_packets_.clear(); output_side_packets_.clear(); stream_to_producer_.clear(); + output_streams_to_consumer_nodes_.clear(); input_streams_.clear(); output_streams_.clear(); owned_packet_types_.clear(); @@ -719,6 +720,15 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode( << " does not have a corresponding output stream."; } } + // Add this node as a consumer of this edge's output stream. + if (edge_info.upstream > -1) { + auto parent_node = output_streams_[edge_info.upstream].parent_node; + if (parent_node.type == NodeTypeInfo::NodeType::CALCULATOR) { + int this_idx = node_type_info->Node().index; + output_streams_to_consumer_nodes_[edge_info.upstream].push_back( + this_idx); + } + } edge_info.parent_node = node_type_info->Node(); edge_info.name = name; diff --git a/mediapipe/framework/validated_graph_config.h b/mediapipe/framework/validated_graph_config.h index 11f9553cd..95ecccbb4 100644 --- a/mediapipe/framework/validated_graph_config.h +++ b/mediapipe/framework/validated_graph_config.h @@ -282,6 +282,14 @@ class ValidatedGraphConfig { return output_streams_[iter->second].parent_node.index; } + std::vector OutputStreamToConsumers(int idx) const { + auto iter = output_streams_to_consumer_nodes_.find(idx); + if (iter == output_streams_to_consumer_nodes_.end()) { + return {}; + } + return iter->second; + } + // Returns the registered type name of the specified side packet if // it can be determined, otherwise an appropriate error is returned. absl::StatusOr RegisteredSidePacketTypeName( @@ -418,6 +426,10 @@ class ValidatedGraphConfig { // Mapping from stream name to the output_streams_ index which produces it. std::map stream_to_producer_; + + // Mapping from output streams to consumer node ids. Used for profiling. + std::map> output_streams_to_consumer_nodes_; + // Mapping from side packet name to the output_side_packets_ index // which produces it. std::map side_packet_to_producer_; From 4c4df2cf18955b2cc76f432b5f1321083d5bfb11 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 9 Dec 2022 04:11:05 -0800 Subject: [PATCH 133/346] Internal change for profiling PiperOrigin-RevId: 494135244 --- mediapipe/framework/profiler/graph_profiler.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mediapipe/framework/profiler/graph_profiler.h b/mediapipe/framework/profiler/graph_profiler.h index 29969af2e..23caed4ec 100644 --- a/mediapipe/framework/profiler/graph_profiler.h +++ b/mediapipe/framework/profiler/graph_profiler.h @@ -232,6 +232,11 @@ class GraphProfiler : public std::enable_shared_from_this { const ProfilerConfig& profiler_config() { return profiler_config_; } + // Helper method to expose the config to other profilers. + const ValidatedGraphConfig* GetValidatedGraphConfig() { + return validated_graph_; + } + private: // This can be used to add packet info for the input streams to the graph. // It treats the stream defined by |stream_name| as a stream produced by a From 5bc1baf96acab858942d151d46b988ebe0577c00 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 9 Dec 2022 05:55:20 -0800 Subject: [PATCH 134/346] Internal change PiperOrigin-RevId: 494150467 --- mediapipe/framework/output_stream_shard.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mediapipe/framework/output_stream_shard.h b/mediapipe/framework/output_stream_shard.h index fdc5fe077..718174c45 100644 --- a/mediapipe/framework/output_stream_shard.h +++ b/mediapipe/framework/output_stream_shard.h @@ -127,6 +127,8 @@ class OutputStreamShard : public OutputStream { friend class GraphProfiler; // Accesses OutputStreamShard for profiling. friend class GraphTracer; + // Accesses OutputStreamShard for profiling. + friend class PerfettoTraceScope; // Accesses OutputStreamShard for post processing. friend class OutputStreamManager; }; From db3cb68d919693adb729437d1223d29c30736f27 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Fri, 9 Dec 2022 07:27:11 -0800 Subject: [PATCH 135/346] Internal change. PiperOrigin-RevId: 494166776 --- .../formats/tensor_hardware_buffer.h | 71 ++++++ .../tensor_hardware_buffer_cpu_storage.cc | 216 ++++++++++++++++++ ...tensor_hardware_buffer_cpu_storage_test.cc | 76 ++++++ 3 files changed, 363 insertions(+) create mode 100644 mediapipe/framework/formats/tensor_hardware_buffer.h create mode 100644 mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc create mode 100644 mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc diff --git a/mediapipe/framework/formats/tensor_hardware_buffer.h b/mediapipe/framework/formats/tensor_hardware_buffer.h new file mode 100644 index 000000000..fa0241bde --- /dev/null +++ b/mediapipe/framework/formats/tensor_hardware_buffer.h @@ -0,0 +1,71 @@ +#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ +#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ + +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) + +#include + +#include + +#include "mediapipe/framework/formats/tensor_buffer.h" +#include "mediapipe/framework/formats/tensor_internal.h" +#include "mediapipe/framework/formats/tensor_v2.h" + +namespace mediapipe { + +// Supports: +// - float 16 and 32 bits +// - signed / unsigned integers 8,16,32 bits +class TensorHardwareBufferView; +struct TensorHardwareBufferViewDescriptor : public Tensor::ViewDescriptor { + using ViewT = TensorHardwareBufferView; + TensorBufferDescriptor buffer; +}; + +class TensorHardwareBufferView : public Tensor::View { + public: + TENSOR_UNIQUE_VIEW_TYPE_ID(); + ~TensorHardwareBufferView() = default; + + const TensorHardwareBufferViewDescriptor& descriptor() const override { + return descriptor_; + } + AHardwareBuffer* handle() const { return ahwb_handle_; } + + protected: + TensorHardwareBufferView(int access_capability, Tensor::View::Access access, + Tensor::View::State state, + const TensorHardwareBufferViewDescriptor& desc, + AHardwareBuffer* ahwb_handle) + : Tensor::View(kId, access_capability, access, state), + descriptor_(desc), + ahwb_handle_(ahwb_handle) {} + + private: + bool MatchDescriptor( + uint64_t view_type_id, + const Tensor::ViewDescriptor& base_descriptor) const override { + if (!Tensor::View::MatchDescriptor(view_type_id, base_descriptor)) + return false; + auto descriptor = + static_cast(base_descriptor); + return descriptor.buffer.format == descriptor_.buffer.format && + descriptor.buffer.size_alignment <= + descriptor_.buffer.size_alignment && + descriptor_.buffer.size_alignment % + descriptor.buffer.size_alignment == + 0; + } + const TensorHardwareBufferViewDescriptor& descriptor_; + AHardwareBuffer* ahwb_handle_ = nullptr; +}; + +} // namespace mediapipe + +#endif // !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) + +#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ diff --git a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc new file mode 100644 index 000000000..9c223ce2c --- /dev/null +++ b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc @@ -0,0 +1,216 @@ +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) + +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "mediapipe/framework/formats/tensor_backend.h" +#include "mediapipe/framework/formats/tensor_cpu_buffer.h" +#include "mediapipe/framework/formats/tensor_hardware_buffer.h" +#include "mediapipe/framework/formats/tensor_v2.h" +#include "util/task/status_macros.h" + +namespace mediapipe { +namespace { + +class TensorCpuViewImpl : public TensorCpuView { + public: + TensorCpuViewImpl(int access_capabilities, Tensor::View::Access access, + Tensor::View::State state, + const TensorCpuViewDescriptor& descriptor, void* pointer, + AHardwareBuffer* ahwb_handle) + : TensorCpuView(access_capabilities, access, state, descriptor, pointer), + ahwb_handle_(ahwb_handle) {} + ~TensorCpuViewImpl() { + // If handle_ is null then this view is constructed in GetViews with no + // access. + if (ahwb_handle_) { + if (__builtin_available(android 26, *)) { + AHardwareBuffer_unlock(ahwb_handle_, nullptr); + } + } + } + + private: + AHardwareBuffer* ahwb_handle_; +}; + +class TensorHardwareBufferViewImpl : public TensorHardwareBufferView { + public: + TensorHardwareBufferViewImpl( + int access_capability, Tensor::View::Access access, + Tensor::View::State state, + const TensorHardwareBufferViewDescriptor& descriptor, + AHardwareBuffer* handle) + : TensorHardwareBufferView(access_capability, access, state, descriptor, + handle) {} + ~TensorHardwareBufferViewImpl() = default; +}; + +class HardwareBufferCpuStorage : public TensorStorage { + public: + ~HardwareBufferCpuStorage() { + if (!ahwb_handle_) return; + if (__builtin_available(android 26, *)) { + AHardwareBuffer_release(ahwb_handle_); + } + } + + static absl::Status CanProvide( + int access_capability, const Tensor::Shape& shape, uint64_t view_type_id, + const Tensor::ViewDescriptor& base_descriptor) { + // TODO: use AHardwareBuffer_isSupported for API >= 29. + static const bool is_ahwb_supported = [] { + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc = {}; + // Aligned to the largest possible virtual memory page size. + constexpr uint32_t kPageSize = 16384; + desc.width = kPageSize; + desc.height = 1; + desc.layers = 1; + desc.format = AHARDWAREBUFFER_FORMAT_BLOB; + desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | + AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN; + AHardwareBuffer* handle; + if (AHardwareBuffer_allocate(&desc, &handle) != 0) return false; + AHardwareBuffer_release(handle); + return true; + } + return false; + }(); + if (!is_ahwb_supported) { + return absl::UnavailableError( + "AHardwareBuffer is not supported on the platform."); + } + + if (view_type_id != TensorCpuView::kId && + view_type_id != TensorHardwareBufferView::kId) { + return absl::InvalidArgumentError( + "A view type is not supported by this storage."); + } + return absl::OkStatus(); + } + + std::vector> GetViews(uint64_t latest_version) { + std::vector> result; + auto update_state = latest_version == version_ + ? Tensor::View::State::kUpToDate + : Tensor::View::State::kOutdated; + if (ahwb_handle_) { + result.push_back( + std::unique_ptr(new TensorHardwareBufferViewImpl( + kAccessCapability, Tensor::View::Access::kNoAccess, update_state, + hw_descriptor_, ahwb_handle_))); + + result.push_back(std::unique_ptr(new TensorCpuViewImpl( + kAccessCapability, Tensor::View::Access::kNoAccess, update_state, + cpu_descriptor_, nullptr, nullptr))); + } + return result; + } + + absl::StatusOr> GetView( + Tensor::View::Access access, const Tensor::Shape& shape, + uint64_t latest_version, uint64_t view_type_id, + const Tensor::ViewDescriptor& base_descriptor, int access_capability) { + MP_RETURN_IF_ERROR( + CanProvide(access_capability, shape, view_type_id, base_descriptor)); + const auto& buffer_descriptor = + view_type_id == TensorHardwareBufferView::kId + ? static_cast( + base_descriptor) + .buffer + : static_cast(base_descriptor) + .buffer; + if (!ahwb_handle_) { + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc = {}; + desc.width = TensorBufferSize(buffer_descriptor, shape); + desc.height = 1; + desc.layers = 1; + desc.format = AHARDWAREBUFFER_FORMAT_BLOB; + // TODO: Use access capabilities to set hints. + desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | + AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN; + auto error = AHardwareBuffer_allocate(&desc, &ahwb_handle_); + if (error != 0) { + return absl::UnknownError( + absl::StrCat("Error allocating hardware buffer: ", error)); + } + // Fill all possible views to provide it as proto views. + hw_descriptor_.buffer = buffer_descriptor; + cpu_descriptor_.buffer = buffer_descriptor; + } + } + if (buffer_descriptor.format != hw_descriptor_.buffer.format || + buffer_descriptor.size_alignment > + hw_descriptor_.buffer.size_alignment || + hw_descriptor_.buffer.size_alignment % + buffer_descriptor.size_alignment > + 0) { + return absl::AlreadyExistsError( + "A view with different params is already allocated with this " + "storage"); + } + + absl::StatusOr> result; + if (view_type_id == TensorHardwareBufferView::kId) { + result = GetAhwbView(access, shape, base_descriptor); + } else { + result = GetCpuView(access, shape, base_descriptor); + } + if (result.ok()) version_ = latest_version; + return result; + } + + private: + absl::StatusOr> GetAhwbView( + Tensor::View::Access access, const Tensor::Shape& shape, + const Tensor::ViewDescriptor& base_descriptor) { + return std::unique_ptr(new TensorHardwareBufferViewImpl( + kAccessCapability, access, Tensor::View::State::kUpToDate, + hw_descriptor_, ahwb_handle_)); + } + + absl::StatusOr> GetCpuView( + Tensor::View::Access access, const Tensor::Shape& shape, + const Tensor::ViewDescriptor& base_descriptor) { + void* pointer = nullptr; + if (__builtin_available(android 26, *)) { + int error = + AHardwareBuffer_lock(ahwb_handle_, + access == Tensor::View::Access::kWriteOnly + ? AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN + : AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN, + -1, nullptr, &pointer); + if (error != 0) { + return absl::UnknownError( + absl::StrCat("Error locking hardware buffer: ", error)); + } + } + return std::unique_ptr( + new TensorCpuViewImpl(access == Tensor::View::Access::kWriteOnly + ? Tensor::View::AccessCapability::kWrite + : Tensor::View::AccessCapability::kRead, + access, Tensor::View::State::kUpToDate, + cpu_descriptor_, pointer, ahwb_handle_)); + } + + static constexpr int kAccessCapability = + Tensor::View::AccessCapability::kRead | + Tensor::View::AccessCapability::kWrite; + TensorHardwareBufferViewDescriptor hw_descriptor_; + AHardwareBuffer* ahwb_handle_ = nullptr; + + TensorCpuViewDescriptor cpu_descriptor_; + uint64_t version_ = 0; +}; +TENSOR_REGISTER_STORAGE(HardwareBufferCpuStorage); + +} // namespace +} // namespace mediapipe + +#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || + // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) diff --git a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc new file mode 100644 index 000000000..0afa9899f --- /dev/null +++ b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc @@ -0,0 +1,76 @@ + +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) +#include + +#include + +#include "mediapipe/framework/formats/tensor_cpu_buffer.h" +#include "mediapipe/framework/formats/tensor_hardware_buffer.h" +#include "mediapipe/framework/formats/tensor_v2.h" +#include "testing/base/public/gmock.h" +#include "testing/base/public/gunit.h" + +namespace mediapipe { + +namespace { + +class TensorHardwareBufferTest : public ::testing::Test { + public: + TensorHardwareBufferTest() {} + ~TensorHardwareBufferTest() override {} +}; + +TEST_F(TensorHardwareBufferTest, TestFloat32) { + Tensor tensor{Tensor::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorHardwareBufferViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + EXPECT_NE(view->handle(), nullptr); + } + { + const auto& const_tensor = tensor; + MP_ASSERT_OK_AND_ASSIGN( + auto view, + const_tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + EXPECT_NE(view->data(), nullptr); + } +} + +TEST_F(TensorHardwareBufferTest, TestInt8Padding) { + Tensor tensor{Tensor::Shape({1})}; + + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorHardwareBufferViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kInt8, + .size_alignment = 4}})); + EXPECT_NE(view->handle(), nullptr); + } + { + const auto& const_tensor = tensor; + MP_ASSERT_OK_AND_ASSIGN( + auto view, + const_tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kInt8}})); + EXPECT_NE(view->data(), nullptr); + } +} + +} // namespace + +} // namespace mediapipe + +#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || + // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) From 453d67de92d19abf2488a4400532708d734b20bb Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 9 Dec 2022 13:10:25 -0800 Subject: [PATCH 136/346] Add MergeDetectionsToVectorCalculator. PiperOrigin-RevId: 494246359 --- mediapipe/calculators/core/BUILD | 1 + mediapipe/calculators/core/merge_to_vector_calculator.cc | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 29bca5fa6..2c143a609 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -1323,6 +1323,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image", "@com_google_absl//absl/status", ], diff --git a/mediapipe/calculators/core/merge_to_vector_calculator.cc b/mediapipe/calculators/core/merge_to_vector_calculator.cc index 5f05ad725..fd053ed2b 100644 --- a/mediapipe/calculators/core/merge_to_vector_calculator.cc +++ b/mediapipe/calculators/core/merge_to_vector_calculator.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mediapipe/calculators/core/merge_to_vector_calculator.h" +#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" namespace mediapipe { @@ -27,5 +28,9 @@ typedef MergeToVectorCalculator MergeGpuBuffersToVectorCalculator; MEDIAPIPE_REGISTER_NODE(MergeGpuBuffersToVectorCalculator); +typedef MergeToVectorCalculator + MergeDetectionsToVectorCalculator; +MEDIAPIPE_REGISTER_NODE(MergeDetectionsToVectorCalculator); + } // namespace api2 } // namespace mediapipe From 69c3c4c181766e4e94bca9f1db6ce49315d8ac45 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 9 Dec 2022 18:08:26 -0800 Subject: [PATCH 137/346] Internal change PiperOrigin-RevId: 494305195 --- mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts | 1 + mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts | 1 + mediapipe/tasks/web/core/task_runner.ts | 1 + mediapipe/tasks/web/text/text_classifier/text_classifier.ts | 1 + mediapipe/tasks/web/text/text_embedder/text_embedder.ts | 1 + .../tasks/web/vision/gesture_recognizer/gesture_recognizer.ts | 1 + mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts | 1 + mediapipe/tasks/web/vision/image_classifier/image_classifier.ts | 1 + mediapipe/tasks/web/vision/image_embedder/image_embedder.ts | 1 + mediapipe/tasks/web/vision/object_detector/object_detector.ts | 1 + 10 files changed, 10 insertions(+) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 265ba2b33..7bfca680a 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -94,6 +94,7 @@ export class AudioClassifier extends AudioTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 445dd5172..246cba883 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -96,6 +96,7 @@ export class AudioEmbedder extends AudioTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 6712c4d89..2011fadef 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -76,6 +76,7 @@ export abstract class TaskRunner { return createTaskRunner(type, initializeCanvas, fileset, options); } + /** @hideconstructor protected */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, graphRunner?: GraphRunnerImageLib) { diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 4a8588836..62708700a 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -92,6 +92,7 @@ export class TextClassifier extends TaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index cd5bc644e..611233e02 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -96,6 +96,7 @@ export class TextEmbedder extends TaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 69a8118a6..b6b795076 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -127,6 +127,7 @@ export class GestureRecognizer extends {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 9a0823f23..2a0e8286c 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -115,6 +115,7 @@ export class HandLandmarker extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 40e8b5099..36e7311fb 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -93,6 +93,7 @@ export class ImageClassifier extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index f8b0204ee..0c45ba5e7 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -95,6 +95,7 @@ export class ImageEmbedder extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index e2cfe0575..fbfaced12 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -92,6 +92,7 @@ export class ObjectDetector extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { From edafef9fd8bfb34e91d8578f5ad68919b8cff702 Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Fri, 9 Dec 2022 18:08:41 -0800 Subject: [PATCH 138/346] Updated issue templates. PiperOrigin-RevId: 494305235 --- .github/ISSUE_TEMPLATE/11-tasks-issue.md | 2 +- .github/ISSUE_TEMPLATE/12-model-maker-issue.md | 2 +- .../{10-solution-issue.md => 13-solution-issue.md} | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) rename .github/ISSUE_TEMPLATE/{10-solution-issue.md => 13-solution-issue.md} (81%) diff --git a/.github/ISSUE_TEMPLATE/11-tasks-issue.md b/.github/ISSUE_TEMPLATE/11-tasks-issue.md index 264371120..4e9ae721d 100644 --- a/.github/ISSUE_TEMPLATE/11-tasks-issue.md +++ b/.github/ISSUE_TEMPLATE/11-tasks-issue.md @@ -1,6 +1,6 @@ --- name: "Tasks Issue" -about: Use this template for assistance with using MediaPipe Tasks to deploy on-device ML solutions (e.g. gesture recognition etc.) on supported platforms. +about: Use this template for assistance with using MediaPipe Tasks (developers.google.com/mediapipe/solutions) to deploy on-device ML solutions (e.g. gesture recognition etc.) on supported platforms. labels: type:support --- diff --git a/.github/ISSUE_TEMPLATE/12-model-maker-issue.md b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md index 258390d5e..31e8d7f1b 100644 --- a/.github/ISSUE_TEMPLATE/12-model-maker-issue.md +++ b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md @@ -1,6 +1,6 @@ --- name: "Model Maker Issue" -about: Use this template for assistance with using MediaPipe Model Maker to create custom on-device ML solutions. +about: Use this template for assistance with using MediaPipe Model Maker (developers.google.com/mediapipe/solutions) to create custom on-device ML solutions. labels: type:support --- diff --git a/.github/ISSUE_TEMPLATE/10-solution-issue.md b/.github/ISSUE_TEMPLATE/13-solution-issue.md similarity index 81% rename from .github/ISSUE_TEMPLATE/10-solution-issue.md rename to .github/ISSUE_TEMPLATE/13-solution-issue.md index a5332cb36..9297edf6b 100644 --- a/.github/ISSUE_TEMPLATE/10-solution-issue.md +++ b/.github/ISSUE_TEMPLATE/13-solution-issue.md @@ -1,6 +1,6 @@ --- -name: "Solution Issue" -about: Use this template for assistance with a specific mediapipe solution, such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc. +name: "Solution (legacy) Issue" +about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions), such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc. labels: type:support --- From e9bb51a524bc3c9e38aa7e689020172bea678069 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 9 Dec 2022 19:19:49 -0800 Subject: [PATCH 139/346] Internal change PiperOrigin-RevId: 494314595 --- .../mediapipe/apps/instantmotiontracking/GIFEditText.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java index 10e6422ba..1b733ed82 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java @@ -18,7 +18,7 @@ import android.content.ClipDescription; import android.content.Context; import android.net.Uri; import android.os.Bundle; -import android.support.v7.widget.AppCompatEditText; +import androidx.appcompat.widget.AppCompatEditText; import android.util.AttributeSet; import android.util.Log; import android.view.inputmethod.EditorInfo; From 421f789edea501d5fbfd7078d2d9534a628dd886 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Sat, 10 Dec 2022 12:32:04 -0800 Subject: [PATCH 140/346] Internal change PiperOrigin-RevId: 494420725 --- mediapipe/framework/tool/BUILD | 2 + .../tool/calculator_graph_template.proto | 3 + mediapipe/framework/tool/proto_util_lite.cc | 103 +++++++++---- mediapipe/framework/tool/proto_util_lite.h | 28 +++- mediapipe/framework/tool/template_expander.cc | 136 ++++++++++++------ mediapipe/framework/tool/template_parser.cc | 128 ++++++++++++++++- mediapipe/framework/tool/testdata/BUILD | 10 ++ .../tool/testdata/frozen_generator.proto | 20 +++ 8 files changed, 348 insertions(+), 82 deletions(-) create mode 100644 mediapipe/framework/tool/testdata/frozen_generator.proto diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 453b5a0e8..89cb802da 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -346,6 +346,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", "@com_google_absl//absl/strings", ], ) @@ -506,6 +507,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":proto_util_lite", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:proto_descriptor_cc_proto", "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:integral_types", diff --git a/mediapipe/framework/tool/calculator_graph_template.proto b/mediapipe/framework/tool/calculator_graph_template.proto index 27153f3f7..31c233812 100644 --- a/mediapipe/framework/tool/calculator_graph_template.proto +++ b/mediapipe/framework/tool/calculator_graph_template.proto @@ -27,6 +27,9 @@ message TemplateExpression { // The FieldDescriptor::Type of the modified field. optional mediapipe.FieldDescriptorProto.Type field_type = 5; + // The FieldDescriptor::Type of each map key in the path. + repeated mediapipe.FieldDescriptorProto.Type key_type = 6; + // Alternative value for the modified field, in protobuf binary format. optional string field_value = 7; } diff --git a/mediapipe/framework/tool/proto_util_lite.cc b/mediapipe/framework/tool/proto_util_lite.cc index 4628815ea..a810ce129 100644 --- a/mediapipe/framework/tool/proto_util_lite.cc +++ b/mediapipe/framework/tool/proto_util_lite.cc @@ -22,6 +22,7 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/tool/field_data.pb.h" #include "mediapipe/framework/type_map.h" @@ -87,12 +88,13 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type, // Extracts the data value(s) for one field from a serialized message. // The message with these field values removed is written to |out|. -absl::Status GetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type, - CodedInputStream* in, CodedOutputStream* out, +absl::Status GetFieldValues(uint32 field_id, CodedInputStream* in, + CodedOutputStream* out, std::vector* field_values) { uint32 tag; while ((tag = in->ReadTag()) != 0) { int field_number = WireFormatLite::GetTagFieldNumber(tag); + WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag); if (field_number == field_id) { if (!IsLengthDelimited(wire_type) && IsLengthDelimited(WireFormatLite::GetTagWireType(tag))) { @@ -131,9 +133,7 @@ absl::Status FieldAccess::SetMessage(const std::string& message) { CodedInputStream in(&ais); StringOutputStream sos(&message_); CodedOutputStream out(&sos); - WireFormatLite::WireType wire_type = - WireFormatLite::WireTypeForFieldType(field_type_); - return GetFieldValues(field_id_, wire_type, &in, &out, &field_values_); + return GetFieldValues(field_id_, &in, &out, &field_values_); } void FieldAccess::GetMessage(std::string* result) { @@ -149,18 +149,56 @@ std::vector* FieldAccess::mutable_field_values() { return &field_values_; } +namespace { +using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry; + +// Returns the FieldAccess and index for a field-id or a map-id. +// Returns access to the field-id if the field index is found, +// to the map-id if the map entry is found, and to the field-id otherwise. +absl::StatusOr> AccessField( + const ProtoPathEntry& entry, FieldType field_type, + const FieldValue& message) { + FieldAccess result(entry.field_id, field_type); + if (entry.field_id >= 0) { + MP_RETURN_IF_ERROR(result.SetMessage(message)); + if (entry.index < result.mutable_field_values()->size()) { + return std::pair(result, entry.index); + } + } + if (entry.map_id >= 0) { + FieldAccess access(entry.map_id, field_type); + MP_RETURN_IF_ERROR(access.SetMessage(message)); + auto& field_values = *access.mutable_field_values(); + for (int index = 0; index < field_values.size(); ++index) { + FieldAccess key(entry.key_id, entry.key_type); + MP_RETURN_IF_ERROR(key.SetMessage(field_values[index])); + if (key.mutable_field_values()->at(0) == entry.key_value) { + return std::pair(std::move(access), index); + } + } + } + if (entry.field_id >= 0) { + return std::pair(result, entry.index); + } + return absl::InvalidArgumentError(absl::StrCat( + "ProtoPath field missing, field-id: ", entry.field_id, ", map-id: ", + entry.map_id, ", key: ", entry.key_value, " key_type: ", entry.key_type)); +} + +} // namespace + // Replaces a range of field values for one field nested within a protobuf. absl::Status ProtoUtilLite::ReplaceFieldRange( FieldValue* message, ProtoPath proto_path, int length, FieldType field_type, const std::vector& field_values) { - int field_id, index; - std::tie(field_id, index) = proto_path.front(); + ProtoPathEntry entry = proto_path.front(); proto_path.erase(proto_path.begin()); - FieldAccess access(field_id, !proto_path.empty() - ? WireFormatLite::TYPE_MESSAGE - : field_type); - MP_RETURN_IF_ERROR(access.SetMessage(*message)); - std::vector& v = *access.mutable_field_values(); + FieldType type = + !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type; + ASSIGN_OR_RETURN(auto r, AccessField(entry, type, *message)); + FieldAccess& access = r.first; + int index = r.second; + std::vector& v = *access.mutable_field_values(); if (!proto_path.empty()) { RET_CHECK_NO_LOG(index >= 0 && index < v.size()); MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length, @@ -180,19 +218,22 @@ absl::Status ProtoUtilLite::ReplaceFieldRange( absl::Status ProtoUtilLite::GetFieldRange( const FieldValue& message, ProtoPath proto_path, int length, FieldType field_type, std::vector* field_values) { - int field_id, index; - std::tie(field_id, index) = proto_path.front(); + ProtoPathEntry entry = proto_path.front(); proto_path.erase(proto_path.begin()); - FieldAccess access(field_id, !proto_path.empty() - ? WireFormatLite::TYPE_MESSAGE - : field_type); - MP_RETURN_IF_ERROR(access.SetMessage(message)); - std::vector& v = *access.mutable_field_values(); + FieldType type = + !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type; + ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message)); + FieldAccess& access = r.first; + int index = r.second; + std::vector& v = *access.mutable_field_values(); if (!proto_path.empty()) { RET_CHECK_NO_LOG(index >= 0 && index < v.size()); MP_RETURN_IF_ERROR( GetFieldRange(v[index], proto_path, length, field_type, field_values)); } else { + if (length == -1) { + length = v.size() - index; + } RET_CHECK_NO_LOG(index >= 0 && index <= v.size()); RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size()); field_values->insert(field_values->begin(), v.begin() + index, @@ -206,19 +247,21 @@ absl::Status ProtoUtilLite::GetFieldCount(const FieldValue& message, ProtoPath proto_path, FieldType field_type, int* field_count) { - int field_id, index; - std::tie(field_id, index) = proto_path.back(); - proto_path.pop_back(); - std::vector parent; - if (proto_path.empty()) { - parent.push_back(std::string(message)); + ProtoPathEntry entry = proto_path.front(); + proto_path.erase(proto_path.begin()); + FieldType type = + !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type; + ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message)); + FieldAccess& access = r.first; + int index = r.second; + std::vector& v = *access.mutable_field_values(); + if (!proto_path.empty()) { + RET_CHECK_NO_LOG(index >= 0 && index < v.size()); + MP_RETURN_IF_ERROR( + GetFieldCount(v[index], proto_path, field_type, field_count)); } else { - MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange( - message, proto_path, 1, WireFormatLite::TYPE_MESSAGE, &parent)); + *field_count = v.size(); } - FieldAccess access(field_id, field_type); - MP_RETURN_IF_ERROR(access.SetMessage(parent[0])); - *field_count = access.mutable_field_values()->size(); return absl::OkStatus(); } diff --git a/mediapipe/framework/tool/proto_util_lite.h b/mediapipe/framework/tool/proto_util_lite.h index 7d3a263f3..d71ceac83 100644 --- a/mediapipe/framework/tool/proto_util_lite.h +++ b/mediapipe/framework/tool/proto_util_lite.h @@ -34,15 +34,31 @@ class ProtoUtilLite { // Defines field types and tag formats. using WireFormatLite = proto_ns::internal::WireFormatLite; - // Defines a sequence of nested field-number field-index pairs. - using ProtoPath = std::vector>; - // The serialized value for a protobuf field. using FieldValue = std::string; // The serialized data type for a protobuf field. using FieldType = WireFormatLite::FieldType; + // A field-id and index, or a map-id and key, or both. + struct ProtoPathEntry { + ProtoPathEntry(int id, int index) : field_id(id), index(index) {} + ProtoPathEntry(int id, int key_id, FieldType key_type, FieldValue key_value) + : map_id(id), + key_id(key_id), + key_type(key_type), + key_value(std::move(key_value)) {} + int field_id = -1; + int index = -1; + int map_id = -1; + int key_id = -1; + FieldType key_type; + FieldValue key_value; + }; + + // Defines a sequence of nested field-number field-index pairs. + using ProtoPath = std::vector; + class FieldAccess { public: // Provides access to a certain protobuf field. @@ -57,9 +73,11 @@ class ProtoUtilLite { // Returns the serialized values of the protobuf field. std::vector* mutable_field_values(); + uint32 field_id() const { return field_id_; } + private: - const uint32 field_id_; - const FieldType field_type_; + uint32 field_id_; + FieldType field_type_; std::string message_; std::vector field_values_; }; diff --git a/mediapipe/framework/tool/template_expander.cc b/mediapipe/framework/tool/template_expander.cc index 034e1a026..a91ea5adc 100644 --- a/mediapipe/framework/tool/template_expander.cc +++ b/mediapipe/framework/tool/template_expander.cc @@ -22,6 +22,7 @@ #include #include "absl/strings/ascii.h" +#include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" @@ -44,6 +45,7 @@ using WireFormatLite = ProtoUtilLite::WireFormatLite; using FieldValue = ProtoUtilLite::FieldValue; using FieldType = ProtoUtilLite::FieldType; using ProtoPath = ProtoUtilLite::ProtoPath; +using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry; namespace { @@ -84,26 +86,87 @@ std::unique_ptr CloneMessage(const MessageLite& message) { return result; } -// Returns the (tag, index) pairs in a field path. -// For example, returns {{1, 1}, {2, 1}, {3, 1}} for path "/1[1]/2[1]/3[1]". -absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) { - absl::Status status; - std::vector ids = absl::StrSplit(path, '/'); - for (const std::string& id : ids) { - if (id.length() > 0) { - std::pair id_pair = - absl::StrSplit(id, absl::ByAnyChar("[]")); - int tag = 0; - int index = 0; - bool ok = absl::SimpleAtoi(id_pair.first, &tag) && - absl::SimpleAtoi(id_pair.second, &index); - if (!ok) { - status.Update(absl::InvalidArgumentError(path)); - } - result->push_back(std::make_pair(tag, index)); +// Parses one ProtoPathEntry. +// The parsed entry is appended to `result` and removed from `path`. +// ProtoPathEntry::key_value stores map key text. Use SetMapKeyTypes +// to serialize the key text to protobuf wire format. +absl::Status ParseEntry(absl::string_view& path, ProtoPath* result) { + bool ok = true; + int sb = path.find('['); + int eb = path.find(']'); + int field_id = -1; + ok &= absl::SimpleAtoi(path.substr(0, sb), &field_id); + auto selector = path.substr(sb + 1, eb - 1 - sb); + if (absl::StartsWith(selector, "@")) { + int eq = selector.find('='); + int key_id = -1; + ok &= absl::SimpleAtoi(selector.substr(1, eq - 1), &key_id); + auto key_text = selector.substr(eq + 1); + FieldType key_type = FieldType::TYPE_STRING; + result->push_back({field_id, key_id, key_type, std::string(key_text)}); + } else { + int index = 0; + ok &= absl::SimpleAtoi(selector, &index); + result->push_back({field_id, index}); + } + int end = path.find('/', eb); + if (end == std::string::npos) { + path = ""; + } else { + path = path.substr(end + 1); + } + return ok ? absl::OkStatus() + : absl::InvalidArgumentError( + absl::StrCat("Failed to parse ProtoPath entry: ", path)); +} + +// Specifies the FieldTypes for protobuf map keys in a ProtoPath. +// Each ProtoPathEntry::key_value is converted from text to the protobuf +// wire format for its key type. +absl::Status SetMapKeyTypes(const std::vector& key_types, + ProtoPath* result) { + int i = 0; + for (ProtoPathEntry& entry : *result) { + if (entry.map_id >= 0) { + FieldType key_type = key_types[i++]; + std::vector key_value; + MP_RETURN_IF_ERROR( + ProtoUtilLite::Serialize({entry.key_value}, key_type, &key_value)); + entry.key_type = key_type; + entry.key_value = key_value.front(); } } - return status; + return absl::OkStatus(); +} + +// Returns the (tag, index) pairs in a field path. +// For example, returns {{1, 1}, {2, 1}, {3, 1}} for "/1[1]/2[1]/3[1]", +// returns {{1, 1}, {2, 1, "INPUT_FRAMES"}} for "/1[1]/2[@1=INPUT_FRAMES]". +absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) { + result->clear(); + absl::string_view rest = path; + if (absl::StartsWith(rest, "/")) { + rest = rest.substr(1); + } + while (!rest.empty()) { + MP_RETURN_IF_ERROR(ParseEntry(rest, result)); + } + return absl::OkStatus(); +} + +// Parse the TemplateExpression.path field into a ProtoPath struct. +absl::Status ParseProtoPath(const TemplateExpression& rule, + std::string base_path, ProtoPath* result) { + ProtoPath base_entries; + MP_RETURN_IF_ERROR(ProtoPathSplit(base_path, &base_entries)); + MP_RETURN_IF_ERROR(ProtoPathSplit(rule.path(), result)); + std::vector key_types; + for (int type : rule.key_type()) { + key_types.push_back(static_cast(type)); + } + MP_RETURN_IF_ERROR(SetMapKeyTypes(key_types, result)); + result->erase(result->begin(), result->begin() + base_entries.size()); + return absl::OkStatus(); } // Returns true if one proto path is prefix by another. @@ -111,13 +174,6 @@ bool ProtoPathStartsWith(const std::string& path, const std::string& prefix) { return absl::StartsWith(path, prefix); } -// Returns the part of one proto path after a prefix proto path. -std::string ProtoPathRelative(const std::string& field_path, - const std::string& base_path) { - CHECK(ProtoPathStartsWith(field_path, base_path)); - return field_path.substr(base_path.length()); -} - // Returns the target ProtoUtilLite::FieldType of a rule. FieldType GetFieldType(const TemplateExpression& rule) { return static_cast(rule.field_type()); @@ -126,19 +182,10 @@ FieldType GetFieldType(const TemplateExpression& rule) { // Returns the count of field values at a ProtoPath. int FieldCount(const FieldValue& base, ProtoPath field_path, FieldType field_type) { - int field_id, index; - std::tie(field_id, index) = field_path.back(); - field_path.pop_back(); - std::vector parent; - if (field_path.empty()) { - parent.push_back(base); - } else { - MEDIAPIPE_CHECK_OK(ProtoUtilLite::GetFieldRange( - base, field_path, 1, WireFormatLite::TYPE_MESSAGE, &parent)); - } - ProtoUtilLite::FieldAccess access(field_id, field_type); - MEDIAPIPE_CHECK_OK(access.SetMessage(parent[0])); - return access.mutable_field_values()->size(); + int result = 0; + CHECK( + ProtoUtilLite::GetFieldCount(base, field_path, field_type, &result).ok()); + return result; } } // namespace @@ -229,9 +276,7 @@ class TemplateExpanderImpl { return absl::OkStatus(); } ProtoPath field_path; - absl::Status status = - ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path); - if (!status.ok()) return status; + MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path)); return ProtoUtilLite::GetFieldRange(output, field_path, 1, GetFieldType(rule), base); } @@ -242,12 +287,13 @@ class TemplateExpanderImpl { const std::vector& field_values, FieldValue* output) { if (!rule.has_path()) { - *output = field_values[0]; + if (!field_values.empty()) { + *output = field_values[0]; + } return absl::OkStatus(); } ProtoPath field_path; - RET_CHECK_OK( - ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path)); + MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path)); int field_count = 1; if (rule.has_field_value()) { // For a non-repeated field, only one value can be specified. @@ -257,7 +303,7 @@ class TemplateExpanderImpl { "Multiple values specified for non-repeated field: ", rule.path())); } // For a non-repeated field, the field value is stored only in the rule. - field_path[field_path.size() - 1].second = 0; + field_path[field_path.size() - 1].index = 0; field_count = 0; } return ProtoUtilLite::ReplaceFieldRange(output, field_path, field_count, diff --git a/mediapipe/framework/tool/template_parser.cc b/mediapipe/framework/tool/template_parser.cc index 1d81e7a78..5a0ceccd3 100644 --- a/mediapipe/framework/tool/template_parser.cc +++ b/mediapipe/framework/tool/template_parser.cc @@ -26,6 +26,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/deps/proto_descriptor.pb.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/integral_types.h" @@ -45,6 +46,9 @@ using mediapipe::proto_ns::Message; using mediapipe::proto_ns::OneofDescriptor; using mediapipe::proto_ns::Reflection; using mediapipe::proto_ns::TextFormat; +using ProtoPath = mediapipe::tool::ProtoUtilLite::ProtoPath; +using FieldType = mediapipe::tool::ProtoUtilLite::FieldType; +using FieldValue = mediapipe::tool::ProtoUtilLite::FieldValue; namespace mediapipe { @@ -1357,7 +1361,7 @@ absl::Status ProtoPathSplit(const std::string& path, if (!ok) { status.Update(absl::InvalidArgumentError(path)); } - result->push_back(std::make_pair(tag, index)); + result->push_back({tag, index}); } } return status; @@ -1381,7 +1385,7 @@ void StowFieldValue(Message* message, TemplateExpression* expression) { const Descriptor* descriptor = message->GetDescriptor(); ProtoUtilLite::ProtoPath path; MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path)); - int field_number = path[path.size() - 1].first; + int field_number = path[path.size() - 1].field_id; const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number); if (!field->is_repeated()) { std::vector field_values; @@ -1402,6 +1406,124 @@ static void StripQuotes(std::string* str) { } } +// Returns the field or extension for field number. +const FieldDescriptor* FindFieldByNumber(const Message* message, + int field_num) { + const FieldDescriptor* result = + message->GetDescriptor()->FindFieldByNumber(field_num); + if (result == nullptr) { + result = message->GetReflection()->FindKnownExtensionByNumber(field_num); + } + return result; +} + +// Returns the message value from a field at an index. +const Message* GetFieldMessage(const Message& message, + const FieldDescriptor* field, int index) { + if (field->type() != FieldDescriptor::TYPE_MESSAGE) { + return nullptr; + } + if (!field->is_repeated()) { + return &message.GetReflection()->GetMessage(message, field); + } + if (index < message.GetReflection()->FieldSize(message, field)) { + return &message.GetReflection()->GetRepeatedMessage(message, field, index); + } + return nullptr; +} + +// Serialize a ProtoPath as a readable string. +// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]", +// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]". +std::string ProtoPathJoin(ProtoPath path) { + std::string result; + for (ProtoUtilLite::ProtoPathEntry& e : path) { + if (e.field_id >= 0) { + absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]"); + } else if (e.map_id >= 0) { + absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value, + "]"); + } + } + return result; +} + +// Returns the protobuf map key types from a ProtoPath. +std::vector ProtoPathKeyTypes(ProtoPath path) { + std::vector result; + for (auto& entry : path) { + if (entry.map_id >= 0) { + result.push_back(entry.key_type); + } + } + return result; +} + +// Returns the text value for a string or numeric protobuf map key. +std::string GetMapKey(const Message& map_entry) { + auto key_field = map_entry.GetDescriptor()->FindFieldByName("key"); + auto reflection = map_entry.GetReflection(); + if (key_field->type() == FieldDescriptor::TYPE_STRING) { + return reflection->GetString(map_entry, key_field); + } else if (key_field->type() == FieldDescriptor::TYPE_INT32) { + return absl::StrCat(reflection->GetInt32(map_entry, key_field)); + } else if (key_field->type() == FieldDescriptor::TYPE_INT64) { + return absl::StrCat(reflection->GetInt64(map_entry, key_field)); + } + return ""; +} + +// Adjusts map-entries from indexes to keys. +// Protobuf map-entry order is intentionally not preserved. +mediapipe::Status KeyProtoMapEntries(Message* source) { + // Copy the rules from the source CalculatorGraphTemplate. + mediapipe::CalculatorGraphTemplate rules; + rules.ParsePartialFromString(source->SerializePartialAsString()); + // Only the "source" Message knows all extension types. + Message* config_0 = source->GetReflection()->MutableMessage( + source, source->GetDescriptor()->FindFieldByName("config"), nullptr); + for (int i = 0; i < rules.rule().size(); ++i) { + TemplateExpression* rule = rules.mutable_rule()->Mutable(i); + const Message* message = config_0; + ProtoPath path; + MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path)); + for (int j = 0; j < path.size(); ++j) { + int field_id = path[j].field_id; + int field_index = path[j].index; + const FieldDescriptor* field = FindFieldByNumber(message, field_id); + if (field->is_map()) { + const Message* map_entry = + GetFieldMessage(*message, field, path[j].index); + int key_id = + map_entry->GetDescriptor()->FindFieldByName("key")->number(); + FieldType key_type = static_cast( + map_entry->GetDescriptor()->FindFieldByName("key")->type()); + std::string key_value = GetMapKey(*map_entry); + path[j] = {field_id, key_id, key_type, key_value}; + } + message = GetFieldMessage(*message, field, field_index); + if (!message) { + break; + } + } + if (!rule->path().empty()) { + *rule->mutable_path() = ProtoPathJoin(path); + for (FieldType key_type : ProtoPathKeyTypes(path)) { + *rule->mutable_key_type()->Add() = key_type; + } + } + } + // Copy the rules back into the source CalculatorGraphTemplate. + auto source_rules = + source->GetReflection()->GetMutableRepeatedFieldRef( + source, source->GetDescriptor()->FindFieldByName("rule")); + source_rules.Clear(); + for (auto& rule : rules.rule()) { + source_rules.Add(rule); + } + return absl::OkStatus(); +} + } // namespace class TemplateParser::Parser::MediaPipeParserImpl @@ -1416,6 +1538,8 @@ class TemplateParser::Parser::MediaPipeParserImpl // Copy the template rules into the output template "rule" field. success &= MergeFields(template_rules_, output).ok(); + // Replace map-entry indexes with map keys. + success &= KeyProtoMapEntries(output).ok(); return success; } diff --git a/mediapipe/framework/tool/testdata/BUILD b/mediapipe/framework/tool/testdata/BUILD index f9aab7b72..8300181b5 100644 --- a/mediapipe/framework/tool/testdata/BUILD +++ b/mediapipe/framework/tool/testdata/BUILD @@ -17,6 +17,7 @@ load( "//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_simple_subgraph", ) +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) @@ -58,3 +59,12 @@ mediapipe_simple_subgraph( "//mediapipe/framework:test_calculators", ], ) + +mediapipe_proto_library( + name = "frozen_generator_proto", + srcs = ["frozen_generator.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = [ + "//mediapipe/framework:packet_generator_proto", + ], +) diff --git a/mediapipe/framework/tool/testdata/frozen_generator.proto b/mediapipe/framework/tool/testdata/frozen_generator.proto new file mode 100644 index 000000000..5f133f461 --- /dev/null +++ b/mediapipe/framework/tool/testdata/frozen_generator.proto @@ -0,0 +1,20 @@ +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/packet_generator.proto"; + +message FrozenGeneratorOptions { + extend mediapipe.PacketGeneratorOptions { + optional FrozenGeneratorOptions ext = 225748738; + } + + // Path to file containing serialized proto of type tensorflow::GraphDef. + optional string graph_proto_path = 1; + + // This map defines the which streams are fed to which tensors in the model. + map tag_to_tensor_names = 2; + + // Graph nodes to run to initialize the model. + repeated string initialization_op_names = 4; +} From 37d2e369605e87cf741220db1d3c1b4afb403def Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 12 Dec 2022 12:08:45 -0800 Subject: [PATCH 141/346] Internal change PiperOrigin-RevId: 494791433 --- .github/ISSUE_TEMPLATE/14-studio-issue.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/14-studio-issue.md diff --git a/.github/ISSUE_TEMPLATE/14-studio-issue.md b/.github/ISSUE_TEMPLATE/14-studio-issue.md new file mode 100644 index 000000000..5942b1eb1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/14-studio-issue.md @@ -0,0 +1,19 @@ +--- +name: "Studio Issue" +about: Use this template for assistance with the MediaPipe Studio application. +labels: type:support + +--- +Please make sure that this is a MediaPipe Studio issue. + +**System information** (Please provide as much relevant information as possible) +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4): +- Browser and Version +- Any microphone or camera hardware +- URL that shows the problem + +**Describe the expected behavior:** + +**Other info / Complete Logs :** +Include any js console logs that would be helpful to diagnose the problem. +Large logs and files should be attached: From 3f66dde8fdb459be8552b837e83fb2a79c44566c Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Mon, 12 Dec 2022 17:33:08 -0800 Subject: [PATCH 142/346] Change `--site_path` default value to match the actual path. This did not match the URL we ended up using for MediaPipe, so needs to be set correctly in order to generate docs that match the real site. This change sets the default to be correct. PiperOrigin-RevId: 494874789 --- docs/build_py_api_docs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py index 46546012d..02eb04074 100644 --- a/docs/build_py_api_docs.py +++ b/docs/build_py_api_docs.py @@ -44,14 +44,14 @@ _OUTPUT_DIR = flags.DEFINE_string( _URL_PREFIX = flags.DEFINE_string( 'code_url_prefix', - 'https://github.com/google/mediapipe/tree/master/mediapipe', + 'https://github.com/google/mediapipe/blob/master/mediapipe', 'The url prefix for links to code.') _SEARCH_HINTS = flags.DEFINE_bool( 'search_hints', True, 'Include metadata search hints in the generated files') -_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', +_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api/solutions/python', 'Path prefix in the _toc.yaml') From fb2179761187f5a0c73c973d94690685170d9a21 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 12 Dec 2022 21:28:35 -0800 Subject: [PATCH 143/346] Internal change PiperOrigin-RevId: 494914168 --- mediapipe/calculators/image/image_cropping_calculator.cc | 3 ++- mediapipe/calculators/image/image_cropping_calculator_test.cc | 4 ++-- mediapipe/calculators/util/detections_to_rects_calculator.cc | 3 +++ .../calculators/util/detections_to_rects_calculator_test.cc | 3 +++ mediapipe/calculators/util/landmark_projection_calculator.cc | 2 ++ mediapipe/calculators/util/landmarks_smoothing_calculator.cc | 2 ++ mediapipe/calculators/util/rect_projection_calculator.cc | 2 ++ mediapipe/calculators/util/rect_to_render_data_calculator.cc | 3 +++ mediapipe/calculators/util/rect_to_render_scale_calculator.cc | 2 ++ mediapipe/calculators/util/rect_transformation_calculator.cc | 3 +++ .../calculators/util/world_landmark_projection_calculator.cc | 2 ++ .../calculators/video/tracked_detection_manager_calculator.cc | 2 ++ .../calculators/hand_landmarks_to_rect_calculator.cc | 2 ++ .../holistic_landmark/calculators/roi_tracking_calculator.cc | 2 ++ .../calculators/frame_annotation_to_rect_calculator.cc | 2 ++ .../cc/components/processors/image_preprocessing_graph.cc | 1 + .../calculators/landmarks_to_matrix_calculator.cc | 2 ++ .../calculators/landmarks_to_matrix_calculator_test.cc | 2 ++ .../tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc | 2 ++ .../cc/vision/gesture_recognizer/gesture_recognizer_graph.cc | 1 + .../gesture_recognizer/hand_gesture_recognizer_graph.cc | 1 + .../tasks/cc/vision/hand_detector/hand_detector_graph.cc | 1 + .../tasks/cc/vision/hand_detector/hand_detector_graph_test.cc | 1 + .../calculators/hand_association_calculator.cc | 2 ++ .../calculators/hand_association_calculator_test.cc | 2 ++ .../calculators/hand_landmarks_deduplication_calculator.cc | 1 + mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc | 2 ++ .../tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc | 1 + .../cc/vision/hand_landmarker/hand_landmarker_graph_test.cc | 1 + .../vision/hand_landmarker/hand_landmarks_detector_graph.cc | 1 + .../hand_landmarker/hand_landmarks_detector_graph_test.cc | 1 + .../tasks/cc/vision/image_classifier/image_classifier.cc | 1 + .../cc/vision/image_classifier/image_classifier_graph.cc | 1 + mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc | 1 + .../tasks/cc/vision/image_embedder/image_embedder_graph.cc | 1 + mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc | 1 + .../tasks/cc/vision/image_segmenter/image_segmenter_graph.cc | 1 + mediapipe/tasks/cc/vision/object_detector/object_detector.cc | 1 + .../tasks/cc/vision/object_detector/object_detector_graph.cc | 1 + mediapipe/util/rectangle_util_test.cc | 1 + mediapipe/util/tracking/tracked_detection.cc | 2 ++ mediapipe/util/tracking/tracked_detection_manager.cc | 1 + mediapipe/util/tracking/tracked_detection_test.cc | 2 ++ 43 files changed, 70 insertions(+), 3 deletions(-) diff --git a/mediapipe/calculators/image/image_cropping_calculator.cc b/mediapipe/calculators/image/image_cropping_calculator.cc index 8c9305ffb..1a2b2e5b0 100644 --- a/mediapipe/calculators/image/image_cropping_calculator.cc +++ b/mediapipe/calculators/image/image_cropping_calculator.cc @@ -37,7 +37,8 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; namespace mediapipe { namespace { - +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; #if !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/calculators/image/image_cropping_calculator_test.cc b/mediapipe/calculators/image/image_cropping_calculator_test.cc index b3f692889..3c565282b 100644 --- a/mediapipe/calculators/image/image_cropping_calculator_test.cc +++ b/mediapipe/calculators/image/image_cropping_calculator_test.cc @@ -195,11 +195,11 @@ TEST(ImageCroppingCalculatorTest, RedundantSpecWithInputStream) { auto cc = absl::make_unique( calculator_state.get(), inputTags, tool::CreateTagMap({}).value()); auto& inputs = cc->Inputs(); - mediapipe::Rect rect = ParseTextProtoOrDie( + Rect rect = ParseTextProtoOrDie( R"pb( width: 1 height: 1 x_center: 40 y_center: 40 rotation: 0.5 )pb"); - inputs.Tag(kRectTag).Value() = MakePacket(rect); + inputs.Tag(kRectTag).Value() = MakePacket(rect); RectSpec expectRect = { .width = 1, .height = 1, diff --git a/mediapipe/calculators/util/detections_to_rects_calculator.cc b/mediapipe/calculators/util/detections_to_rects_calculator.cc index 73a67d322..3e566836c 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator.cc @@ -37,6 +37,9 @@ constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kRectsTag[] = "RECTS"; constexpr char kNormRectsTag[] = "NORM_RECTS"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + constexpr float kMinFloat = std::numeric_limits::lowest(); constexpr float kMaxFloat = std::numeric_limits::max(); diff --git a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc index 6caf792a7..63de60a60 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc @@ -39,6 +39,9 @@ constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kRectTag[] = "RECT"; constexpr char kDetectionTag[] = "DETECTION"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + MATCHER_P4(RectEq, x_center, y_center, width, height, "") { return testing::Value(arg.x_center(), testing::Eq(x_center)) && testing::Value(arg.y_center(), testing::Eq(y_center)) && diff --git a/mediapipe/calculators/util/landmark_projection_calculator.cc b/mediapipe/calculators/util/landmark_projection_calculator.cc index e27edea66..9f276da56 100644 --- a/mediapipe/calculators/util/landmark_projection_calculator.cc +++ b/mediapipe/calculators/util/landmark_projection_calculator.cc @@ -24,6 +24,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + namespace { constexpr char kLandmarksTag[] = "NORM_LANDMARKS"; diff --git a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc index 6673816e7..7a92cfb7e 100644 --- a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc @@ -35,7 +35,9 @@ constexpr char kObjectScaleRoiTag[] = "OBJECT_SCALE_ROI"; constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS"; constexpr char kFilteredLandmarksTag[] = "FILTERED_LANDMARKS"; +using ::mediapipe::NormalizedRect; using mediapipe::OneEuroFilter; +using ::mediapipe::Rect; using mediapipe::RelativeVelocityFilter; void NormalizedLandmarksToLandmarks( diff --git a/mediapipe/calculators/util/rect_projection_calculator.cc b/mediapipe/calculators/util/rect_projection_calculator.cc index dcc6e7391..69b28af87 100644 --- a/mediapipe/calculators/util/rect_projection_calculator.cc +++ b/mediapipe/calculators/util/rect_projection_calculator.cc @@ -23,6 +23,8 @@ namespace { constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormReferenceRectTag[] = "NORM_REFERENCE_RECT"; +using ::mediapipe::NormalizedRect; + } // namespace // Projects rectangle from reference coordinate system (defined by reference diff --git a/mediapipe/calculators/util/rect_to_render_data_calculator.cc b/mediapipe/calculators/util/rect_to_render_data_calculator.cc index 400be277d..bbc08255e 100644 --- a/mediapipe/calculators/util/rect_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_data_calculator.cc @@ -29,6 +29,9 @@ constexpr char kNormRectsTag[] = "NORM_RECTS"; constexpr char kRectsTag[] = "RECTS"; constexpr char kRenderDataTag[] = "RENDER_DATA"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + RenderAnnotation::Rectangle* NewRect( const RectToRenderDataCalculatorOptions& options, RenderData* render_data) { auto* annotation = render_data->add_render_annotations(); diff --git a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc index d94615228..6ff6b3d51 100644 --- a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc @@ -24,6 +24,8 @@ constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kRenderScaleTag[] = "RENDER_SCALE"; +using ::mediapipe::NormalizedRect; + } // namespace // A calculator to get scale for RenderData primitives. diff --git a/mediapipe/calculators/util/rect_transformation_calculator.cc b/mediapipe/calculators/util/rect_transformation_calculator.cc index 15bb26826..4783cb919 100644 --- a/mediapipe/calculators/util/rect_transformation_calculator.cc +++ b/mediapipe/calculators/util/rect_transformation_calculator.cc @@ -28,6 +28,9 @@ constexpr char kRectTag[] = "RECT"; constexpr char kRectsTag[] = "RECTS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + // Wraps around an angle in radians to within -M_PI and M_PI. inline float NormalizeRadians(float angle) { return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI)); diff --git a/mediapipe/calculators/util/world_landmark_projection_calculator.cc b/mediapipe/calculators/util/world_landmark_projection_calculator.cc index bcd7352a2..e843d63bf 100644 --- a/mediapipe/calculators/util/world_landmark_projection_calculator.cc +++ b/mediapipe/calculators/util/world_landmark_projection_calculator.cc @@ -22,6 +22,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + namespace { constexpr char kLandmarksTag[] = "LANDMARKS"; diff --git a/mediapipe/calculators/video/tracked_detection_manager_calculator.cc b/mediapipe/calculators/video/tracked_detection_manager_calculator.cc index c416fa9b0..48664fead 100644 --- a/mediapipe/calculators/video/tracked_detection_manager_calculator.cc +++ b/mediapipe/calculators/video/tracked_detection_manager_calculator.cc @@ -32,6 +32,8 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + constexpr int kDetectionUpdateTimeOutMS = 5000; constexpr char kDetectionsTag[] = "DETECTIONS"; constexpr char kDetectionBoxesTag[] = "DETECTION_BOXES"; diff --git a/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc b/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc index 6f2c49d64..638678ff5 100644 --- a/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc +++ b/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc @@ -22,6 +22,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + namespace { // NORM_LANDMARKS is either the full set of landmarks for the hand, or diff --git a/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc b/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc index 0da6cd7f7..49c7b93fb 100644 --- a/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc +++ b/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc @@ -34,6 +34,8 @@ constexpr char kRecropRectTag[] = "RECROP_RECT"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kTrackingRectTag[] = "TRACKING_RECT"; +using ::mediapipe::NormalizedRect; + // TODO: Use rect rotation. // Verifies that Intersection over Union of previous frame rect and current // frame re-crop rect is less than threshold. diff --git a/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc b/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc index 476f8cb54..1fe919c54 100644 --- a/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc +++ b/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc @@ -34,6 +34,8 @@ namespace { constexpr char kInputFrameAnnotationTag[] = "FRAME_ANNOTATION"; constexpr char kOutputNormRectsTag[] = "NORM_RECTS"; +using ::mediapipe::NormalizedRect; + } // namespace // A calculator that converts FrameAnnotation proto to NormalizedRect. diff --git a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc index b24b7f0cb..fefc1ec52 100644 --- a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc @@ -45,6 +45,7 @@ namespace components { namespace processors { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::Tensor; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc index 277bb170a..088f97c29 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc @@ -35,6 +35,8 @@ limitations under the License. namespace mediapipe { namespace api2 { +using ::mediapipe::NormalizedRect; + namespace { constexpr char kLandmarksTag[] = "LANDMARKS"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc index fe6f1162b..a1a44c8d1 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc @@ -33,6 +33,8 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index e7fcf6fd9..01f444742 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -57,6 +57,8 @@ namespace { using GestureRecognizerGraphOptionsProto = ::mediapipe::tasks::vision:: gesture_recognizer::proto::GestureRecognizerGraphOptions; +using ::mediapipe::NormalizedRect; + constexpr char kHandGestureSubgraphTypeName[] = "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc index 47d95100b..2d949c410 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -46,6 +46,7 @@ namespace gesture_recognizer { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index d7e983d81..4db57e85b 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -52,6 +52,7 @@ namespace gesture_recognizer { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index c24548c9b..49958e36b 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -50,6 +50,7 @@ namespace hand_detector { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc index cbbc0e193..f4e5f8c7d 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc @@ -53,6 +53,7 @@ namespace { using ::file::Defaults; using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc index b6df80588..dffdbdd38 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc @@ -27,6 +27,8 @@ limitations under the License. namespace mediapipe::api2 { +using ::mediapipe::NormalizedRect; + // HandAssociationCalculator accepts multiple inputs of vectors of // NormalizedRect. The output is a vector of NormalizedRect that contains // rects from the input vectors that don't overlap with each other. When two diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc index cb3130854..138164209 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc @@ -26,6 +26,8 @@ limitations under the License. namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + class HandAssociationCalculatorTest : public testing::Test { protected: HandAssociationCalculatorTest() { diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc index 266ce223f..d875de98f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc @@ -41,6 +41,7 @@ limitations under the License. namespace mediapipe::api2 { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Source; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc index 2b818b2e5..3bb1ee8d8 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc @@ -46,6 +46,8 @@ namespace { using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision:: hand_landmarker::proto::HandLandmarkerGraphOptions; +using ::mediapipe::NormalizedRect; + constexpr char kHandLandmarkerGraphTypeName[] = "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index 2c4133eb1..05ad97efe 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -49,6 +49,7 @@ namespace hand_landmarker { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc index f275486f5..c28df2c05 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc @@ -54,6 +54,7 @@ namespace { using ::file::Defaults; using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index 014830ba2..4ea066aab 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -53,6 +53,7 @@ namespace hand_landmarker { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc index d1e928ce7..f28907d2f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc @@ -50,6 +50,7 @@ namespace { using ::file::Defaults; using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index 60f8f7ed4..763e0a320 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -58,6 +58,7 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::components::containers::ConvertToClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::PacketMap; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 2d0379c66..0adcf842d 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -38,6 +38,7 @@ namespace image_classifier { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc index e3198090f..494b075a7 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc @@ -54,6 +54,7 @@ constexpr char kGraphTypeName[] = "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::core::PacketMap; diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index 81ccb5361..95c4ff379 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -34,6 +34,7 @@ namespace image_embedder { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index bbee714c6..7130c72e2 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -44,6 +44,7 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: image_segmenter::proto::ImageSegmenterGraphOptions; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 5531968c1..923cf2937 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -49,6 +49,7 @@ namespace image_segmenter { namespace { using ::mediapipe::Image; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index e0222dd70..2477f8a44 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -57,6 +57,7 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.ObjectDetectorGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::components::containers::ConvertToDetectionResult; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index fd95bb1ac..e5af7544d 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -52,6 +52,7 @@ namespace vision { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/util/rectangle_util_test.cc b/mediapipe/util/rectangle_util_test.cc index cd1946d45..3bc323f9f 100644 --- a/mediapipe/util/rectangle_util_test.cc +++ b/mediapipe/util/rectangle_util_test.cc @@ -20,6 +20,7 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; using ::testing::FloatNear; class RectangleUtilTest : public testing::Test { diff --git a/mediapipe/util/tracking/tracked_detection.cc b/mediapipe/util/tracking/tracked_detection.cc index 130a87640..80a6981a8 100644 --- a/mediapipe/util/tracking/tracked_detection.cc +++ b/mediapipe/util/tracking/tracked_detection.cc @@ -20,6 +20,8 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + // Struct for carrying boundary information. struct NormalizedRectBounds { float left, right, top, bottom; diff --git a/mediapipe/util/tracking/tracked_detection_manager.cc b/mediapipe/util/tracking/tracked_detection_manager.cc index 597827f3c..a9e348ceb 100644 --- a/mediapipe/util/tracking/tracked_detection_manager.cc +++ b/mediapipe/util/tracking/tracked_detection_manager.cc @@ -21,6 +21,7 @@ namespace { +using ::mediapipe::NormalizedRect; using mediapipe::TrackedDetection; // Checks if a point is out of view. diff --git a/mediapipe/util/tracking/tracked_detection_test.cc b/mediapipe/util/tracking/tracked_detection_test.cc index 60b9df1b1..13efaab92 100644 --- a/mediapipe/util/tracking/tracked_detection_test.cc +++ b/mediapipe/util/tracking/tracked_detection_test.cc @@ -18,6 +18,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + const float kErrorMargin = 1e-4f; TEST(TrackedDetectionTest, ConstructorWithoutBox) { From 78597c5b37a2ef8f3f005ed55f0a01676a08fb0b Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 13 Dec 2022 09:05:19 -0800 Subject: [PATCH 144/346] Internal changes. PiperOrigin-RevId: 495038477 --- mediapipe/framework/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 3cc72b4f1..265ae9c6f 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -21,6 +21,7 @@ licenses(["notice"]) package(default_visibility = ["//visibility:private"]) +# The MediaPipe internal package group. No mediapipe users should be added to this group. package_group( name = "mediapipe_internal", packages = [ From db404b1a8593a8b316cc4930dc1bcc845fc3df62 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Tue, 13 Dec 2022 10:21:07 -0800 Subject: [PATCH 145/346] Internal change PiperOrigin-RevId: 495058817 --- mediapipe/framework/tool/proto_util_lite.h | 7 +- mediapipe/framework/tool/template_parser.cc | 181 +++++++++++++++----- 2 files changed, 144 insertions(+), 44 deletions(-) diff --git a/mediapipe/framework/tool/proto_util_lite.h b/mediapipe/framework/tool/proto_util_lite.h index d71ceac83..15e321eeb 100644 --- a/mediapipe/framework/tool/proto_util_lite.h +++ b/mediapipe/framework/tool/proto_util_lite.h @@ -48,11 +48,16 @@ class ProtoUtilLite { key_id(key_id), key_type(key_type), key_value(std::move(key_value)) {} + bool operator==(const ProtoPathEntry& o) const { + return field_id == o.field_id && index == o.index && map_id == o.map_id && + key_id == o.key_id && key_type == o.key_type && + key_value == o.key_value; + } int field_id = -1; int index = -1; int map_id = -1; int key_id = -1; - FieldType key_type; + FieldType key_type = FieldType::MAX_FIELD_TYPE; FieldValue key_value; }; diff --git a/mediapipe/framework/tool/template_parser.cc b/mediapipe/framework/tool/template_parser.cc index 5a0ceccd3..cf23f3443 100644 --- a/mediapipe/framework/tool/template_parser.cc +++ b/mediapipe/framework/tool/template_parser.cc @@ -1367,26 +1367,132 @@ absl::Status ProtoPathSplit(const std::string& path, return status; } +// Returns a message serialized deterministically. +bool DeterministicallySerialize(const Message& proto, std::string* result) { + proto_ns::io::StringOutputStream stream(result); + proto_ns::io::CodedOutputStream output(&stream); + output.SetSerializationDeterministic(true); + return proto.SerializeToCodedStream(&output); +} + // Serialize one field of a message. void SerializeField(const Message* message, const FieldDescriptor* field, std::vector* result) { ProtoUtilLite::FieldValue message_bytes; - CHECK(message->SerializePartialToString(&message_bytes)); + CHECK(DeterministicallySerialize(*message, &message_bytes)); ProtoUtilLite::FieldAccess access( field->number(), static_cast(field->type())); MEDIAPIPE_CHECK_OK(access.SetMessage(message_bytes)); *result = *access.mutable_field_values(); } +// Serialize a ProtoPath as a readable string. +// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]", +// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]". +std::string ProtoPathJoin(ProtoPath path) { + std::string result; + for (ProtoUtilLite::ProtoPathEntry& e : path) { + if (e.field_id >= 0) { + absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]"); + } else if (e.map_id >= 0) { + absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value, + "]"); + } + } + return result; +} + +// Returns the message value from a field at an index. +const Message* GetFieldMessage(const Message& message, + const FieldDescriptor* field, int index) { + if (field->type() != FieldDescriptor::TYPE_MESSAGE) { + return nullptr; + } + if (!field->is_repeated()) { + return &message.GetReflection()->GetMessage(message, field); + } + if (index < message.GetReflection()->FieldSize(message, field)) { + return &message.GetReflection()->GetRepeatedMessage(message, field, index); + } + return nullptr; +} + +// Returns all FieldDescriptors including extensions. +std::vector GetFields(const Message* src) { + std::vector result; + src->GetDescriptor()->file()->pool()->FindAllExtensions(src->GetDescriptor(), + &result); + for (int i = 0; i < src->GetDescriptor()->field_count(); ++i) { + result.push_back(src->GetDescriptor()->field(i)); + } + return result; +} + +// Orders map entries in dst to match src. +void OrderMapEntries(const Message* src, Message* dst, + std::set* seen = nullptr) { + std::unique_ptr> seen_owner; + if (!seen) { + seen_owner = std::make_unique>(); + seen = seen_owner.get(); + } + if (seen->count(src) > 0) { + return; + } else { + seen->insert(src); + } + for (auto field : GetFields(src)) { + if (field->is_map()) { + dst->GetReflection()->ClearField(dst, field); + for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) { + const Message& entry = + src->GetReflection()->GetRepeatedMessage(*src, field, j); + dst->GetReflection()->AddMessage(dst, field)->CopyFrom(entry); + } + } + if (field->type() == FieldDescriptor::TYPE_MESSAGE) { + if (field->is_repeated()) { + for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) { + OrderMapEntries( + &src->GetReflection()->GetRepeatedMessage(*src, field, j), + dst->GetReflection()->MutableRepeatedMessage(dst, field, j), + seen); + } + } else { + OrderMapEntries(&src->GetReflection()->GetMessage(*src, field), + dst->GetReflection()->MutableMessage(dst, field), seen); + } + } + } +} + +// Copies a Message, keeping map entries in order. +std::unique_ptr CloneMessage(const Message* message) { + std::unique_ptr result(message->New()); + result->CopyFrom(*message); + OrderMapEntries(message, result.get()); + return result; +} + +using MessageMap = std::map>; + // For a non-repeated field, move the most recently parsed field value // into the most recently parsed template expression. -void StowFieldValue(Message* message, TemplateExpression* expression) { +void StowFieldValue(Message* message, TemplateExpression* expression, + MessageMap* stowed_messages) { const Reflection* reflection = message->GetReflection(); const Descriptor* descriptor = message->GetDescriptor(); ProtoUtilLite::ProtoPath path; MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path)); int field_number = path[path.size() - 1].field_id; const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number); + + // Save each stowed message unserialized preserving map entry order. + if (!field->is_repeated() && field->type() == FieldDescriptor::TYPE_MESSAGE) { + (*stowed_messages)[ProtoPathJoin(path)] = + CloneMessage(GetFieldMessage(*message, field, 0)); + } + if (!field->is_repeated()) { std::vector field_values; SerializeField(message, field, &field_values); @@ -1417,37 +1523,6 @@ const FieldDescriptor* FindFieldByNumber(const Message* message, return result; } -// Returns the message value from a field at an index. -const Message* GetFieldMessage(const Message& message, - const FieldDescriptor* field, int index) { - if (field->type() != FieldDescriptor::TYPE_MESSAGE) { - return nullptr; - } - if (!field->is_repeated()) { - return &message.GetReflection()->GetMessage(message, field); - } - if (index < message.GetReflection()->FieldSize(message, field)) { - return &message.GetReflection()->GetRepeatedMessage(message, field, index); - } - return nullptr; -} - -// Serialize a ProtoPath as a readable string. -// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]", -// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]". -std::string ProtoPathJoin(ProtoPath path) { - std::string result; - for (ProtoUtilLite::ProtoPathEntry& e : path) { - if (e.field_id >= 0) { - absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]"); - } else if (e.map_id >= 0) { - absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value, - "]"); - } - } - return result; -} - // Returns the protobuf map key types from a ProtoPath. std::vector ProtoPathKeyTypes(ProtoPath path) { std::vector result; @@ -1473,9 +1548,29 @@ std::string GetMapKey(const Message& map_entry) { return ""; } +// Returns a Message store in CalculatorGraphTemplate::field_value. +Message* FindStowedMessage(MessageMap* stowed_messages, ProtoPath proto_path) { + auto it = stowed_messages->find(ProtoPathJoin(proto_path)); + return (it != stowed_messages->end()) ? it->second.get() : nullptr; +} + +const Message* GetNestedMessage(const Message& message, + const FieldDescriptor* field, + ProtoPath proto_path, + MessageMap* stowed_messages) { + if (field->type() != FieldDescriptor::TYPE_MESSAGE) { + return nullptr; + } + const Message* result = FindStowedMessage(stowed_messages, proto_path); + if (!result) { + result = GetFieldMessage(message, field, proto_path.back().index); + } + return result; +} + // Adjusts map-entries from indexes to keys. // Protobuf map-entry order is intentionally not preserved. -mediapipe::Status KeyProtoMapEntries(Message* source) { +absl::Status KeyProtoMapEntries(Message* source, MessageMap* stowed_messages) { // Copy the rules from the source CalculatorGraphTemplate. mediapipe::CalculatorGraphTemplate rules; rules.ParsePartialFromString(source->SerializePartialAsString()); @@ -1489,11 +1584,14 @@ mediapipe::Status KeyProtoMapEntries(Message* source) { MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path)); for (int j = 0; j < path.size(); ++j) { int field_id = path[j].field_id; - int field_index = path[j].index; const FieldDescriptor* field = FindFieldByNumber(message, field_id); + ProtoPath prefix = {path.begin(), path.begin() + j + 1}; + message = GetNestedMessage(*message, field, prefix, stowed_messages); + if (!message) { + break; + } if (field->is_map()) { - const Message* map_entry = - GetFieldMessage(*message, field, path[j].index); + const Message* map_entry = message; int key_id = map_entry->GetDescriptor()->FindFieldByName("key")->number(); FieldType key_type = static_cast( @@ -1501,10 +1599,6 @@ mediapipe::Status KeyProtoMapEntries(Message* source) { std::string key_value = GetMapKey(*map_entry); path[j] = {field_id, key_id, key_type, key_value}; } - message = GetFieldMessage(*message, field, field_index); - if (!message) { - break; - } } if (!rule->path().empty()) { *rule->mutable_path() = ProtoPathJoin(path); @@ -1539,7 +1633,7 @@ class TemplateParser::Parser::MediaPipeParserImpl // Copy the template rules into the output template "rule" field. success &= MergeFields(template_rules_, output).ok(); // Replace map-entry indexes with map keys. - success &= KeyProtoMapEntries(output).ok(); + success &= KeyProtoMapEntries(output, &stowed_messages_).ok(); return success; } @@ -1565,7 +1659,7 @@ class TemplateParser::Parser::MediaPipeParserImpl DO(ConsumeFieldTemplate(message)); } else { DO(ConsumeField(message)); - StowFieldValue(message, expression); + StowFieldValue(message, expression, &stowed_messages_); } DO(ConsumeEndTemplate()); return true; @@ -1776,6 +1870,7 @@ class TemplateParser::Parser::MediaPipeParserImpl } mediapipe::CalculatorGraphTemplate template_rules_; + std::map> stowed_messages_; }; #undef DO From d5ff060bfa6930b9b6b1826b43ca0434b69050a9 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 13 Dec 2022 16:01:06 -0800 Subject: [PATCH 146/346] Internal change PiperOrigin-RevId: 495149484 --- mediapipe/graphs/object_detection_3d/calculators/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/graphs/object_detection_3d/calculators/BUILD b/mediapipe/graphs/object_detection_3d/calculators/BUILD index 783fff187..d4c5c496b 100644 --- a/mediapipe/graphs/object_detection_3d/calculators/BUILD +++ b/mediapipe/graphs/object_detection_3d/calculators/BUILD @@ -22,6 +22,7 @@ package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "gl_animation_overlay_calculator_proto", srcs = ["gl_animation_overlay_calculator.proto"], + def_options_lib = False, visibility = ["//visibility:public"], exports = [ "//mediapipe/gpu:gl_animation_overlay_calculator_proto", From 904a537b027a98c69c744cd4944f06e63c3e882d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 13 Dec 2022 16:08:54 -0800 Subject: [PATCH 147/346] Internal change PiperOrigin-RevId: 495151410 --- mediapipe/framework/BUILD | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 265ae9c6f..872944acd 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -57,12 +57,12 @@ mediapipe_proto_library( srcs = ["calculator.proto"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:mediapipe_options_proto", - "//mediapipe/framework:packet_factory_proto", - "//mediapipe/framework:packet_generator_proto", - "//mediapipe/framework:status_handler_proto", - "//mediapipe/framework:stream_handler_proto", + ":calculator_options_proto", + ":mediapipe_options_proto", + ":packet_factory_proto", + ":packet_generator_proto", + ":status_handler_proto", + ":stream_handler_proto", "@com_google_protobuf//:any_proto", ], ) @@ -79,8 +79,8 @@ mediapipe_proto_library( srcs = ["calculator_contract_test.proto"], visibility = ["//mediapipe/framework:__subpackages__"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", + ":calculator_options_proto", + ":calculator_proto", ], ) @@ -89,8 +89,8 @@ mediapipe_proto_library( srcs = ["calculator_profile.proto"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", + ":calculator_options_proto", + ":calculator_proto", ], ) @@ -126,14 +126,14 @@ mediapipe_proto_library( name = "status_handler_proto", srcs = ["status_handler.proto"], visibility = [":mediapipe_internal"], - deps = ["//mediapipe/framework:mediapipe_options_proto"], + deps = [":mediapipe_options_proto"], ) mediapipe_proto_library( name = "stream_handler_proto", srcs = ["stream_handler.proto"], visibility = [":mediapipe_internal"], - deps = ["//mediapipe/framework:mediapipe_options_proto"], + deps = [":mediapipe_options_proto"], ) mediapipe_proto_library( @@ -142,8 +142,8 @@ mediapipe_proto_library( srcs = ["test_calculators.proto"], visibility = [":mediapipe_internal"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", + ":calculator_options_proto", + ":calculator_proto", ], ) @@ -151,7 +151,7 @@ mediapipe_proto_library( name = "thread_pool_executor_proto", srcs = ["thread_pool_executor.proto"], visibility = [":mediapipe_internal"], - deps = ["//mediapipe/framework:mediapipe_options_proto"], + deps = [":mediapipe_options_proto"], ) # It is for pure-native Android builds where the library can't have any dependency on libandroid.so From b9d020cb7d32e936943b963c401cc3aeb9f88407 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Tue, 13 Dec 2022 16:58:12 -0800 Subject: [PATCH 148/346] Internal change PiperOrigin-RevId: 495163109 --- mediapipe/framework/scheduler.cc | 11 ++++++++--- mediapipe/framework/scheduler.h | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mediapipe/framework/scheduler.cc b/mediapipe/framework/scheduler.cc index afef4f383..854c10fd5 100644 --- a/mediapipe/framework/scheduler.cc +++ b/mediapipe/framework/scheduler.cc @@ -117,7 +117,7 @@ void Scheduler::SubmitWaitingTasksOnQueues() { // Note: state_mutex_ is held when this function is entered or // exited. void Scheduler::HandleIdle() { - if (handling_idle_) { + if (++handling_idle_ > 1) { // Someone is already inside this method. // Note: This can happen in the sections below where we unlock the mutex // and make more nodes runnable: the nodes can run and become idle again @@ -127,7 +127,6 @@ void Scheduler::HandleIdle() { VLOG(2) << "HandleIdle: already in progress"; return; } - handling_idle_ = true; while (IsIdle() && (state_ == STATE_RUNNING || state_ == STATE_CANCELLING)) { // Remove active sources that are closed. @@ -165,11 +164,17 @@ void Scheduler::HandleIdle() { } } + // If HandleIdle has been called again, then continue scheduling. + if (handling_idle_ > 1) { + handling_idle_ = 1; + continue; + } + // Nothing left to do. break; } - handling_idle_ = false; + handling_idle_ = 0; } // Note: state_mutex_ is held when this function is entered or exited. diff --git a/mediapipe/framework/scheduler.h b/mediapipe/framework/scheduler.h index dd1572d99..b59467b9f 100644 --- a/mediapipe/framework/scheduler.h +++ b/mediapipe/framework/scheduler.h @@ -302,7 +302,7 @@ class Scheduler { // - We need it to be reentrant, which Mutex does not support. // - We want simultaneous calls to return immediately instead of waiting, // and Mutex's TryLock is not guaranteed to work. - bool handling_idle_ ABSL_GUARDED_BY(state_mutex_) = false; + int handling_idle_ ABSL_GUARDED_BY(state_mutex_) = 0; // Mutex for the scheduler state and related things. // Note: state_ is declared as atomic so that its getter methods don't need From 6fa0a58529ab60bd93bb622e4c97a0e796bb6276 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Wed, 14 Dec 2022 00:34:22 -0800 Subject: [PATCH 149/346] Internal change PiperOrigin-RevId: 495235951 --- .../framework/GraphTextureFrame.java | 47 +++++++++++++++---- .../framework/jni/graph_texture_frame_jni.cc | 7 +++ .../framework/jni/graph_texture_frame_jni.h | 3 ++ 3 files changed, 47 insertions(+), 10 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java index efaec34a7..586b5c0a0 100644 --- a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java @@ -14,6 +14,10 @@ package com.google.mediapipe.framework; +import com.google.common.flogger.FluentLogger; +import java.util.HashSet; +import java.util.Set; + /** * A {@link TextureFrame} that represents a texture produced by MediaPipe. * @@ -21,6 +25,7 @@ package com.google.mediapipe.framework; * method. */ public class GraphTextureFrame implements TextureFrame { + private static final FluentLogger logger = FluentLogger.forEnclosingClass(); private long nativeBufferHandle; // We cache these to be able to get them without a JNI call. private int textureName; @@ -30,6 +35,7 @@ public class GraphTextureFrame implements TextureFrame { // True when created with PacketGetter.getTextureFrameDeferredSync(). This will result in gpuWait // when calling getTextureName(). private final boolean deferredSync; + private final Set activeConsumerContextHandleSet = new HashSet<>(); GraphTextureFrame(long nativeHandle, long timestamp) { this(nativeHandle, timestamp, false); @@ -54,17 +60,19 @@ public class GraphTextureFrame implements TextureFrame { * condition if release() is called after the if-check for nativeBufferHandle is already passed. */ @Override - public int getTextureName() { + public synchronized int getTextureName() { // Return special texture id 0 if handle is 0 i.e. frame is already released. if (nativeBufferHandle == 0) { return 0; } - // Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using - // PacketGetter.getTextureFrameDeferredSync(). - if (deferredSync) { - // Note that, if a CPU wait has already been done, the sync point will have been - // cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait. - nativeGpuWait(nativeBufferHandle); + if (activeConsumerContextHandleSet.add(nativeGetCurrentExternalContextHandle())) { + // Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using + // PacketGetter.getTextureFrameDeferredSync(). + if (deferredSync) { + // Note that, if a CPU wait has already been done, the sync point will have been + // cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait. + nativeGpuWait(nativeBufferHandle); + } } return textureName; } @@ -92,9 +100,14 @@ public class GraphTextureFrame implements TextureFrame { *

The consumer calls this when it is done using the texture. */ @Override - public void release() { - GlSyncToken consumerToken = - new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle)); + public synchronized void release() { + GlSyncToken consumerToken = null; + // Note that this remove should be moved to the other overload of release when b/68808951 is + // addressed. + if (activeConsumerContextHandleSet.remove(nativeGetCurrentExternalContextHandle())) { + consumerToken = + new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle)); + } release(consumerToken); } @@ -113,12 +126,24 @@ public class GraphTextureFrame implements TextureFrame { long token = consumerSyncToken == null ? 0 : consumerSyncToken.nativeToken(); nativeReleaseBuffer(nativeBufferHandle, token); nativeBufferHandle = 0; + } else if (consumerSyncToken != null) { + logger.atWarning().log("release with sync token, but handle is 0"); } if (consumerSyncToken != null) { consumerSyncToken.release(); } } + @Override + protected void finalize() throws Throwable { + if (nativeBufferHandle != 0) { + logger.atWarning().log("release was not called before finalize"); + } + if (!activeConsumerContextHandleSet.isEmpty()) { + logger.atWarning().log("active consumers did not release with sync before finalize"); + } + } + private native void nativeReleaseBuffer(long nativeHandle, long consumerSyncToken); private native int nativeGetTextureName(long nativeHandle); @@ -128,4 +153,6 @@ public class GraphTextureFrame implements TextureFrame { private native void nativeGpuWait(long nativeHandle); private native long nativeCreateSyncTokenForCurrentExternalContext(long nativeHandle); + + private native long nativeGetCurrentExternalContextHandle(); } diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc index 84df89260..963ea522e 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc @@ -15,6 +15,7 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h" #include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h" @@ -84,3 +85,9 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( } return reinterpret_cast(token); } + +JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( + nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz) { + return reinterpret_cast( + mediapipe::GlContext::GetCurrentNativeContext()); +} diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h index 45637bb31..02903c664 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h @@ -44,6 +44,9 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( nativeCreateSyncTokenForCurrentExternalContext)(JNIEnv* env, jobject thiz, jlong nativeHandle); +JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( + nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus From db6ea38cf69a72149e9b8e5e8868c6e3f33a4ac8 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Wed, 14 Dec 2022 00:37:52 -0800 Subject: [PATCH 150/346] Internal change PiperOrigin-RevId: 495236576 --- .../framework/GraphTextureFrame.java | 42 +++++++++++++++---- .../mediapipe/framework/TextureFrame.java | 14 +++++++ .../framework/jni/graph_texture_frame_jni.cc | 16 ++++--- .../framework/jni/graph_texture_frame_jni.h | 5 ++- 4 files changed, 61 insertions(+), 16 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java index 586b5c0a0..63ea7854b 100644 --- a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java @@ -36,6 +36,7 @@ public class GraphTextureFrame implements TextureFrame { // when calling getTextureName(). private final boolean deferredSync; private final Set activeConsumerContextHandleSet = new HashSet<>(); + private int refCount = 1; GraphTextureFrame(long nativeHandle, long timestamp) { this(nativeHandle, timestamp, false); @@ -94,6 +95,17 @@ public class GraphTextureFrame implements TextureFrame { return timestamp; } + @Override + public boolean supportsRetain() { + return true; + } + + @Override + public synchronized void retain() { + // TODO: check that refCount is > 0 and handle is not 0. + refCount++; + } + /** * Releases a reference to the underlying buffer. * @@ -121,22 +133,32 @@ public class GraphTextureFrame implements TextureFrame { * currently cannot create a GlSyncToken, so they cannot call this method. */ @Override - public void release(GlSyncToken consumerSyncToken) { - if (nativeBufferHandle != 0) { - long token = consumerSyncToken == null ? 0 : consumerSyncToken.nativeToken(); - nativeReleaseBuffer(nativeBufferHandle, token); - nativeBufferHandle = 0; - } else if (consumerSyncToken != null) { - logger.atWarning().log("release with sync token, but handle is 0"); + public synchronized void release(GlSyncToken consumerSyncToken) { + if (nativeBufferHandle == 0) { + if (consumerSyncToken != null) { + logger.atWarning().log("release with sync token, but handle is 0"); + } + return; } + if (consumerSyncToken != null) { + long token = consumerSyncToken.nativeToken(); + nativeDidRead(nativeBufferHandle, token); + // We should remove the token's context from activeConsumerContextHandleSet here, but for now + // we do it in the release(void) overload. consumerSyncToken.release(); } + + refCount--; + if (refCount <= 0) { + nativeReleaseBuffer(nativeBufferHandle); + nativeBufferHandle = 0; + } } @Override protected void finalize() throws Throwable { - if (nativeBufferHandle != 0) { + if (refCount >= 0 || nativeBufferHandle != 0) { logger.atWarning().log("release was not called before finalize"); } if (!activeConsumerContextHandleSet.isEmpty()) { @@ -144,7 +166,7 @@ public class GraphTextureFrame implements TextureFrame { } } - private native void nativeReleaseBuffer(long nativeHandle, long consumerSyncToken); + private native void nativeReleaseBuffer(long nativeHandle); private native int nativeGetTextureName(long nativeHandle); private native int nativeGetWidth(long nativeHandle); @@ -155,4 +177,6 @@ public class GraphTextureFrame implements TextureFrame { private native long nativeCreateSyncTokenForCurrentExternalContext(long nativeHandle); private native long nativeGetCurrentExternalContextHandle(); + + private native void nativeDidRead(long nativeHandle, long consumerSyncToken); } diff --git a/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java index babfd2958..76eaf39df 100644 --- a/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java @@ -59,4 +59,18 @@ public interface TextureFrame extends TextureReleaseCallback { */ @Override void release(GlSyncToken syncToken); + + /** + * If this method returns true, this object supports the retain method, and can be used with + * multiple consumers. Call retain for each additional consumer beyond the first; each consumer + * should call release. + */ + default boolean supportsRetain() { + return false; + } + + /** Increments the reference count. Only available with some implementations of TextureFrame. */ + default void retain() { + throw new UnsupportedOperationException(); + } } diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc index 963ea522e..dd99cccd4 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc @@ -22,14 +22,9 @@ using mediapipe::GlTextureBufferSharedPtr; JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)( - JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken) { + JNIEnv* env, jobject thiz, jlong nativeHandle) { GlTextureBufferSharedPtr* buffer = reinterpret_cast(nativeHandle); - if (consumerSyncToken) { - mediapipe::GlSyncToken& token = - *reinterpret_cast(consumerSyncToken); - (*buffer)->DidRead(token); - } delete buffer; } @@ -91,3 +86,12 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( return reinterpret_cast( mediapipe::GlContext::GetCurrentNativeContext()); } + +JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeDidRead)( + JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken) { + GlTextureBufferSharedPtr* buffer = + reinterpret_cast(nativeHandle); + mediapipe::GlSyncToken& token = + *reinterpret_cast(consumerSyncToken); + (*buffer)->DidRead(token); +} diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h index 02903c664..41c531fff 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h @@ -26,7 +26,7 @@ extern "C" { // Releases a native mediapipe::GpuBuffer. JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)( - JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken); + JNIEnv* env, jobject thiz, jlong nativeHandle); JNIEXPORT jint JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeGetTextureName)( JNIEnv* env, jobject thiz, jlong nativeHandle); @@ -44,6 +44,9 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( nativeCreateSyncTokenForCurrentExternalContext)(JNIEnv* env, jobject thiz, jlong nativeHandle); +JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeDidRead)( + JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken); + JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz); From 7efb3bcf81081c822c76bb1d7e4867e5f1f66115 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 14 Dec 2022 19:13:41 +0530 Subject: [PATCH 151/346] Added iOS task error codes --- mediapipe/tasks/ios/common/BUILD | 26 +++ .../tasks/ios/common/sources/MPPCommon.h | 179 ++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 mediapipe/tasks/ios/common/BUILD create mode 100644 mediapipe/tasks/ios/common/sources/MPPCommon.h diff --git a/mediapipe/tasks/ios/common/BUILD b/mediapipe/tasks/ios/common/BUILD new file mode 100644 index 000000000..0d00c423f --- /dev/null +++ b/mediapipe/tasks/ios/common/BUILD @@ -0,0 +1,26 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCommon", + hdrs = [ + "sources/MPPCommon.h", + ], + module_name = "MPPCommon", +) + diff --git a/mediapipe/tasks/ios/common/sources/MPPCommon.h b/mediapipe/tasks/ios/common/sources/MPPCommon.h new file mode 100644 index 000000000..427b4cb75 --- /dev/null +++ b/mediapipe/tasks/ios/common/sources/MPPCommon.h @@ -0,0 +1,179 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * @enum MPPTasksErrorCode + * This enum specifies error codes for Mediapipe Task Library. + * It maintains a 1:1 mapping to MediaPipeTasksStatus of the C ++libray. + */ +typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { + + // Generic error codes. + + // Unspecified error. + MPPTasksErrorCodeError = 1, + // Invalid argument specified. + MPPTasksErrorCodeInvalidArgumentError = 2, + // Invalid FlatBuffer file or buffer specified. + MPPTasksErrorCodeInvalidFlatBufferError = 3, + // Model contains a builtin op that isn't supported by the OpResolver or + // delegates. + MPPTasksErrorCodeUnsupportedBuiltinOp = 4, + // Model contains a custom op that isn't supported by the OpResolver or + // delegates. + MPPTasksErrorCodeUnsupportedCustomOp = 5, + + // File I/O error codes. + + // No such file. + MPPTasksErrorCodeFileNotFoundError = 100, + // Permission issue. + MPPTasksErrorCodeFilePermissionDeniedError, + // I/O error when reading file. + MPPTasksErrorCodeFileReadError, + // I/O error when mmap-ing file. + MPPTasksErrorCodeFileMmapError, + // ZIP I/O error when unpacMPPTasksErrorCodeing the zip file. + MPPTasksErrorCodeFileZipError, + + // TensorFlow Lite metadata error codes. + + // Unexpected schema version (aMPPTasksErrorCodea file_identifier) in the Metadata FlatBuffer. + MPPTasksErrorCodeMetadataInvalidSchemaVersionError = 200, + // No such associated file within metadata, or file has not been pacMPPTasksErrorCodeed. + MPPTasksErrorCodeMetadataAssociatedFileNotFoundError, + // ZIP I/O error when unpacMPPTasksErrorCodeing an associated file. + MPPTasksErrorCodeMetadataAssociatedFileZipError, + // Inconsistency error between the metadata and actual TF Lite model. + // E.g.: number of labels and output tensor values differ. + MPPTasksErrorCodeMetadataInconsistencyError, + // Invalid process units specified. + // E.g.: multiple ProcessUnits with the same type for a given tensor. + MPPTasksErrorCodeMetadataInvalidProcessUnitsError, + // Inconsistency error with the number of labels. + // E.g.: label files for different locales have a different number of labels. + MPPTasksErrorCodeMetadataNumLabelsMismatchError, + // Score calibration parameters parsing error. + // E.g.: too many parameters provided in the corresponding associated file. + MPPTasksErrorCodeMetadataMalformedScoreCalibrationError, + // Unexpected number of subgraphs for the current task. + // E.g.: image classification expects a single subgraph. + MPPTasksErrorCodeMetadataInvalidNumSubgraphsError, + // A given tensor requires NormalizationOptions but none were found. + // E.g.: float input tensor requires normalization to preprocess input images. + MPPTasksErrorCodeMetadataMissingNormalizationOptionsError, + // Invalid ContentProperties specified. + // E.g. expected ImageProperties, got BoundingBoxProperties. + MPPTasksErrorCodeMetadataInvalidContentPropertiesError, + // Metadata is mandatory but was not found. + // E.g. current task requires TFLite Model Metadata but none was found. + MPPTasksErrorCodeMetadataNotFoundError, + // Associated TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS file is mandatory but + // none was found or it was empty. + // E.g. current task requires labels but none were found. + MPPTasksErrorCodeMetadataMissingLabelsError, + // The ProcessingUnit for tokenizer is not correctly configured. + // E.g BertTokenizer doesn't have a valid vocab file associated. + MPPTasksErrorCodeMetadataInvalidTokenizerError, + + // Input tensor(s) error codes. + + // Unexpected number of input tensors for the current task. + // E.g. current task expects a single input tensor. + MPPTasksErrorCodeInvalidNumInputTensorsError = 300, + // Unexpected input tensor dimensions for the current task. + // E.g.: only 4D input tensors supported. + MPPTasksErrorCodeInvalidInputTensorDimensionsError, + // Unexpected input tensor type for the current task. + // E.g.: current task expects a uint8 pixel image as input. + MPPTasksErrorCodeInvalidInputTensorTypeError, + // Unexpected input tensor bytes size. + // E.g.: size in bytes does not correspond to the expected number of pixels. + MPPTasksErrorCodeInvalidInputTensorSizeError, + // No correct input tensor found for the model. + // E.g.: input tensor name is not part of the text model's input tensors. + MPPTasksErrorCodeInputTensorNotFoundError, + + // Output tensor(s) error codes. + + // Unexpected output tensor dimensions for the current task. + // E.g.: only a batch size of 1 is supported. + MPPTasksErrorCodeInvalidOutputTensorDimensionsError = 400, + // Unexpected input tensor type for the current task. + // E.g.: multi-head model with different output tensor types. + MPPTasksErrorCodeInvalidOutputTensorTypeError, + // No correct output tensor found for the model. + // E.g.: output tensor name is not part of the text model's output tensors. + MPPTasksErrorCodeOutputTensorNotFoundError, + // Unexpected number of output tensors for the current task. + // E.g.: current task expects a single output tensor. + MPPTasksErrorCodeInvalidNumOutputTensorsError, + + // Image processing error codes. + + // Unspecified image processing failures. + MPPTasksErrorCodeImageProcessingError = 500, + // Unexpected input or output buffer metadata. + // E.g.: rotate RGBA buffer to Grayscale buffer by 90 degrees. + MPPTasksErrorCodeImageProcessingInvalidArgumentError, + // Image processing operation failures. + // E.g. libyuv rotation failed for an unknown reason. + MPPTasksErrorCodeImageProcessingBackendError, + + // Task runner error codes. + MPPTasksErrorCodeRunnerError = 600, + // Task runner is not initialized. + MPPTasksErrorCodeRunnerInitializationError, + // Task runner is not started successfully. + MPPTasksErrorCodeRunnerFailsToStartError, + // Task runner is not started. + MPPTasksErrorCodeRunnerNotStartedError, + // Task runner API is called in the wrong processing mode. + MPPTasksErrorCodeRunnerApiCalledInWrongModeError, + // Task runner receives/produces invalid MediaPipe packet timestamp. + MPPTasksErrorCodeRunnerInvalidTimestampError, + // Task runner receives unexpected MediaPipe graph input packet. + // E.g. The packet type doesn't match the graph input stream's data type. + MPPTasksErrorCodeRunnerUnexpectedInputError, + // Task runner produces unexpected MediaPipe graph output packet. + // E.g. The number of output packets is not equal to the number of graph + // output streams. + MPPTasksErrorCodeRunnerUnexpectedOutputError, + // Task runner is not closed successfully. + MPPTasksErrorCodeRunnerFailsToCloseError, + // Task runner's model resources cache service is unavailable or the + // targeting model resources bundle is not found. + MPPTasksErrorCodeRunnerModelResourcesCacheServiceError, + + // Task graph error codes. + MPPTasksErrorCodeGraphError = 700, + // Task graph is not implemented. + MPPTasksErrorCodeTaskGraphNotImplementedError, + // Task graph config is invalid. + MPPTasksErrorCodeInvalidTaskGraphConfigError, + + MPPTasksErrorCodeFirst = MPPTasksErrorCodeError, + + /** + * The last error code in TFLSupportErrorCode (for internal use only). + */ + MPPTasksErrorCodeLast = MPPTasksErrorCodeInvalidTaskGraphConfigError, + +} NS_SWIFT_NAME(TasksErrorCode); + +NS_ASSUME_NONNULL_END From e9fb6c28f5d69e07e92ce9f88c234d7e2a0081f3 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 14 Dec 2022 19:14:02 +0530 Subject: [PATCH 152/346] Added task options --- .../tasks/ios/core/sources/MPPTaskOptions.h | 48 +++++++++++++++++++ .../tasks/ios/core/sources/MPPTaskOptions.m | 36 ++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskOptions.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskOptions.m diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h new file mode 100644 index 000000000..0195f3654 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h @@ -0,0 +1,48 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend + * this class. + */ +NS_SWIFT_NAME(TaskOptions) +@interface MPPTaskOptions : NSObject +/** + * Base options for configuring the Mediapipe task. + */ +@property(nonatomic, copy) MPPBaseOptions *baseOptions; + +/** + * Initializes a new `MPPTaskOptions` with the absolute path to the model file + * stored locally on the device, set to the given the model path. + * + * @discussion The external model file must be a single standalone TFLite file. It could be packed + * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the + * necessary metadata and associated files might result in errors. Check the [documentation] + * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * + * @return An instance of `MPPTaskOptions` initialized to the given model path. + */ +- (instancetype)initWithModelPath:(NSString *)modelPath; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m new file mode 100644 index 000000000..e45364d55 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m @@ -0,0 +1,36 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +@implementation MPPTaskOptions + +- (instancetype)init { + self = [super init]; + if (self) { + _baseOptions = [[MPPBaseOptions alloc] init]; + } + return self; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath { + self = [self init]; + if (self) { + _baseOptions.modelAssetPath = modelPath; + } + return self; +} + +@end From 22bb87d9e0346cfcdc7e4e2d61baef0f7c987912 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 14 Dec 2022 19:14:11 +0530 Subject: [PATCH 153/346] Added iOS task result --- .../tasks/ios/core/sources/MPPTaskResult.h | 34 +++++++++++++++++++ .../tasks/ios/core/sources/MPPTaskResult.m | 27 +++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskResult.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskResult.m diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h new file mode 100644 index 000000000..22171a852 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -0,0 +1,34 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend + * this class. + */ +NS_SWIFT_NAME(TaskResult) +@interface MPPTaskResult : NSObject +/** + * Base options for configuring the Mediapipe task. + */ +@property(nonatomic, assign, readonly) long timeStamp; + +- (instancetype)initWithTimeStamp:(long)timeStamp; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m new file mode 100644 index 000000000..ad74c009d --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m @@ -0,0 +1,27 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" + +@implementation MPPTaskResult + +- (instancetype)initWithTimeStamp:(long)timeStamp { + self = [self init]; + if (self) { + _timeStamp = timeStamp; + } + return self; +} + +@end From 0aedff06596a7ee43588489e8dd8ad8d2d24a7b2 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 14 Dec 2022 19:14:49 +0530 Subject: [PATCH 154/346] Added target for task options --- mediapipe/tasks/ios/core/BUILD | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index 3f1193e46..cee0fa4eb 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -21,3 +21,12 @@ objc_library( srcs = ["sources/MPPBaseOptions.m"], hdrs = ["sources/MPPBaseOptions.h"], ) + +objc_library( + name = "MPPTaskOptions", + srcs = ["sources/MPPTaskOptions.m"], + hdrs = ["sources/MPPTaskOptions.h"], + deps = [ + ":MPPBaseOptions", + ], +) From c0fed7df3116db8778052b29de6ab906a95083fa Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 14 Dec 2022 19:15:01 +0530 Subject: [PATCH 155/346] Added target for task result --- mediapipe/tasks/ios/core/BUILD | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index cee0fa4eb..7b648945e 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -30,3 +30,9 @@ objc_library( ":MPPBaseOptions", ], ) + +objc_library( + name = "MPPTaskResult", + srcs = ["sources/MPPTaskResult.m"], + hdrs = ["sources/MPPTaskResult.h"], +) From 174f2869a335a075764f1364130b7d9529b93a29 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 14 Dec 2022 08:31:49 -0800 Subject: [PATCH 156/346] Internal changes PiperOrigin-RevId: 495322170 --- mediapipe/framework/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 872944acd..0dd694760 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1,4 +1,3 @@ -# # Copyright 2019 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); From e9e173f9fa37948bcb9a028f7822c44773a2bbcf Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 14 Dec 2022 18:09:20 -0800 Subject: [PATCH 157/346] Internal change PiperOrigin-RevId: 495468694 --- mediapipe/framework/api2/port.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/api2/port.h b/mediapipe/framework/api2/port.h index e63d3651e..eee542640 100644 --- a/mediapipe/framework/api2/port.h +++ b/mediapipe/framework/api2/port.h @@ -557,8 +557,8 @@ class OutputSidePacketAccess { if (output_) output_->Set(ToOldPacket(std::move(packet))); } - void Set(const T& payload) { Set(MakePacket(payload)); } - void Set(T&& payload) { Set(MakePacket(std::move(payload))); } + void Set(const T& payload) { Set(api2::MakePacket(payload)); } + void Set(T&& payload) { Set(api2::MakePacket(std::move(payload))); } private: OutputSidePacketAccess(OutputSidePacket* output) : output_(output) {} From d526b20e19339712e12db73a2f07d07a2c919b01 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 14 Dec 2022 19:52:13 -0800 Subject: [PATCH 158/346] Internal change. PiperOrigin-RevId: 495483878 --- .../formats/tensor_hardware_buffer.h | 71 ------ .../tensor_hardware_buffer_cpu_storage.cc | 216 ------------------ ...tensor_hardware_buffer_cpu_storage_test.cc | 76 ------ 3 files changed, 363 deletions(-) delete mode 100644 mediapipe/framework/formats/tensor_hardware_buffer.h delete mode 100644 mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc delete mode 100644 mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc diff --git a/mediapipe/framework/formats/tensor_hardware_buffer.h b/mediapipe/framework/formats/tensor_hardware_buffer.h deleted file mode 100644 index fa0241bde..000000000 --- a/mediapipe/framework/formats/tensor_hardware_buffer.h +++ /dev/null @@ -1,71 +0,0 @@ -#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ -#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ - -#if !defined(MEDIAPIPE_NO_JNI) && \ - (__ANDROID_API__ >= 26 || \ - defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) - -#include - -#include - -#include "mediapipe/framework/formats/tensor_buffer.h" -#include "mediapipe/framework/formats/tensor_internal.h" -#include "mediapipe/framework/formats/tensor_v2.h" - -namespace mediapipe { - -// Supports: -// - float 16 and 32 bits -// - signed / unsigned integers 8,16,32 bits -class TensorHardwareBufferView; -struct TensorHardwareBufferViewDescriptor : public Tensor::ViewDescriptor { - using ViewT = TensorHardwareBufferView; - TensorBufferDescriptor buffer; -}; - -class TensorHardwareBufferView : public Tensor::View { - public: - TENSOR_UNIQUE_VIEW_TYPE_ID(); - ~TensorHardwareBufferView() = default; - - const TensorHardwareBufferViewDescriptor& descriptor() const override { - return descriptor_; - } - AHardwareBuffer* handle() const { return ahwb_handle_; } - - protected: - TensorHardwareBufferView(int access_capability, Tensor::View::Access access, - Tensor::View::State state, - const TensorHardwareBufferViewDescriptor& desc, - AHardwareBuffer* ahwb_handle) - : Tensor::View(kId, access_capability, access, state), - descriptor_(desc), - ahwb_handle_(ahwb_handle) {} - - private: - bool MatchDescriptor( - uint64_t view_type_id, - const Tensor::ViewDescriptor& base_descriptor) const override { - if (!Tensor::View::MatchDescriptor(view_type_id, base_descriptor)) - return false; - auto descriptor = - static_cast(base_descriptor); - return descriptor.buffer.format == descriptor_.buffer.format && - descriptor.buffer.size_alignment <= - descriptor_.buffer.size_alignment && - descriptor_.buffer.size_alignment % - descriptor.buffer.size_alignment == - 0; - } - const TensorHardwareBufferViewDescriptor& descriptor_; - AHardwareBuffer* ahwb_handle_ = nullptr; -}; - -} // namespace mediapipe - -#endif // !defined(MEDIAPIPE_NO_JNI) && \ - (__ANDROID_API__ >= 26 || \ - defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) - -#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ diff --git a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc deleted file mode 100644 index 9c223ce2c..000000000 --- a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc +++ /dev/null @@ -1,216 +0,0 @@ -#if !defined(MEDIAPIPE_NO_JNI) && \ - (__ANDROID_API__ >= 26 || \ - defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) - -#include - -#include "absl/memory/memory.h" -#include "absl/status/status.h" -#include "mediapipe/framework/formats/tensor_backend.h" -#include "mediapipe/framework/formats/tensor_cpu_buffer.h" -#include "mediapipe/framework/formats/tensor_hardware_buffer.h" -#include "mediapipe/framework/formats/tensor_v2.h" -#include "util/task/status_macros.h" - -namespace mediapipe { -namespace { - -class TensorCpuViewImpl : public TensorCpuView { - public: - TensorCpuViewImpl(int access_capabilities, Tensor::View::Access access, - Tensor::View::State state, - const TensorCpuViewDescriptor& descriptor, void* pointer, - AHardwareBuffer* ahwb_handle) - : TensorCpuView(access_capabilities, access, state, descriptor, pointer), - ahwb_handle_(ahwb_handle) {} - ~TensorCpuViewImpl() { - // If handle_ is null then this view is constructed in GetViews with no - // access. - if (ahwb_handle_) { - if (__builtin_available(android 26, *)) { - AHardwareBuffer_unlock(ahwb_handle_, nullptr); - } - } - } - - private: - AHardwareBuffer* ahwb_handle_; -}; - -class TensorHardwareBufferViewImpl : public TensorHardwareBufferView { - public: - TensorHardwareBufferViewImpl( - int access_capability, Tensor::View::Access access, - Tensor::View::State state, - const TensorHardwareBufferViewDescriptor& descriptor, - AHardwareBuffer* handle) - : TensorHardwareBufferView(access_capability, access, state, descriptor, - handle) {} - ~TensorHardwareBufferViewImpl() = default; -}; - -class HardwareBufferCpuStorage : public TensorStorage { - public: - ~HardwareBufferCpuStorage() { - if (!ahwb_handle_) return; - if (__builtin_available(android 26, *)) { - AHardwareBuffer_release(ahwb_handle_); - } - } - - static absl::Status CanProvide( - int access_capability, const Tensor::Shape& shape, uint64_t view_type_id, - const Tensor::ViewDescriptor& base_descriptor) { - // TODO: use AHardwareBuffer_isSupported for API >= 29. - static const bool is_ahwb_supported = [] { - if (__builtin_available(android 26, *)) { - AHardwareBuffer_Desc desc = {}; - // Aligned to the largest possible virtual memory page size. - constexpr uint32_t kPageSize = 16384; - desc.width = kPageSize; - desc.height = 1; - desc.layers = 1; - desc.format = AHARDWAREBUFFER_FORMAT_BLOB; - desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | - AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN; - AHardwareBuffer* handle; - if (AHardwareBuffer_allocate(&desc, &handle) != 0) return false; - AHardwareBuffer_release(handle); - return true; - } - return false; - }(); - if (!is_ahwb_supported) { - return absl::UnavailableError( - "AHardwareBuffer is not supported on the platform."); - } - - if (view_type_id != TensorCpuView::kId && - view_type_id != TensorHardwareBufferView::kId) { - return absl::InvalidArgumentError( - "A view type is not supported by this storage."); - } - return absl::OkStatus(); - } - - std::vector> GetViews(uint64_t latest_version) { - std::vector> result; - auto update_state = latest_version == version_ - ? Tensor::View::State::kUpToDate - : Tensor::View::State::kOutdated; - if (ahwb_handle_) { - result.push_back( - std::unique_ptr(new TensorHardwareBufferViewImpl( - kAccessCapability, Tensor::View::Access::kNoAccess, update_state, - hw_descriptor_, ahwb_handle_))); - - result.push_back(std::unique_ptr(new TensorCpuViewImpl( - kAccessCapability, Tensor::View::Access::kNoAccess, update_state, - cpu_descriptor_, nullptr, nullptr))); - } - return result; - } - - absl::StatusOr> GetView( - Tensor::View::Access access, const Tensor::Shape& shape, - uint64_t latest_version, uint64_t view_type_id, - const Tensor::ViewDescriptor& base_descriptor, int access_capability) { - MP_RETURN_IF_ERROR( - CanProvide(access_capability, shape, view_type_id, base_descriptor)); - const auto& buffer_descriptor = - view_type_id == TensorHardwareBufferView::kId - ? static_cast( - base_descriptor) - .buffer - : static_cast(base_descriptor) - .buffer; - if (!ahwb_handle_) { - if (__builtin_available(android 26, *)) { - AHardwareBuffer_Desc desc = {}; - desc.width = TensorBufferSize(buffer_descriptor, shape); - desc.height = 1; - desc.layers = 1; - desc.format = AHARDWAREBUFFER_FORMAT_BLOB; - // TODO: Use access capabilities to set hints. - desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | - AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN; - auto error = AHardwareBuffer_allocate(&desc, &ahwb_handle_); - if (error != 0) { - return absl::UnknownError( - absl::StrCat("Error allocating hardware buffer: ", error)); - } - // Fill all possible views to provide it as proto views. - hw_descriptor_.buffer = buffer_descriptor; - cpu_descriptor_.buffer = buffer_descriptor; - } - } - if (buffer_descriptor.format != hw_descriptor_.buffer.format || - buffer_descriptor.size_alignment > - hw_descriptor_.buffer.size_alignment || - hw_descriptor_.buffer.size_alignment % - buffer_descriptor.size_alignment > - 0) { - return absl::AlreadyExistsError( - "A view with different params is already allocated with this " - "storage"); - } - - absl::StatusOr> result; - if (view_type_id == TensorHardwareBufferView::kId) { - result = GetAhwbView(access, shape, base_descriptor); - } else { - result = GetCpuView(access, shape, base_descriptor); - } - if (result.ok()) version_ = latest_version; - return result; - } - - private: - absl::StatusOr> GetAhwbView( - Tensor::View::Access access, const Tensor::Shape& shape, - const Tensor::ViewDescriptor& base_descriptor) { - return std::unique_ptr(new TensorHardwareBufferViewImpl( - kAccessCapability, access, Tensor::View::State::kUpToDate, - hw_descriptor_, ahwb_handle_)); - } - - absl::StatusOr> GetCpuView( - Tensor::View::Access access, const Tensor::Shape& shape, - const Tensor::ViewDescriptor& base_descriptor) { - void* pointer = nullptr; - if (__builtin_available(android 26, *)) { - int error = - AHardwareBuffer_lock(ahwb_handle_, - access == Tensor::View::Access::kWriteOnly - ? AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN - : AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN, - -1, nullptr, &pointer); - if (error != 0) { - return absl::UnknownError( - absl::StrCat("Error locking hardware buffer: ", error)); - } - } - return std::unique_ptr( - new TensorCpuViewImpl(access == Tensor::View::Access::kWriteOnly - ? Tensor::View::AccessCapability::kWrite - : Tensor::View::AccessCapability::kRead, - access, Tensor::View::State::kUpToDate, - cpu_descriptor_, pointer, ahwb_handle_)); - } - - static constexpr int kAccessCapability = - Tensor::View::AccessCapability::kRead | - Tensor::View::AccessCapability::kWrite; - TensorHardwareBufferViewDescriptor hw_descriptor_; - AHardwareBuffer* ahwb_handle_ = nullptr; - - TensorCpuViewDescriptor cpu_descriptor_; - uint64_t version_ = 0; -}; -TENSOR_REGISTER_STORAGE(HardwareBufferCpuStorage); - -} // namespace -} // namespace mediapipe - -#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || - // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) diff --git a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc deleted file mode 100644 index 0afa9899f..000000000 --- a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc +++ /dev/null @@ -1,76 +0,0 @@ - -#if !defined(MEDIAPIPE_NO_JNI) && \ - (__ANDROID_API__ >= 26 || \ - defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) -#include - -#include - -#include "mediapipe/framework/formats/tensor_cpu_buffer.h" -#include "mediapipe/framework/formats/tensor_hardware_buffer.h" -#include "mediapipe/framework/formats/tensor_v2.h" -#include "testing/base/public/gmock.h" -#include "testing/base/public/gunit.h" - -namespace mediapipe { - -namespace { - -class TensorHardwareBufferTest : public ::testing::Test { - public: - TensorHardwareBufferTest() {} - ~TensorHardwareBufferTest() override {} -}; - -TEST_F(TensorHardwareBufferTest, TestFloat32) { - Tensor tensor{Tensor::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorHardwareBufferViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - EXPECT_NE(view->handle(), nullptr); - } - { - const auto& const_tensor = tensor; - MP_ASSERT_OK_AND_ASSIGN( - auto view, - const_tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - EXPECT_NE(view->data(), nullptr); - } -} - -TEST_F(TensorHardwareBufferTest, TestInt8Padding) { - Tensor tensor{Tensor::Shape({1})}; - - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorHardwareBufferViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kInt8, - .size_alignment = 4}})); - EXPECT_NE(view->handle(), nullptr); - } - { - const auto& const_tensor = tensor; - MP_ASSERT_OK_AND_ASSIGN( - auto view, - const_tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kInt8}})); - EXPECT_NE(view->data(), nullptr); - } -} - -} // namespace - -} // namespace mediapipe - -#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || - // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) From bf91c5240782364792739cef7deabbc60c6db77e Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 15 Dec 2022 10:21:07 +0530 Subject: [PATCH 159/346] Fixed typos --- mediapipe/tasks/ios/common/sources/MPPCommon.h | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mediapipe/tasks/ios/common/sources/MPPCommon.h b/mediapipe/tasks/ios/common/sources/MPPCommon.h index 427b4cb75..b3d715520 100644 --- a/mediapipe/tasks/ios/common/sources/MPPCommon.h +++ b/mediapipe/tasks/ios/common/sources/MPPCommon.h @@ -48,16 +48,16 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { MPPTasksErrorCodeFileReadError, // I/O error when mmap-ing file. MPPTasksErrorCodeFileMmapError, - // ZIP I/O error when unpacMPPTasksErrorCodeing the zip file. + // ZIP I/O error when unpacking the zip file. MPPTasksErrorCodeFileZipError, // TensorFlow Lite metadata error codes. - // Unexpected schema version (aMPPTasksErrorCodea file_identifier) in the Metadata FlatBuffer. + // Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer. MPPTasksErrorCodeMetadataInvalidSchemaVersionError = 200, - // No such associated file within metadata, or file has not been pacMPPTasksErrorCodeed. + // No such associated file within metadata, or file has not been packed. MPPTasksErrorCodeMetadataAssociatedFileNotFoundError, - // ZIP I/O error when unpacMPPTasksErrorCodeing an associated file. + // ZIP I/O error when unpacking an associated file. MPPTasksErrorCodeMetadataAssociatedFileZipError, // Inconsistency error between the metadata and actual TF Lite model. // E.g.: number of labels and output tensor values differ. @@ -167,11 +167,10 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { // Task graph config is invalid. MPPTasksErrorCodeInvalidTaskGraphConfigError, + // The first error code in MPPTasksErrorCode (for internal use only). MPPTasksErrorCodeFirst = MPPTasksErrorCodeError, - /** - * The last error code in TFLSupportErrorCode (for internal use only). - */ + // The last error code in MPPTasksErrorCode (for internal use only). MPPTasksErrorCodeLast = MPPTasksErrorCodeInvalidTaskGraphConfigError, } NS_SWIFT_NAME(TasksErrorCode); From fe7fbc0b38b23a0639d816ffdd3fc64da0734c9b Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 15 Dec 2022 10:21:14 +0530 Subject: [PATCH 160/346] Fixed comment --- mediapipe/tasks/ios/core/sources/MPPTaskOptions.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h index 0195f3654..6a00de6f5 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h @@ -22,6 +22,7 @@ NS_ASSUME_NONNULL_BEGIN * this class. */ NS_SWIFT_NAME(TaskOptions) + @interface MPPTaskOptions : NSObject /** * Base options for configuring the Mediapipe task. @@ -32,10 +33,9 @@ NS_SWIFT_NAME(TaskOptions) * Initializes a new `MPPTaskOptions` with the absolute path to the model file * stored locally on the device, set to the given the model path. * - * @discussion The external model file must be a single standalone TFLite file. It could be packed - * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the - * necessary metadata and associated files might result in errors. Check the [documentation] - * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. + * @discussion The external model file must be a single standalone TFLite file. It must be packed + * with TFLite Model Metadata[1] and associated files. Failure to provide the + * necessary metadata and associated files will result in errors. * * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. * From 9ab010758421f6e8cea9d840ff597181c67070a8 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 15 Dec 2022 10:21:22 +0530 Subject: [PATCH 161/346] Added new line --- mediapipe/tasks/ios/core/sources/MPPTaskResult.h | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h index 22171a852..89555fe32 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -21,6 +21,7 @@ NS_ASSUME_NONNULL_BEGIN * this class. */ NS_SWIFT_NAME(TaskResult) + @interface MPPTaskResult : NSObject /** * Base options for configuring the Mediapipe task. From 163b13d7de654bd996e26a3c2f6659ca9d481833 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 15 Dec 2022 10:23:27 +0530 Subject: [PATCH 162/346] Updated comments --- mediapipe/tasks/ios/core/sources/MPPTaskResult.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h index 89555fe32..f1707a767 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -17,14 +17,14 @@ NS_ASSUME_NONNULL_BEGIN /** - * MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend + * MediaPipe Tasks result base class. Any MediaPipe task result class should extend * this class. */ NS_SWIFT_NAME(TaskResult) @interface MPPTaskResult : NSObject /** - * Base options for configuring the Mediapipe task. + * Timestamp that is associated with the task result object. */ @property(nonatomic, assign, readonly) long timeStamp; From 5ab17fe686ab2fd20936f3351f7df6c619ff9684 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 15 Dec 2022 10:28:50 +0530 Subject: [PATCH 163/346] Removed convenience initializer --- mediapipe/tasks/ios/core/sources/MPPTaskOptions.h | 14 -------------- mediapipe/tasks/ios/core/sources/MPPTaskOptions.m | 8 -------- 2 files changed, 22 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h index 6a00de6f5..ee2f7d032 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h @@ -29,20 +29,6 @@ NS_SWIFT_NAME(TaskOptions) */ @property(nonatomic, copy) MPPBaseOptions *baseOptions; -/** - * Initializes a new `MPPTaskOptions` with the absolute path to the model file - * stored locally on the device, set to the given the model path. - * - * @discussion The external model file must be a single standalone TFLite file. It must be packed - * with TFLite Model Metadata[1] and associated files. Failure to provide the - * necessary metadata and associated files will result in errors. - * - * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. - * - * @return An instance of `MPPTaskOptions` initialized to the given model path. - */ -- (instancetype)initWithModelPath:(NSString *)modelPath; - @end NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m index e45364d55..e3cf6684a 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m @@ -25,12 +25,4 @@ return self; } -- (instancetype)initWithModelPath:(NSString *)modelPath { - self = [self init]; - if (self) { - _baseOptions.modelAssetPath = modelPath; - } - return self; -} - @end From 6db5eabe0b4ec6090f4dc45c241ab24aa0f2d59e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 15 Dec 2022 00:41:27 -0800 Subject: [PATCH 164/346] Internal change PiperOrigin-RevId: 495525736 --- docs/solutions/holistic.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/solutions/holistic.md b/docs/solutions/holistic.md index 8c552834e..11589425d 100644 --- a/docs/solutions/holistic.md +++ b/docs/solutions/holistic.md @@ -259,6 +259,7 @@ mp_holistic = mp.solutions.holistic # For static images: IMAGE_FILES = [] +BG_COLOR = (192, 192, 192) # gray with mp_holistic.Holistic( static_image_mode=True, model_complexity=2, From 675420341fca61cceaa9d6b8054b858c0695bd6e Mon Sep 17 00:00:00 2001 From: Ayush Gupta Date: Thu, 15 Dec 2022 16:06:54 +0530 Subject: [PATCH 165/346] Internal Change --- .github/bot_config.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/bot_config.yml b/.github/bot_config.yml index 8ad724168..74a60e4b9 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -15,4 +15,5 @@ # A list of assignees assignees: - - sureshdagooglecom + - kuaashish + - ayushgdev From 299aa03302d66d1ed449eaf10e01702b633538ac Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 15 Dec 2022 09:20:22 -0800 Subject: [PATCH 166/346] Internal change PiperOrigin-RevId: 495613573 --- .../audioclassifier/AudioClassifier.java | 2 + .../audio/audioembedder/AudioEmbedder.java | 2 + .../com/google/mediapipe/tasks/core/BUILD | 12 +++ .../google/mediapipe/tasks/core/TaskInfo.java | 12 ++- .../mediapipe/tasks/core/TaskRunner.java | 29 +++++- .../core/logging/TasksStatsDummyLogger.java | 78 +++++++++++++++ .../tasks/core/logging/TasksStatsLogger.java | 98 +++++++++++++++++++ .../text/textclassifier/TextClassifier.java | 1 + .../tasks/text/textembedder/TextEmbedder.java | 1 + .../gesturerecognizer/GestureRecognizer.java | 2 + .../vision/handlandmarker/HandLandmarker.java | 2 + .../imageclassifier/ImageClassifier.java | 2 + .../vision/imageembedder/ImageEmbedder.java | 2 + .../vision/objectdetector/ObjectDetector.java | 2 + 14 files changed, 239 insertions(+), 6 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsDummyLogger.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsLogger.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java index d78685fe3..4e5cd7655 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java @@ -203,6 +203,8 @@ public final class AudioClassifier extends BaseAudioTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(AudioClassifier.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java index 4bc505d84..077f28ca2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java @@ -200,6 +200,8 @@ public final class AudioEmbedder extends BaseAudioTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(AudioEmbedder.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index 31f885267..3eb28d38b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -22,6 +22,7 @@ android_library( ], manifest = "AndroidManifest.xml", deps = [ + ":logging", "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", "//mediapipe/framework:calculator_java_proto_lite", @@ -37,6 +38,17 @@ android_library( ], ) +android_library( + name = "logging", + srcs = glob( + ["logging/*.java"], + ), + deps = [ + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_core_aar") mediapipe_tasks_core_aar( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java index 12f8be8ba..310f5739c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java @@ -32,6 +32,12 @@ public abstract class TaskInfo { /** Builder for {@link TaskInfo}. */ @AutoValue.Builder public abstract static class Builder { + /** Sets the MediaPipe task name. */ + public abstract Builder setTaskName(String value); + + /** Sets the MediaPipe task running mode name. */ + public abstract Builder setTaskRunningModeName(String value); + /** Sets the MediaPipe task graph name. */ public abstract Builder setTaskGraphName(String value); @@ -71,6 +77,10 @@ public abstract class TaskInfo { } } + abstract String taskName(); + + abstract String taskRunningModeName(); + abstract String taskGraphName(); abstract T taskOptions(); @@ -82,7 +92,7 @@ public abstract class TaskInfo { abstract Boolean enableFlowLimiting(); public static Builder builder() { - return new AutoValue_TaskInfo.Builder(); + return new AutoValue_TaskInfo.Builder().setTaskName("").setTaskRunningModeName(""); } /* Returns a list of the output stream names without the stream tags. */ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java index e6fc91cf6..1a128c538 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java @@ -21,6 +21,8 @@ import com.google.mediapipe.framework.AndroidPacketCreator; import com.google.mediapipe.framework.Graph; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.tasks.core.logging.TasksStatsLogger; +import com.google.mediapipe.tasks.core.logging.TasksStatsDummyLogger; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; @@ -34,6 +36,7 @@ public class TaskRunner implements AutoCloseable { private final Graph graph; private final ModelResourcesCache modelResourcesCache; private final AndroidPacketCreator packetCreator; + private final TasksStatsLogger statsLogger; private long lastSeenTimestamp = Long.MIN_VALUE; private ErrorListener errorListener; @@ -51,6 +54,8 @@ public class TaskRunner implements AutoCloseable { Context context, TaskInfo taskInfo, OutputHandler outputHandler) { + TasksStatsLogger statsLogger = + TasksStatsDummyLogger.create(context, taskInfo.taskName(), taskInfo.taskRunningModeName()); AndroidAssetUtil.initializeNativeAssetManager(context); Graph mediapipeGraph = new Graph(); mediapipeGraph.loadBinaryGraph(taskInfo.generateGraphConfig()); @@ -58,12 +63,15 @@ public class TaskRunner implements AutoCloseable { mediapipeGraph.setServiceObject(new ModelResourcesCacheService(), graphModelResourcesCache); mediapipeGraph.addMultiStreamCallback( taskInfo.outputStreamNames(), - outputHandler::run, - /*observeTimestampBounds=*/ outputHandler.handleTimestampBoundChanges()); + packets -> { + outputHandler.run(packets); + statsLogger.recordInvocationEnd(packets.get(0).getTimestamp()); + }, + /* observeTimestampBounds= */ outputHandler.handleTimestampBoundChanges()); mediapipeGraph.startRunningGraph(); // Waits until all calculators are opened and the graph is fully started. mediapipeGraph.waitUntilGraphIdle(); - return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler); + return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler, statsLogger); } /** @@ -91,7 +99,10 @@ public class TaskRunner implements AutoCloseable { * @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs. */ public synchronized TaskResult process(Map inputs) { - addPackets(inputs, generateSyntheticTimestamp()); + long syntheticInputTimestamp = generateSyntheticTimestamp(); + // TODO: Support recording GPU input arrival. + statsLogger.recordCpuInputArrival(syntheticInputTimestamp); + addPackets(inputs, syntheticInputTimestamp); graph.waitUntilGraphIdle(); lastSeenTimestamp = outputHandler.getLatestOutputTimestamp(); return outputHandler.retrieveCachedTaskResult(); @@ -112,6 +123,7 @@ public class TaskRunner implements AutoCloseable { */ public synchronized TaskResult process(Map inputs, long inputTimestamp) { validateInputTimstamp(inputTimestamp); + statsLogger.recordCpuInputArrival(inputTimestamp); addPackets(inputs, inputTimestamp); graph.waitUntilGraphIdle(); return outputHandler.retrieveCachedTaskResult(); @@ -132,6 +144,7 @@ public class TaskRunner implements AutoCloseable { */ public synchronized void send(Map inputs, long inputTimestamp) { validateInputTimstamp(inputTimestamp); + statsLogger.recordCpuInputArrival(inputTimestamp); addPackets(inputs, inputTimestamp); } @@ -145,6 +158,7 @@ public class TaskRunner implements AutoCloseable { graphStarted.set(false); graph.closeAllPacketSources(); graph.waitUntilGraphDone(); + statsLogger.logSessionEnd(); } catch (MediaPipeException e) { reportError(e); } @@ -154,6 +168,7 @@ public class TaskRunner implements AutoCloseable { // Waits until all calculators are opened and the graph is fully restarted. graph.waitUntilGraphIdle(); graphStarted.set(true); + statsLogger.logSessionStart(); } catch (MediaPipeException e) { reportError(e); } @@ -169,6 +184,7 @@ public class TaskRunner implements AutoCloseable { graphStarted.set(false); graph.closeAllPacketSources(); graph.waitUntilGraphDone(); + statsLogger.logSessionEnd(); if (modelResourcesCache != null) { modelResourcesCache.release(); } @@ -247,12 +263,15 @@ public class TaskRunner implements AutoCloseable { private TaskRunner( Graph graph, ModelResourcesCache modelResourcesCache, - OutputHandler outputHandler) { + OutputHandler outputHandler, + TasksStatsLogger statsLogger) { this.outputHandler = outputHandler; this.graph = graph; this.modelResourcesCache = modelResourcesCache; this.packetCreator = new AndroidPacketCreator(graph); + this.statsLogger = statsLogger; graphStarted.set(true); + this.statsLogger.logSessionStart(); } /** Reports error. */ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsDummyLogger.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsDummyLogger.java new file mode 100644 index 000000000..c10b5d224 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsDummyLogger.java @@ -0,0 +1,78 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core.logging; + +import android.content.Context; + +/** A dummy MediaPipe Tasks stats logger that has all methods as no-ops. */ +public class TasksStatsDummyLogger implements TasksStatsLogger { + + /** + * Creates the MediaPipe Tasks stats dummy logger. + * + * @param context a {@link Context}. + * @param taskNameStr the task api name. + * @param taskRunningModeStr the task running mode string representation. + */ + public static TasksStatsDummyLogger create( + Context context, String taskNameStr, String taskRunningModeStr) { + return new TasksStatsDummyLogger(); + } + + private TasksStatsDummyLogger() {} + + /** Logs the start of a MediaPipe Tasks API session. */ + @Override + public void logSessionStart() {} + + /** + * Records MediaPipe Tasks API receiving CPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + @Override + public void recordCpuInputArrival(long packetTimestamp) {} + + /** + * Records MediaPipe Tasks API receiving GPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + @Override + public void recordGpuInputArrival(long packetTimestamp) {} + + /** + * Records the end of a Mediapipe Tasks API invocation. + * + * @param packetTimestamp the output packet timestamp that acts as the identifier of the api + * invocation. + */ + @Override + public void recordInvocationEnd(long packetTimestamp) {} + + /** Logs the MediaPipe Tasks API periodic invocation report. */ + @Override + public void logInvocationReport(StatsSnapshot stats) {} + + /** Logs the Tasks API session end event. */ + @Override + public void logSessionEnd() {} + + /** Logs the MediaPipe Tasks API initialization error. */ + @Override + public void logInitError() {} +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsLogger.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsLogger.java new file mode 100644 index 000000000..c726e7d0d --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsLogger.java @@ -0,0 +1,98 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core.logging; + +import com.google.auto.value.AutoValue; + +/** The stats logger interface that defines what MediaPipe Tasks events to log. */ +public interface TasksStatsLogger { + /** Task stats snapshot. */ + @AutoValue + abstract static class StatsSnapshot { + static StatsSnapshot create( + int cpuInputCount, + int gpuInputCount, + int finishedCount, + int droppedCount, + long totalLatencyMs, + long peakLatencyMs, + long elapsedTimeMs) { + return new AutoValue_TasksStatsLogger_StatsSnapshot( + cpuInputCount, + gpuInputCount, + finishedCount, + droppedCount, + totalLatencyMs, + peakLatencyMs, + elapsedTimeMs); + } + + static StatsSnapshot createDefault() { + return new AutoValue_TasksStatsLogger_StatsSnapshot(0, 0, 0, 0, 0, 0, 0); + } + + abstract int cpuInputCount(); + + abstract int gpuInputCount(); + + abstract int finishedCount(); + + abstract int droppedCount(); + + abstract long totalLatencyMs(); + + abstract long peakLatencyMs(); + + abstract long elapsedTimeMs(); + } + + /** Logs the start of a MediaPipe Tasks API session. */ + public void logSessionStart(); + + /** + * Records MediaPipe Tasks API receiving CPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + public void recordCpuInputArrival(long packetTimestamp); + + /** + * Records MediaPipe Tasks API receiving GPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + public void recordGpuInputArrival(long packetTimestamp); + + /** + * Records the end of a Mediapipe Tasks API invocation. + * + * @param packetTimestamp the output packet timestamp that acts as the identifier of the api + * invocation. + */ + public void recordInvocationEnd(long packetTimestamp); + + /** Logs the MediaPipe Tasks API periodic invocation report. */ + public void logInvocationReport(StatsSnapshot stats); + + /** Logs the Tasks API session end event. */ + public void logSessionEnd(); + + /** Logs the MediaPipe Tasks API initialization error. */ + public void logInitError(); + + // TODO: Logs more error types. +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java index 0ea91a9f8..edb78a191 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java @@ -169,6 +169,7 @@ public final class TextClassifier implements AutoCloseable { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(TextClassifier.class.getSimpleName()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java index 9b464d0e8..28f351d4b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java @@ -159,6 +159,7 @@ public final class TextEmbedder implements AutoCloseable { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(TextEmbedder.class.getSimpleName()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java index e9e74a067..a933d2f65 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -194,6 +194,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(GestureRecognizer.class.getSimpleName()) + .setTaskRunningModeName(recognizerOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java index a9270d347..1d08ab928 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java @@ -183,6 +183,8 @@ public final class HandLandmarker extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(HandLandmarker.class.getSimpleName()) + .setTaskRunningModeName(landmarkerOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index 8990f46fd..38482797c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -197,6 +197,8 @@ public final class ImageClassifier extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(ImageClassifier.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java index af053d860..488927257 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java @@ -180,6 +180,8 @@ public final class ImageEmbedder extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(ImageEmbedder.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java index 769b9137f..d706189ee 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java @@ -190,6 +190,8 @@ public final class ObjectDetector extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(ObjectDetector.class.getSimpleName()) + .setTaskRunningModeName(detectorOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) From fd50b6aa2f6d1a8f69163fbda4db763bbd2862f4 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 15 Dec 2022 11:49:06 -0800 Subject: [PATCH 167/346] Add a new python unit test to test creating mediapipe Image from cvmat. PiperOrigin-RevId: 495655719 --- mediapipe/python/image_test.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/mediapipe/python/image_test.py b/mediapipe/python/image_test.py index 117d20974..cd9124948 100644 --- a/mediapipe/python/image_test.py +++ b/mediapipe/python/image_test.py @@ -28,6 +28,8 @@ import PIL.Image from mediapipe.python._framework_bindings import image from mediapipe.python._framework_bindings import image_frame +TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' + Image = image.Image ImageFormat = image_frame.ImageFormat @@ -187,5 +189,26 @@ class ImageTest(absltest.TestCase): gc.collect() self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count) + def test_image_create_from_cvmat(self): + image_path = os.path.join(os.path.dirname(__file__), + 'solutions/testdata/hands.jpg') + mat = cv2.imread(image_path).astype(np.uint8) + mat = cv2.cvtColor(mat, cv2.COLOR_BGR2RGB) + rgb_image = Image(image_format=ImageFormat.SRGB, data=mat) + self.assertEqual(rgb_image.width, 720) + self.assertEqual(rgb_image.height, 382) + self.assertEqual(rgb_image.channels, 3) + self.assertEqual(rgb_image.image_format, ImageFormat.SRGB) + self.assertTrue(np.array_equal(mat, rgb_image.numpy_view())) + + def test_image_create_from_file(self): + image_path = os.path.join(os.path.dirname(__file__), + 'solutions/testdata/hands.jpg') + loaded_image = Image.create_from_file(image_path) + self.assertEqual(loaded_image.width, 720) + self.assertEqual(loaded_image.height, 382) + self.assertEqual(loaded_image.channels, 3) + self.assertEqual(loaded_image.image_format, ImageFormat.SRGB) + if __name__ == '__main__': absltest.main() From 62f0034033ca9b4c0106bd7987669ce3604d2571 Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Thu, 15 Dec 2022 14:17:23 -0800 Subject: [PATCH 168/346] Internal change PiperOrigin-RevId: 495694817 --- .github/ISSUE_TEMPLATE/13-solution-issue.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/13-solution-issue.md b/.github/ISSUE_TEMPLATE/13-solution-issue.md index 9297edf6b..bf0d613c9 100644 --- a/.github/ISSUE_TEMPLATE/13-solution-issue.md +++ b/.github/ISSUE_TEMPLATE/13-solution-issue.md @@ -1,6 +1,6 @@ --- name: "Solution (legacy) Issue" -about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions), such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc. +about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions) such as "Pose", including inference model usage/training, solution-specific calculators etc. labels: type:support --- From 8d2473c751ca53fd22164a51b66bbd7fa13375f1 Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Thu, 15 Dec 2022 15:42:00 -0800 Subject: [PATCH 169/346] Update `Image` docs to improve rendering. The [API docs](https://developers.google.com/mediapipe/api/solutions/python/mp/Image) have a few rendering issues. e.g., the doc generator will turn ``` This block: Anything here ``` Into a table with heading `This block` and `Anything here` as a plain-text cell. In order to render code as code, it needs to be in backticks. They can also be in `>>> code()` format, and we can try to run them ([doctests](https://docs.python.org/3/library/doctest.html)). I'll have a dashboard ready soon that shows areas we can improve. PiperOrigin-RevId: 495715576 --- mediapipe/python/pybind/image.cc | 46 ++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/mediapipe/python/pybind/image.cc b/mediapipe/python/pybind/image.cc index 5d8663143..e5fa24e8c 100644 --- a/mediapipe/python/pybind/image.cc +++ b/mediapipe/python/pybind/image.cc @@ -48,16 +48,19 @@ void ImageSubmodule(pybind11::module* module) { become immutable after creation. Creation examples: - import cv2 - cv_mat = cv2.imread(input_file)[:, :, ::-1] - rgb_frame = mp.Image(format=ImageFormat.SRGB, data=cv_mat) - gray_frame = mp.Image( - format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) - from PIL import Image - pil_img = Image.new('RGB', (60, 30), color = 'red') - image = mp.Image( - format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + ```python + import cv2 + cv_mat = cv2.imread(input_file)[:, :, ::-1] + rgb_frame = mp.Image(format=ImageFormat.SRGB, data=cv_mat) + gray_frame = mp.Image( + format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) + + from PIL import Image + pil_img = Image.new('RGB', (60, 30), color = 'red') + image = mp.Image( + format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + ``` The pixel data in an Image can be retrieved as a numpy ndarray by calling `Image.numpy_view()`. The returned numpy ndarray is a reference to the @@ -65,15 +68,18 @@ void ImageSubmodule(pybind11::module* module) { numpy ndarray, it's required to obtain a copy of it. Pixel data retrieval examples: - for channel in range(num_channel): - for col in range(width): - for row in range(height): - print(image[row, col, channel]) - output_ndarray = image.numpy_view() - print(output_ndarray[0, 0, 0]) - copied_ndarray = np.copy(output_ndarray) - copied_ndarray[0,0,0] = 0 + ```python + for channel in range(num_channel): + for col in range(width): + for row in range(height): + print(image[row, col, channel]) + + output_ndarray = image.numpy_view() + print(output_ndarray[0, 0, 0]) + copied_ndarray = np.copy(output_ndarray) + copied_ndarray[0,0,0] = 0 + ``` )doc", py::dynamic_attr()); @@ -156,9 +162,11 @@ void ImageSubmodule(pybind11::module* module) { An unwritable numpy ndarray. Examples: + ``` output_ndarray = image.numpy_view() copied_ndarray = np.copy(output_ndarray) copied_ndarray[0,0,0] = 0 + ``` )doc"); image.def( @@ -191,10 +199,12 @@ void ImageSubmodule(pybind11::module* module) { IndexError: If the index is invalid or out of bounds. Examples: + ``` for channel in range(num_channel): for col in range(width): for row in range(height): print(image[row, col, channel]) + ``` )doc"); image @@ -224,7 +234,9 @@ void ImageSubmodule(pybind11::module* module) { A boolean. Examples: + ``` image.is_aligned(16) + ``` )doc"); image.def_static( From 6bf5648430108154b2c2f24c1e8edb78275deac9 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 15 Dec 2022 17:40:29 -0800 Subject: [PATCH 170/346] Fix the documentation of the constructor of Image and ImageFrame Python classes. PiperOrigin-RevId: 495739875 --- mediapipe/python/pybind/image.cc | 7 ++++--- mediapipe/python/pybind/image_frame.cc | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mediapipe/python/pybind/image.cc b/mediapipe/python/pybind/image.cc index e5fa24e8c..1bcca12ff 100644 --- a/mediapipe/python/pybind/image.cc +++ b/mediapipe/python/pybind/image.cc @@ -52,14 +52,15 @@ void ImageSubmodule(pybind11::module* module) { ```python import cv2 cv_mat = cv2.imread(input_file)[:, :, ::-1] - rgb_frame = mp.Image(format=ImageFormat.SRGB, data=cv_mat) + rgb_frame = mp.Image(image_format=ImageFormat.SRGB, data=cv_mat) gray_frame = mp.Image( - format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) + image_format=ImageFormat.GRAY, + data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) from PIL import Image pil_img = Image.new('RGB', (60, 30), color = 'red') image = mp.Image( - format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) ``` The pixel data in an Image can be retrieved as a numpy ndarray by calling diff --git a/mediapipe/python/pybind/image_frame.cc b/mediapipe/python/pybind/image_frame.cc index a7fc6bfe4..bc7a9753d 100644 --- a/mediapipe/python/pybind/image_frame.cc +++ b/mediapipe/python/pybind/image_frame.cc @@ -83,14 +83,15 @@ void ImageFrameSubmodule(pybind11::module* module) { Creation examples: import cv2 cv_mat = cv2.imread(input_file)[:, :, ::-1] - rgb_frame = mp.ImageFrame(format=ImageFormat.SRGB, data=cv_mat) + rgb_frame = mp.ImageFrame(image_format=ImageFormat.SRGB, data=cv_mat) gray_frame = mp.ImageFrame( - format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) + image_format=ImageFormat.GRAY, + data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) from PIL import Image pil_img = Image.new('RGB', (60, 30), color = 'red') image_frame = mp.ImageFrame( - format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) The pixel data in an ImageFrame can be retrieved as a numpy ndarray by calling `ImageFrame.numpy_view()`. The returned numpy ndarray is a reference to the From 0a1f050f1fbff3b70c351178eab9ff6b94fcb1db Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 15 Dec 2022 17:50:48 -0800 Subject: [PATCH 171/346] Internal change PiperOrigin-RevId: 495741383 --- mediapipe/calculators/audio/BUILD | 4 ++-- mediapipe/calculators/internal/BUILD | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mediapipe/calculators/audio/BUILD b/mediapipe/calculators/audio/BUILD index 555f7543f..4a8f0f598 100644 --- a/mediapipe/calculators/audio/BUILD +++ b/mediapipe/calculators/audio/BUILD @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + licenses(["notice"]) package(default_visibility = ["//visibility:private"]) -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") - proto_library( name = "mfcc_mel_calculators_proto", srcs = ["mfcc_mel_calculators.proto"], diff --git a/mediapipe/calculators/internal/BUILD b/mediapipe/calculators/internal/BUILD index caade2dc3..8647e3f3f 100644 --- a/mediapipe/calculators/internal/BUILD +++ b/mediapipe/calculators/internal/BUILD @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -licenses(["notice"]) - load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +licenses(["notice"]) + package(default_visibility = ["//visibility:private"]) proto_library( From d5562241cc50ec34a04f1fb4f4172df7dbe008bf Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Thu, 15 Dec 2022 18:32:10 -0800 Subject: [PATCH 172/346] Tensor: Interoperability GPU/Cpu -> Ahwb by transforming the underlying storage into Ahwb with releasing previously Cpu/Gpu resources. PiperOrigin-RevId: 495748104 --- mediapipe/framework/formats/tensor.h | 2 +- mediapipe/framework/formats/tensor_ahwb.cc | 19 ++++++------ .../framework/formats/tensor_ahwb_gpu_test.cc | 30 +++++++++++++++++++ 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 9d3e90b6a..f5a99cde1 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -408,8 +408,8 @@ class Tensor { mutable std::function release_callback_; bool AllocateAHardwareBuffer(int size_alignment = 0) const; void CreateEglSyncAndFd() const; - // Use Ahwb for other views: OpenGL / CPU buffer. #endif // MEDIAPIPE_TENSOR_USE_AHWB + // Use Ahwb for other views: OpenGL / CPU buffer. static inline bool use_ahwb_ = false; // Expects the target SSBO to be already bound. bool AllocateAhwbMapToSsbo() const; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 3c3ec8b17..363c5efd0 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -212,9 +212,6 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const { CHECK(!(valid_ & kValidOpenGlTexture2d)) << "Tensor conversion between OpenGL texture and AHardwareBuffer is not " "supported."; - CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer)) - << "Interoperability bettween OpenGL buffer and AHardwareBuffer is not " - "supported on target system."; bool transfer = !ahwb_; CHECK(AllocateAHardwareBuffer()) << "AHardwareBuffer is not supported on the target system."; @@ -315,7 +312,13 @@ void Tensor::MoveCpuOrSsboToAhwb() const { ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest); CHECK(error == 0) << "AHardwareBuffer_lock " << error; } - if (valid_ & kValidOpenGlBuffer) { + if (valid_ & kValidCpu) { + std::memcpy(dest, cpu_buffer_, bytes()); + // Free CPU memory because next time AHWB is mapped instead. + free(cpu_buffer_); + cpu_buffer_ = nullptr; + valid_ &= ~kValidCpu; + } else if (valid_ & kValidOpenGlBuffer) { gl_context_->Run([this, dest]() { glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(), @@ -326,11 +329,9 @@ void Tensor::MoveCpuOrSsboToAhwb() const { }); opengl_buffer_ = GL_INVALID_INDEX; gl_context_ = nullptr; - } else if (valid_ & kValidCpu) { - std::memcpy(dest, cpu_buffer_, bytes()); - // Free CPU memory because next time AHWB is mapped instead. - free(cpu_buffer_); - cpu_buffer_ = nullptr; + // Reset OpenGL Buffer validness. The OpenGL buffer will be allocated on top + // of the Ahwb at the next request to the OpenGlBufferView. + valid_ &= ~kValidOpenGlBuffer; } else { LOG(FATAL) << "Can't convert tensor with mask " << valid_ << " into AHWB."; } diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index 7ccd9c7f5..a6ca00949 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -152,6 +152,36 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { { auto view = tensor.GetAHardwareBufferReadView(); EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); + } + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(testing::FloatEq(), reference)); +} + +TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) { + // Request the GPU view to get the ssbo allocated internally. + // Request Ahwb view then to transform the storage into Ahwb. + Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + RunInGlContext([&tensor] { + auto ssbo_view = tensor.GetOpenGlBufferWriteView(); + auto ssbo_name = ssbo_view.name(); + EXPECT_GT(ssbo_name, 0); + FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), + tensor.element_type()); + }); + { + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); } auto ptr = tensor.GetCpuReadView().buffer(); EXPECT_NE(ptr, nullptr); From b45554623af211792ab394459088506593295a4c Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 16 Dec 2022 13:39:31 -0800 Subject: [PATCH 173/346] Fix typo in GetVectorItemCalculator doc PiperOrigin-RevId: 495951016 --- mediapipe/calculators/core/get_vector_item_calculator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/calculators/core/get_vector_item_calculator.h b/mediapipe/calculators/core/get_vector_item_calculator.h index 25d90bfe6..ee886b381 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.h +++ b/mediapipe/calculators/core/get_vector_item_calculator.h @@ -47,7 +47,7 @@ namespace api2 { // calculator: "Get{SpecificType}VectorItemCalculator" // input_stream: "VECTOR:vector" // input_stream: "INDEX:index" -// input_stream: "ITEM:item" +// output_stream: "ITEM:item" // options { // [mediapipe.GetVectorItemCalculatorOptions.ext] { // item_index: 5 From 7ce4bb72d4da5fae2d52ee85d0d10dae5dd96f31 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 19 Dec 2022 09:00:46 -0800 Subject: [PATCH 174/346] Replace numpy.float with the builtin float type as numpy removes its own float type in v1.24. PiperOrigin-RevId: 496412858 --- mediapipe/python/packet_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/python/packet_test.py b/mediapipe/python/packet_test.py index e1a4c12af..16fc37c87 100644 --- a/mediapipe/python/packet_test.py +++ b/mediapipe/python/packet_test.py @@ -157,7 +157,7 @@ class PacketTest(absltest.TestCase): p.timestamp = 0 self.assertAlmostEqual(packet_getter.get_float(p), 0.42) self.assertEqual(p.timestamp, 0) - p2 = packet_creator.create_float(np.float(0.42)) + p2 = packet_creator.create_float(float(0.42)) p2.timestamp = 0 self.assertAlmostEqual(packet_getter.get_float(p2), 0.42) self.assertEqual(p2.timestamp, 0) From 482247697407106cffdc7beb646477443b573557 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 19 Dec 2022 11:05:23 -0800 Subject: [PATCH 175/346] Internal change PiperOrigin-RevId: 496443946 --- mediapipe/tasks/web/core/fileset_resolver.ts | 24 +++++++------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/mediapipe/tasks/web/core/fileset_resolver.ts b/mediapipe/tasks/web/core/fileset_resolver.ts index d4691243b..9917035a4 100644 --- a/mediapipe/tasks/web/core/fileset_resolver.ts +++ b/mediapipe/tasks/web/core/fileset_resolver.ts @@ -44,22 +44,14 @@ async function isSimdSupported(): Promise { } async function createFileset( - taskName: string, basePath: string = '.'): Promise { - if (await isSimdSupported()) { - return { - wasmLoaderPath: - `${basePath}/${taskName}_wasm_internal.js`, - wasmBinaryPath: - `${basePath}/${taskName}_wasm_internal.wasm`, - }; - } else { - return { - wasmLoaderPath: - `${basePath}/${taskName}_wasm_nosimd_internal.js`, - wasmBinaryPath: - `${basePath}/${taskName}_wasm_nosimd_internal.wasm`, - }; - } + taskName: string, basePath: string = ''): Promise { + const suffix = + await isSimdSupported() ? 'wasm_internal' : 'wasm_nosimd_internal'; + + return { + wasmLoaderPath: `${basePath}/${taskName}_${suffix}.js`, + wasmBinaryPath: `${basePath}/${taskName}_${suffix}.wasm`, + }; } // tslint:disable:class-as-namespace From 3e6cd5d2bf403299886bfdcb77079d92c2d794b5 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 19 Dec 2022 11:54:57 -0800 Subject: [PATCH 176/346] Add support for customizing gesture recognizer layers PiperOrigin-RevId: 496456160 --- .../gesture_recognizer/gesture_recognizer.py | 15 +++++++---- .../gesture_recognizer_test.py | 26 +++++++++++++++++++ .../gesture_recognizer/model_options.py | 6 +++++ 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py index f297d8640..556d2fcd7 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py @@ -173,15 +173,20 @@ class GestureRecognizer(classifier.Classifier): batch_size=None, dtype=tf.float32, name='hand_embedding') - - x = tf.keras.layers.BatchNormalization()(inputs) - x = tf.keras.layers.ReLU()(x) + x = inputs dropout_rate = self._model_options.dropout_rate - x = tf.keras.layers.Dropout(rate=dropout_rate, name='dropout')(x) + for i, width in enumerate(self._model_options.layer_widths): + x = tf.keras.layers.BatchNormalization()(x) + x = tf.keras.layers.ReLU()(x) + x = tf.keras.layers.Dropout(rate=dropout_rate)(x) + x = tf.keras.layers.Dense(width, name=f'custom_gesture_recognizer_{i}')(x) + x = tf.keras.layers.BatchNormalization()(x) + x = tf.keras.layers.ReLU()(x) + x = tf.keras.layers.Dropout(rate=dropout_rate)(x) outputs = tf.keras.layers.Dense( self._num_classes, activation='softmax', - name='custom_gesture_recognizer')( + name='custom_gesture_recognizer_out')( x) self._model = tf.keras.Model(inputs=inputs, outputs=outputs) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 280fc6a82..08fda4fea 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -60,6 +60,32 @@ class GestureRecognizerTest(tf.test.TestCase): self._test_accuracy(model) + @unittest_mock.patch.object( + tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense) + def test_gesture_recognizer_model_layer_widths(self, mock_dense): + layer_widths = [64, 32] + model_options = gesture_recognizer.ModelOptions(layer_widths=layer_widths) + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=model_options, hparams=hparams) + model = gesture_recognizer.GestureRecognizer.create( + train_data=self._train_data, + validation_data=self._validation_data, + options=gesture_recognizer_options) + expected_calls = [ + unittest_mock.call(w, name=f'custom_gesture_recognizer_{i}') + for i, w in enumerate(layer_widths) + ] + expected_calls.append( + unittest_mock.call( + len(self._train_data.label_names), + activation='softmax', + name='custom_gesture_recognizer_out')) + self.assertLen(mock_dense.call_args_list, len(expected_calls)) + mock_dense.assert_has_calls(expected_calls) + self._test_accuracy(model) + def test_export_gesture_recognizer_model(self): model_options = gesture_recognizer.ModelOptions() hparams = gesture_recognizer.HParams( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py b/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py index 79a84c792..1870437d4 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py @@ -14,6 +14,7 @@ """Configurable model options for gesture recognizer models.""" import dataclasses +from typing import List @dataclasses.dataclass @@ -23,5 +24,10 @@ class GestureRecognizerModelOptions: Attributes: dropout_rate: The fraction of the input units to drop, used in dropout layer. + layer_widths: A list of hidden layer widths for the gesture model. Each + element in the list will create a new hidden layer with the specified + width. The hidden layers are separated with BatchNorm, Dropout, and ReLU. + Defaults to an empty list(no hidden layers). """ dropout_rate: float = 0.05 + layer_widths: List[int] = dataclasses.field(default_factory=list) From ef3fa67bf423e2d1c2ffba2bab01cc1c7b5d2ba5 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Mon, 19 Dec 2022 12:36:07 -0800 Subject: [PATCH 177/346] Automatic selection of the tensor's storage type by recording previously requested views. PiperOrigin-RevId: 496466136 --- mediapipe/framework/formats/BUILD | 6 ++- mediapipe/framework/formats/tensor.cc | 35 +++++------------- mediapipe/framework/formats/tensor.h | 37 ++++++++++++++++--- mediapipe/framework/formats/tensor_ahwb.cc | 15 ++++++++ mediapipe/framework/formats/tensor_internal.h | 10 ++--- 5 files changed, 67 insertions(+), 36 deletions(-) diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index fdb698c48..fdd9b8909 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -428,7 +428,10 @@ cc_library( "tensor.cc", "tensor_ahwb.cc", ], - hdrs = ["tensor.h"], + hdrs = [ + "tensor.h", + "tensor_internal.h", + ], copts = select({ "//mediapipe:apple": [ "-x objective-c++", @@ -452,6 +455,7 @@ cc_library( ], }), deps = [ + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index fdafbff5c..3f11d368a 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -246,10 +246,10 @@ Tensor::OpenGlTexture2dView::GetLayoutDimensions(const Tensor::Shape& shape, return Tensor::OpenGlTexture2dView::Layout::kAligned; } } - // The best performance of a compute shader can be achived with textures' + // The best performance of a compute shader can be achieved with textures' // width multiple of 256. Making minimum fixed width of 256 waste memory for // small tensors. The optimal balance memory-vs-performance is power of 2. - // The texture width and height are choosen to be closer to square. + // The texture width and height are chosen to be closer to square. float power = std::log2(std::sqrt(static_cast(num_pixels))); w = 1 << static_cast(power); int h = (num_pixels + w - 1) / w; @@ -326,7 +326,7 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const { auto lock(absl::make_unique(&view_mutex_)); AllocateOpenGlBuffer(); if (!(valid_ & kValidOpenGlBuffer)) { - // If the call succeds then AHWB -> SSBO are synchronized so any usage of + // If the call succeeds then AHWB -> SSBO are synchronized so any usage of // the SSBO is correct after this call. if (!InsertAhwbToSsboFence()) { glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); @@ -348,8 +348,10 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const { }; } -Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView() const { +Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView( + uint64_t source_location_hash) const { auto lock(absl::make_unique(&view_mutex_)); + TrackAhwbUsage(source_location_hash); AllocateOpenGlBuffer(); valid_ = kValidOpenGlBuffer; return {opengl_buffer_, std::move(lock), nullptr}; @@ -385,6 +387,7 @@ void Tensor::Move(Tensor* src) { src->element_type_ = ElementType::kNone; // Mark as invalidated. cpu_buffer_ = src->cpu_buffer_; src->cpu_buffer_ = nullptr; + ahwb_tracking_key_ = src->ahwb_tracking_key_; #if MEDIAPIPE_METAL_ENABLED device_ = src->device_; src->device_ = nil; @@ -589,8 +592,10 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { return {cpu_buffer_, std::move(lock)}; } -Tensor::CpuWriteView Tensor::GetCpuWriteView() const { +Tensor::CpuWriteView Tensor::GetCpuWriteView( + uint64_t source_location_hash) const { auto lock = absl::make_unique(&view_mutex_); + TrackAhwbUsage(source_location_hash); AllocateCpuBuffer(); valid_ = kValidCpu; #ifdef MEDIAPIPE_TENSOR_USE_AHWB @@ -620,24 +625,4 @@ void Tensor::AllocateCpuBuffer() const { } } -void Tensor::SetPreferredStorageType(StorageType type) { -#ifdef MEDIAPIPE_TENSOR_USE_AHWB - if (__builtin_available(android 26, *)) { - use_ahwb_ = type == StorageType::kAhwb; - VLOG(4) << "Tensor: use of AHardwareBuffer is " - << (use_ahwb_ ? "allowed" : "not allowed"); - } -#else - VLOG(4) << "Tensor: use of AHardwareBuffer is not allowed"; -#endif // MEDIAPIPE_TENSOR_USE_AHWB -} - -Tensor::StorageType Tensor::GetPreferredStorageType() { -#ifdef MEDIAPIPE_TENSOR_USE_AHWB - return use_ahwb_ ? StorageType::kAhwb : StorageType::kDefault; -#else - return StorageType::kDefault; -#endif // MEDIAPIPE_TENSOR_USE_AHWB -} - } // namespace mediapipe diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index f5a99cde1..8a6f02e9d 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -24,8 +24,9 @@ #include #include -#include "absl/memory/memory.h" +#include "absl/container/flat_hash_set.h" #include "absl/synchronization/mutex.h" +#include "mediapipe/framework/formats/tensor_internal.h" #include "mediapipe/framework/port.h" #if MEDIAPIPE_METAL_ENABLED @@ -48,6 +49,22 @@ #include "mediapipe/gpu/gl_context.h" #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 +#if defined __has_builtin +#if __has_builtin(__builtin_LINE) +#define builtin_LINE __builtin_LINE +#endif +#if __has_builtin(__builtin_FILE) +#define builtin_FILE __builtin_FILE +#endif +#endif + +#ifndef builtin_LINE +#define builtin_LINE() 0 +#endif +#ifndef builtin_FILE +#define builtin_FILE() "" +#endif + namespace mediapipe { // Tensor is a container of multi-dimensional data that supports sharing the @@ -65,7 +82,7 @@ namespace mediapipe { // GLuint buffer = view.buffer(); // Then the buffer can be bound to the GPU command buffer. // ...binding the buffer to the command buffer... -// ...commiting command buffer and releasing the view... +// ...committing command buffer and releasing the view... // // The following request for the CPU view will be blocked until the GPU view is // released and the GPU task is finished. @@ -161,7 +178,9 @@ class Tensor { using CpuReadView = CpuView; CpuReadView GetCpuReadView() const; using CpuWriteView = CpuView; - CpuWriteView GetCpuWriteView() const; + CpuWriteView GetCpuWriteView( + uint64_t source_location_hash = + tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const; #if MEDIAPIPE_METAL_ENABLED // TODO: id vs. MtlBufferView. @@ -305,7 +324,9 @@ class Tensor { // A valid OpenGL context must be bound to the calling thread due to possible // GPU resource allocation. OpenGlBufferView GetOpenGlBufferReadView() const; - OpenGlBufferView GetOpenGlBufferWriteView() const; + OpenGlBufferView GetOpenGlBufferWriteView( + uint64_t source_location_hash = + tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const; #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 const Shape& shape() const { return shape_; } @@ -410,7 +431,11 @@ class Tensor { void CreateEglSyncAndFd() const; #endif // MEDIAPIPE_TENSOR_USE_AHWB // Use Ahwb for other views: OpenGL / CPU buffer. - static inline bool use_ahwb_ = false; + mutable bool use_ahwb_ = false; + mutable uint64_t ahwb_tracking_key_ = 0; + // TODO: Tracks all unique tensors. Can grow to a large number. LRU + // can be more predicted. + static inline absl::flat_hash_set ahwb_usage_track_; // Expects the target SSBO to be already bound. bool AllocateAhwbMapToSsbo() const; bool InsertAhwbToSsboFence() const; @@ -419,6 +444,8 @@ class Tensor { void* MapAhwbToCpuRead() const; void* MapAhwbToCpuWrite() const; void MoveCpuOrSsboToAhwb() const; + // Set current tracking key, set "use ahwb" if the key is already marked. + void TrackAhwbUsage(uint64_t key) const; #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 mutable std::shared_ptr gl_context_; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 363c5efd0..466811be7 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -265,6 +265,10 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView( } bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { + // Mark current tracking key as Ahwb-use. + ahwb_usage_track_.insert(ahwb_tracking_key_); + use_ahwb_ = true; + if (__builtin_available(android 26, *)) { if (ahwb_ == nullptr) { AHardwareBuffer_Desc desc = {}; @@ -447,6 +451,16 @@ void* Tensor::MapAhwbToCpuWrite() const { return nullptr; } +void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const { + if (ahwb_tracking_key_ == 0) { + ahwb_tracking_key_ = source_location_hash; + for (int dim : shape_.dims) { + ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim); + } + } + use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_); +} + #else // MEDIAPIPE_TENSOR_USE_AHWB bool Tensor::AllocateAhwbMapToSsbo() const { return false; } @@ -455,6 +469,7 @@ void Tensor::MoveAhwbStuff(Tensor* src) {} void Tensor::ReleaseAhwbStuff() {} void* Tensor::MapAhwbToCpuRead() const { return nullptr; } void* Tensor::MapAhwbToCpuWrite() const { return nullptr; } +void Tensor::TrackAhwbUsage(uint64_t key) const {} #endif // MEDIAPIPE_TENSOR_USE_AHWB diff --git a/mediapipe/framework/formats/tensor_internal.h b/mediapipe/framework/formats/tensor_internal.h index 1231a991c..c223c5b1d 100644 --- a/mediapipe/framework/formats/tensor_internal.h +++ b/mediapipe/framework/formats/tensor_internal.h @@ -18,8 +18,6 @@ #include #include -#include "mediapipe/framework/tool/type_util.h" - namespace mediapipe { // Generates unique view id at compile-time using FILE and LINE. @@ -41,10 +39,12 @@ namespace tensor_internal { // https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function constexpr uint64_t kFnvPrime = 0x00000100000001B3; constexpr uint64_t kFnvOffsetBias = 0xcbf29ce484222325; -constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) { - return (str[0] == 0) ? hash : FnvHash64(str + 1, (hash ^ str[0]) * kFnvPrime); +constexpr uint64_t FnvHash64(uint64_t value1, uint64_t value2) { + return (value2 ^ value1) * kFnvPrime; +} +constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) { + return (str[0] == 0) ? hash : FnvHash64(str + 1, FnvHash64(hash, str[0])); } - template struct TypeList { static constexpr std::size_t size{sizeof...(Ts)}; From ea0bebc22608b9bbc2c0173460a418122bea4861 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 19 Dec 2022 14:48:47 -0800 Subject: [PATCH 178/346] Add BGR -> RGB color conversion to ColorConvertCalculator. PiperOrigin-RevId: 496497002 --- .../calculators/image/color_convert_calculator.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mediapipe/calculators/image/color_convert_calculator.cc b/mediapipe/calculators/image/color_convert_calculator.cc index bdac932bb..4781f1ea1 100644 --- a/mediapipe/calculators/image/color_convert_calculator.cc +++ b/mediapipe/calculators/image/color_convert_calculator.cc @@ -38,6 +38,7 @@ void SetColorChannel(int channel, uint8 value, cv::Mat* mat) { constexpr char kRgbaInTag[] = "RGBA_IN"; constexpr char kRgbInTag[] = "RGB_IN"; +constexpr char kBgrInTag[] = "BGR_IN"; constexpr char kBgraInTag[] = "BGRA_IN"; constexpr char kGrayInTag[] = "GRAY_IN"; constexpr char kRgbaOutTag[] = "RGBA_OUT"; @@ -57,6 +58,7 @@ constexpr char kGrayOutTag[] = "GRAY_OUT"; // RGB -> RGBA // RGBA -> BGRA // BGRA -> RGBA +// BGR -> RGB // // This calculator only supports a single input stream and output stream at a // time. If more than one input stream or output stream is present, the @@ -69,6 +71,7 @@ constexpr char kGrayOutTag[] = "GRAY_OUT"; // RGB_IN: The input video stream (ImageFrame, SRGB). // BGRA_IN: The input video stream (ImageFrame, SBGRA). // GRAY_IN: The input video stream (ImageFrame, GRAY8). +// BGR_IN: The input video stream (ImageFrame, SBGR). // // Output streams: // RGBA_OUT: The output video stream (ImageFrame, SRGBA). @@ -122,6 +125,10 @@ absl::Status ColorConvertCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kBgraInTag).Set(); } + if (cc->Inputs().HasTag(kBgrInTag)) { + cc->Inputs().Tag(kBgrInTag).Set(); + } + if (cc->Outputs().HasTag(kRgbOutTag)) { cc->Outputs().Tag(kRgbOutTag).Set(); } @@ -194,6 +201,11 @@ absl::Status ColorConvertCalculator::Process(CalculatorContext* cc) { return ConvertAndOutput(kRgbaInTag, kBgraOutTag, ImageFormat::SBGRA, cv::COLOR_RGBA2BGRA, cc); } + // BGR -> RGB + if (cc->Inputs().HasTag(kBgrInTag) && cc->Outputs().HasTag(kRgbOutTag)) { + return ConvertAndOutput(kBgrInTag, kRgbOutTag, ImageFormat::SRGB, + cv::COLOR_BGR2RGB, cc); + } return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Unsupported image format conversion."; From 6842f2c7c6657e7645ecddd26c92504e1b797f84 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 19 Dec 2022 17:12:06 -0800 Subject: [PATCH 179/346] Use the proper namespace for builder test PiperOrigin-RevId: 496526588 --- mediapipe/framework/api2/builder_test.cc | 131 ++++++++++++----------- 1 file changed, 66 insertions(+), 65 deletions(-) diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 3bf3ec198..361f740c4 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -15,12 +15,17 @@ #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" -namespace mediapipe { -namespace api2 { -namespace test { +namespace mediapipe::api2::builder { +namespace { + +using ::mediapipe::api2::test::Bar; +using ::mediapipe::api2::test::FloatAdder; +using ::mediapipe::api2::test::Foo; +using ::mediapipe::api2::test::Foo2; +using ::mediapipe::api2::test::FooBar1; TEST(BuilderTest, BuildGraph) { - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode("Foo"); auto& bar = graph.AddNode("Bar"); graph.In("IN").SetName("base") >> foo.In("BASE"); @@ -49,22 +54,21 @@ TEST(BuilderTest, BuildGraph) { } TEST(BuilderTest, CopyableSource) { - builder::Graph graph; - builder::Source a = graph[Input("A")]; + Graph graph; + Source a = graph[Input("A")]; a.SetName("a"); - builder::Source b = graph[Input("B")]; + Source b = graph[Input("B")]; b.SetName("b"); - builder::SideSource side_a = graph[SideInput("SIDE_A")]; + SideSource side_a = graph[SideInput("SIDE_A")]; side_a.SetName("side_a"); - builder::SideSource side_b = graph[SideInput("SIDE_B")]; + SideSource side_b = graph[SideInput("SIDE_B")]; side_b.SetName("side_b"); - builder::Destination out = graph[Output("OUT")]; - builder::SideDestination side_out = - graph[SideOutput("SIDE_OUT")]; + Destination out = graph[Output("OUT")]; + SideDestination side_out = graph[SideOutput("SIDE_OUT")]; - builder::Source input = a; + Source input = a; input = b; - builder::SideSource side_input = side_b; + SideSource side_input = side_b; side_input = side_a; input >> out; @@ -83,28 +87,27 @@ TEST(BuilderTest, CopyableSource) { } TEST(BuilderTest, BuildGraphWithFunctions) { - builder::Graph graph; + Graph graph; - builder::Source base = graph[Input("IN")]; + Source base = graph[Input("IN")]; base.SetName("base"); - builder::SideSource side = graph[SideInput("SIDE")]; + SideSource side = graph[SideInput("SIDE")]; side.SetName("side"); - auto foo_fn = [](builder::Source base, builder::SideSource side, - builder::Graph& graph) { + auto foo_fn = [](Source base, SideSource side, Graph& graph) { auto& foo = graph.AddNode("Foo"); base >> foo[Input("BASE")]; side >> foo[SideInput("SIDE")]; return foo[Output("OUT")]; }; - builder::Source foo_out = foo_fn(base, side, graph); + Source foo_out = foo_fn(base, side, graph); - auto bar_fn = [](builder::Source in, builder::Graph& graph) { + auto bar_fn = [](Source in, Graph& graph) { auto& bar = graph.AddNode("Bar"); in >> bar[Input("IN")]; return bar[Output("OUT")]; }; - builder::Source bar_out = bar_fn(foo_out, graph); + Source bar_out = bar_fn(foo_out, graph); bar_out.SetName("out"); bar_out >> graph[Output("OUT")]; @@ -131,7 +134,7 @@ TEST(BuilderTest, BuildGraphWithFunctions) { template void BuildGraphTypedTest() { - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode(); auto& bar = graph.AddNode(); graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE")); @@ -161,12 +164,12 @@ void BuildGraphTypedTest() { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } -TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest(); } +TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest(); } -TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest(); } +TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest(); } TEST(BuilderTest, FanOut) { - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode("Foo"); auto& adder = graph.AddNode("FloatAdder"); graph.In("IN").SetName("base") >> foo.In("BASE"); @@ -194,9 +197,9 @@ TEST(BuilderTest, FanOut) { } TEST(BuilderTest, TypedMultiple) { - builder::Graph graph; - auto& foo = graph.AddNode(); - auto& adder = graph.AddNode(); + Graph graph; + auto& foo = graph.AddNode(); + auto& adder = graph.AddNode(); graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE")); foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[0]; foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[1]; @@ -222,8 +225,8 @@ TEST(BuilderTest, TypedMultiple) { } TEST(BuilderTest, TypedByPorts) { - builder::Graph graph; - auto& foo = graph.AddNode(); + Graph graph; + auto& foo = graph.AddNode(); auto& adder = graph.AddNode(); graph[FooBar1::kIn].SetName("base") >> foo[Foo::kBase]; @@ -251,7 +254,7 @@ TEST(BuilderTest, TypedByPorts) { } TEST(BuilderTest, PacketGenerator) { - builder::Graph graph; + Graph graph; auto& generator = graph.AddPacketGenerator("FloatGenerator"); graph.SideIn("IN") >> generator.SideIn("IN"); generator.SideOut("OUT") >> graph.SideOut("OUT"); @@ -270,7 +273,7 @@ TEST(BuilderTest, PacketGenerator) { } TEST(BuilderTest, EmptyTag) { - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode("Foo"); graph.In("A").SetName("a") >> foo.In("")[0]; graph.In("C").SetName("c") >> foo.In("")[2]; @@ -302,7 +305,7 @@ TEST(BuilderTest, StringLikeTags) { const std::string kB = "B"; constexpr absl::string_view kC = "C"; - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode("Foo"); graph.In(kA).SetName("a") >> foo.In(kA); graph.In(kB).SetName("b") >> foo.In(kB); @@ -324,7 +327,7 @@ TEST(BuilderTest, StringLikeTags) { } TEST(BuilderTest, GraphIndexes) { - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode("Foo"); graph.In(0).SetName("a") >> foo.In("")[0]; graph.In(1).SetName("c") >> foo.In("")[2]; @@ -376,28 +379,27 @@ class AnyAndSameTypeCalculator : public NodeIntf { }; TEST(BuilderTest, AnyAndSameTypeHandledProperly) { - builder::Graph graph; - builder::Source any_input = graph[Input{"GRAPH_ANY_INPUT"}]; - builder::Source int_input = graph[Input{"GRAPH_INT_INPUT"}]; + Graph graph; + Source any_input = graph[Input{"GRAPH_ANY_INPUT"}]; + Source int_input = graph[Input{"GRAPH_INT_INPUT"}]; auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; int_input >> node[AnyAndSameTypeCalculator::kIntInput]; - builder::Source any_type_output = + Source any_type_output = node[AnyAndSameTypeCalculator::kAnyTypeOutput]; any_type_output.SetName("any_type_output"); - builder::Source same_type_output = + Source same_type_output = node[AnyAndSameTypeCalculator::kSameTypeOutput]; same_type_output.SetName("same_type_output"); - builder::Source recursive_same_type_output = + Source recursive_same_type_output = node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput]; recursive_same_type_output.SetName("recursive_same_type_output"); - builder::Source same_int_output = - node[AnyAndSameTypeCalculator::kSameIntOutput]; + Source same_int_output = node[AnyAndSameTypeCalculator::kSameIntOutput]; same_int_output.SetName("same_int_output"); - builder::Source recursive_same_int_type_output = + Source recursive_same_int_type_output = node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput]; recursive_same_int_type_output.SetName("recursive_same_int_type_output"); @@ -420,13 +422,13 @@ TEST(BuilderTest, AnyAndSameTypeHandledProperly) { } TEST(BuilderTest, AnyTypeCanBeCast) { - builder::Graph graph; - builder::Source any_input = + Graph graph; + Source any_input = graph.In("GRAPH_ANY_INPUT").Cast(); auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; - builder::Source any_type_output = + Source any_type_output = node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast(); any_type_output.SetName("any_type_output"); @@ -446,11 +448,11 @@ TEST(BuilderTest, AnyTypeCanBeCast) { } TEST(BuilderTest, MultiPortIsCastToMultiPort) { - builder::Graph graph; - builder::MultiSource any_input = graph.In("ANY_INPUT"); - builder::MultiSource int_input = any_input.Cast(); - builder::MultiDestination any_output = graph.Out("ANY_OUTPUT"); - builder::MultiDestination int_output = any_output.Cast(); + Graph graph; + MultiSource any_input = graph.In("ANY_INPUT"); + MultiSource int_input = any_input.Cast(); + MultiDestination any_output = graph.Out("ANY_OUTPUT"); + MultiDestination int_output = any_output.Cast(); int_input >> int_output; CalculatorGraphConfig expected = @@ -462,11 +464,11 @@ TEST(BuilderTest, MultiPortIsCastToMultiPort) { } TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) { - builder::Graph graph; - builder::MultiSource any_multi_input = graph.In("ANY_INPUT"); - builder::Source any_input = any_multi_input; - builder::MultiDestination any_multi_output = graph.Out("ANY_OUTPUT"); - builder::Destination any_output = any_multi_output; + Graph graph; + MultiSource any_multi_input = graph.In("ANY_INPUT"); + Source any_input = any_multi_input; + MultiDestination any_multi_output = graph.Out("ANY_OUTPUT"); + Destination any_output = any_multi_output; any_input >> any_output; CalculatorGraphConfig expected = @@ -478,11 +480,11 @@ TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) { } TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) { - builder::Graph graph; - builder::Source int_input = graph.In("INT_INPUT").Cast(); - builder::Source any_input = graph.In("ANY_OUTPUT"); - builder::Destination int_output = graph.Out("INT_OUTPUT").Cast(); - builder::Destination any_output = graph.Out("ANY_OUTPUT"); + Graph graph; + Source int_input = graph.In("INT_INPUT").Cast(); + Source any_input = graph.In("ANY_OUTPUT"); + Destination int_output = graph.Out("INT_OUTPUT").Cast(); + Destination any_output = graph.Out("ANY_OUTPUT"); int_input >> int_output; any_input >> any_output; @@ -496,6 +498,5 @@ TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } -} // namespace test -} // namespace api2 -} // namespace mediapipe +} // namespace +} // namespace mediapipe::api2::builder From f5f2fee0b9b2cccdf9fd04c6cb4b96fd8c1bc7ee Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 19 Dec 2022 17:16:04 -0800 Subject: [PATCH 180/346] Switch to Cast where possible and reduce usage of operator[](port). PiperOrigin-RevId: 496527250 --- mediapipe/framework/api2/builder_test.cc | 36 ++++++++++++------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 361f740c4..d8522b3c8 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -55,16 +55,16 @@ TEST(BuilderTest, BuildGraph) { TEST(BuilderTest, CopyableSource) { Graph graph; - Source a = graph[Input("A")]; + Source a = graph.In("A").Cast(); a.SetName("a"); - Source b = graph[Input("B")]; + Source b = graph.In("B").Cast(); b.SetName("b"); - SideSource side_a = graph[SideInput("SIDE_A")]; + SideSource side_a = graph.SideIn("SIDE_A").Cast(); side_a.SetName("side_a"); - SideSource side_b = graph[SideInput("SIDE_B")]; + SideSource side_b = graph.SideIn("SIDE_B").Cast(); side_b.SetName("side_b"); - Destination out = graph[Output("OUT")]; - SideDestination side_out = graph[SideOutput("SIDE_OUT")]; + Destination out = graph.Out("OUT").Cast(); + SideDestination side_out = graph.SideOut("SIDE_OUT").Cast(); Source input = a; input = b; @@ -89,28 +89,28 @@ TEST(BuilderTest, CopyableSource) { TEST(BuilderTest, BuildGraphWithFunctions) { Graph graph; - Source base = graph[Input("IN")]; + Source base = graph.In("IN").Cast(); base.SetName("base"); - SideSource side = graph[SideInput("SIDE")]; + SideSource side = graph.SideIn("SIDE").Cast(); side.SetName("side"); auto foo_fn = [](Source base, SideSource side, Graph& graph) { auto& foo = graph.AddNode("Foo"); - base >> foo[Input("BASE")]; - side >> foo[SideInput("SIDE")]; - return foo[Output("OUT")]; + base >> foo.In("BASE"); + side >> foo.SideIn("SIDE"); + return foo.Out("OUT")[0].Cast(); }; Source foo_out = foo_fn(base, side, graph); auto bar_fn = [](Source in, Graph& graph) { auto& bar = graph.AddNode("Bar"); - in >> bar[Input("IN")]; - return bar[Output("OUT")]; + in >> bar.In("IN"); + return bar.Out("OUT")[0].Cast(); }; Source bar_out = bar_fn(foo_out, graph); bar_out.SetName("out"); - bar_out >> graph[Output("OUT")]; + bar_out >> graph.Out("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -229,10 +229,10 @@ TEST(BuilderTest, TypedByPorts) { auto& foo = graph.AddNode(); auto& adder = graph.AddNode(); - graph[FooBar1::kIn].SetName("base") >> foo[Foo::kBase]; + graph.In(FooBar1::kIn).SetName("base") >> foo[Foo::kBase]; foo[Foo::kOut] >> adder[FloatAdder::kIn][0]; foo[Foo::kOut] >> adder[FloatAdder::kIn][1]; - adder[FloatAdder::kOut].SetName("out") >> graph[FooBar1::kOut]; + adder[FloatAdder::kOut].SetName("out") >> graph.Out(FooBar1::kOut); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -380,8 +380,8 @@ class AnyAndSameTypeCalculator : public NodeIntf { TEST(BuilderTest, AnyAndSameTypeHandledProperly) { Graph graph; - Source any_input = graph[Input{"GRAPH_ANY_INPUT"}]; - Source int_input = graph[Input{"GRAPH_INT_INPUT"}]; + Source any_input = graph.In("GRAPH_ANY_INPUT"); + Source int_input = graph.In("GRAPH_INT_INPUT").Cast(); auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; From 994eb28d2c007ebc09795b300cedf0abe7130507 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 19 Dec 2022 18:05:30 -0800 Subject: [PATCH 181/346] Chain SetName calls where possible PiperOrigin-RevId: 496534328 --- mediapipe/framework/api2/builder_test.cc | 28 ++++++++++-------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index d8522b3c8..b01c2b759 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -55,14 +55,12 @@ TEST(BuilderTest, BuildGraph) { TEST(BuilderTest, CopyableSource) { Graph graph; - Source a = graph.In("A").Cast(); - a.SetName("a"); - Source b = graph.In("B").Cast(); - b.SetName("b"); - SideSource side_a = graph.SideIn("SIDE_A").Cast(); - side_a.SetName("side_a"); - SideSource side_b = graph.SideIn("SIDE_B").Cast(); - side_b.SetName("side_b"); + Source a = graph.In("A").SetName("a").Cast(); + Source b = graph.In("B").SetName("b").Cast(); + SideSource side_a = + graph.SideIn("SIDE_A").SetName("side_a").Cast(); + SideSource side_b = + graph.SideIn("SIDE_B").SetName("side_b").Cast(); Destination out = graph.Out("OUT").Cast(); SideDestination side_out = graph.SideOut("SIDE_OUT").Cast(); @@ -89,10 +87,8 @@ TEST(BuilderTest, CopyableSource) { TEST(BuilderTest, BuildGraphWithFunctions) { Graph graph; - Source base = graph.In("IN").Cast(); - base.SetName("base"); - SideSource side = graph.SideIn("SIDE").Cast(); - side.SetName("side"); + Source base = graph.In("IN").SetName("base").Cast(); + SideSource side = graph.SideIn("SIDE").SetName("side").Cast(); auto foo_fn = [](Source base, SideSource side, Graph& graph) { auto& foo = graph.AddNode("Foo"); @@ -108,9 +104,8 @@ TEST(BuilderTest, BuildGraphWithFunctions) { return bar.Out("OUT")[0].Cast(); }; Source bar_out = bar_fn(foo_out, graph); - bar_out.SetName("out"); - bar_out >> graph.Out("OUT"); + bar_out.SetName("out") >> graph.Out("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -429,8 +424,9 @@ TEST(BuilderTest, AnyTypeCanBeCast) { auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; Source any_type_output = - node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast(); - any_type_output.SetName("any_type_output"); + node[AnyAndSameTypeCalculator::kAnyTypeOutput] + .SetName("any_type_output") + .Cast(); any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast(); From 90678040057ee23dc2ad29c6982010a260c7b7cd Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 19 Dec 2022 19:39:00 -0800 Subject: [PATCH 182/346] Fix the missing logging component issue of mediapipe tasks core. PiperOrigin-RevId: 496548340 --- mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index 3eb28d38b..5f7101776 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -53,7 +53,7 @@ load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl" mediapipe_tasks_core_aar( name = "tasks_core", - srcs = glob(["*.java"]) + [ + srcs = glob(["**/*.java"]) + [ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:java_src", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:java_src", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:java_src", From 4682416f0f426e8302b4181a7085713ac1c6e38c Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 19 Dec 2022 22:07:55 -0800 Subject: [PATCH 183/346] Internal change PiperOrigin-RevId: 496568835 --- mediapipe/calculators/core/BUILD | 14 ++--- mediapipe/calculators/image/BUILD | 4 +- mediapipe/calculators/util/BUILD | 2 +- mediapipe/framework/BUILD | 70 ++++++++++++------------ mediapipe/framework/formats/BUILD | 24 ++++---- mediapipe/framework/formats/motion/BUILD | 2 +- mediapipe/framework/stream_handler/BUILD | 14 ++--- mediapipe/framework/tool/BUILD | 18 +++--- mediapipe/gpu/BUILD | 14 ++--- 9 files changed, 81 insertions(+), 81 deletions(-) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 2c143a609..b3378a74e 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -567,7 +567,7 @@ cc_library( name = "packet_thinner_calculator", srcs = ["packet_thinner_calculator.cc"], deps = [ - "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", + ":packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:video_stream_header", @@ -584,7 +584,7 @@ cc_test( srcs = ["packet_thinner_calculator_test.cc"], deps = [ ":packet_thinner_calculator", - "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", + ":packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -762,7 +762,7 @@ cc_library( srcs = ["packet_resampler_calculator.cc"], hdrs = ["packet_resampler_calculator.h"], deps = [ - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + ":packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -786,7 +786,7 @@ cc_test( ], deps = [ ":packet_resampler_calculator", - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + ":packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -852,10 +852,10 @@ cc_test( name = "flow_limiter_calculator_test", srcs = ["flow_limiter_calculator_test.cc"], deps = [ + ":counting_source_calculator", ":flow_limiter_calculator", ":flow_limiter_calculator_cc_proto", - "//mediapipe/calculators/core:counting_source_calculator", - "//mediapipe/calculators/core:pass_through_calculator", + ":pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:test_calculators", @@ -1302,7 +1302,7 @@ cc_test( srcs = ["packet_sequencer_calculator_test.cc"], deps = [ ":packet_sequencer_calculator", - "//mediapipe/calculators/core:pass_through_calculator", + ":pass_through_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:subgraph", diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 530dd3d4a..9aae8cfbc 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -378,8 +378,8 @@ cc_library( name = "scale_image_calculator", srcs = ["scale_image_calculator.cc"], deps = [ + ":scale_image_calculator_cc_proto", ":scale_image_utils", - "//mediapipe/calculators/image:scale_image_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", @@ -747,8 +747,8 @@ cc_test( tags = ["desktop_only_test"], deps = [ ":affine_transformation", + ":image_transformation_calculator", ":warp_affine_calculator", - "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/tensor:image_to_tensor_converter", "//mediapipe/calculators/tensor:image_to_tensor_utils", "//mediapipe/calculators/util:from_image_calculator", diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 1529ead8a..a679a80fd 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -22,8 +22,8 @@ cc_library( name = "alignment_points_to_rects_calculator", srcs = ["alignment_points_to_rects_calculator.cc"], deps = [ + ":detections_to_rects_calculator", ":detections_to_rects_calculator_cc_proto", - "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 0dd694760..082ea9994 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -226,13 +226,13 @@ cc_library( ":mediapipe_internal", ], deps = [ + ":calculator_cc_proto", ":graph_service", + ":mediapipe_options_cc_proto", + ":packet_generator_cc_proto", ":packet_type", ":port", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:mediapipe_options_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", + ":status_handler_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_map", @@ -328,10 +328,10 @@ cc_library( ":thread_pool_executor", ":timestamp", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", + ":calculator_cc_proto", + ":packet_generator_cc_proto", + ":status_handler_cc_proto", + ":thread_pool_executor_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", @@ -391,6 +391,7 @@ cc_library( visibility = [":mediapipe_internal"], deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -407,10 +408,9 @@ cc_library( ":packet_set", ":packet_type", ":port", + ":stream_handler_cc_proto", ":timestamp", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:stream_handler_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -466,6 +466,7 @@ cc_library( hdrs = ["calculator_state.h"], visibility = [":mediapipe_internal"], deps = [ + ":calculator_cc_proto", ":counter", ":counter_factory", ":graph_service", @@ -475,7 +476,6 @@ cc_library( ":packet", ":packet_set", ":port", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/tool:options_map", @@ -583,7 +583,7 @@ cc_library( hdrs = ["executor.h"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:mediapipe_options_cc_proto", + ":mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", @@ -670,11 +670,11 @@ cc_library( ":collection_item_id", ":input_stream_manager", ":input_stream_shard", + ":mediapipe_options_cc_proto", ":mediapipe_profiling", ":packet", ":packet_set", ":packet_type", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -784,12 +784,12 @@ cc_library( ":calculator_context_manager", ":collection", ":collection_item_id", + ":mediapipe_options_cc_proto", ":output_stream_manager", ":output_stream_shard", ":packet_set", ":packet_type", ":timestamp", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -875,10 +875,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":packet", + ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:status", @@ -896,13 +896,13 @@ cc_library( ":delegating_executor", ":executor", ":packet", + ":packet_factory_cc_proto", ":packet_generator", + ":packet_generator_cc_proto", ":packet_type", ":port", ":thread_pool_executor", ":validated_graph_config", - "//mediapipe/framework:packet_factory_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -1019,10 +1019,10 @@ cc_library( hdrs = ["status_handler.h"], visibility = ["//visibility:public"], deps = [ + ":mediapipe_options_cc_proto", ":packet_set", ":packet_type", ":port", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "@com_google_absl//absl/memory", @@ -1035,10 +1035,10 @@ cc_library( hdrs = ["subgraph.h"], visibility = ["//visibility:public"], deps = [ + ":calculator_cc_proto", ":graph_service", ":graph_service_manager", ":port", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1097,7 +1097,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":executor", - "//mediapipe/framework:thread_pool_executor_cc_proto", + ":thread_pool_executor_cc_proto", "//mediapipe/framework/deps:thread_options", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -1162,22 +1162,22 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_contract", ":graph_service_manager", ":legacy_calculator_support", ":packet", ":packet_generator", + ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", ":status_handler", + ":status_handler_cc_proto", + ":stream_handler_cc_proto", ":subgraph", + ":thread_pool_executor_cc_proto", ":timestamp", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", - "//mediapipe/framework:stream_handler_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -1202,11 +1202,11 @@ cc_test( name = "validated_graph_config_test", srcs = ["validated_graph_config_test.cc"], deps = [ + ":calculator_cc_proto", ":calculator_framework", ":graph_service", ":graph_service_manager", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", "//mediapipe/framework/port:gtest_main", @@ -1233,6 +1233,7 @@ cc_test( linkstatic = 1, deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_registry", @@ -1242,7 +1243,6 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", @@ -1256,11 +1256,11 @@ cc_test( srcs = ["calculator_contract_test.cc"], linkstatic = 1, deps = [ + ":calculator_cc_proto", ":calculator_contract", ":calculator_contract_test_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", + ":packet_generator_cc_proto", + ":status_handler_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", ], @@ -1368,6 +1368,7 @@ cc_test( srcs = ["calculator_context_test.cc"], linkstatic = 1, deps = [ + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -1376,7 +1377,6 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", @@ -1403,6 +1403,7 @@ cc_test( ":executor", ":input_stream_handler", ":lifetime_tracker", + ":mediapipe_options_cc_proto", ":output_stream_poller", ":packet_set", ":packet_type", @@ -1410,13 +1411,12 @@ cc_test( ":subgraph", ":test_calculators", ":thread_pool_executor", + ":thread_pool_executor_cc_proto", ":timestamp", ":type_map", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:mediapipe_options_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1481,12 +1481,12 @@ cc_test( ], visibility = ["//visibility:public"], deps = [ + ":calculator_cc_proto", ":calculator_framework", ":test_calculators", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1630,8 +1630,8 @@ cc_test( srcs = ["packet_generator_test.cc"], deps = [ ":packet_generator", + ":packet_generator_cc_proto", ":packet_type", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/tool:validate_type", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index fdd9b8909..f5a043f10 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -26,7 +26,7 @@ licenses(["notice"]) mediapipe_proto_library( name = "detection_proto", srcs = ["detection.proto"], - deps = ["//mediapipe/framework/formats:location_data_proto"], + deps = [":location_data_proto"], ) mediapipe_register_type( @@ -38,7 +38,7 @@ mediapipe_register_type( "::std::vector<::mediapipe::Detection>", "::std::vector<::mediapipe::DetectionList>", ], - deps = ["//mediapipe/framework/formats:detection_cc_proto"], + deps = [":detection_cc_proto"], ) mediapipe_proto_library( @@ -105,8 +105,8 @@ cc_library( srcs = ["matrix.cc"], hdrs = ["matrix.h"], deps = [ + ":matrix_data_cc_proto", "//mediapipe/framework:port", - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -142,7 +142,7 @@ cc_library( srcs = ["image_frame.cc"], hdrs = ["image_frame.h"], deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -166,8 +166,8 @@ cc_library( srcs = ["image_frame_opencv.cc"], hdrs = ["image_frame_opencv.h"], deps = [ + ":image_format_cc_proto", ":image_frame", - "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:opencv_core", ], ) @@ -194,7 +194,7 @@ cc_library( deps = [ "@com_google_protobuf//:protobuf", "//mediapipe/framework/formats/annotation:locus_cc_proto", - "//mediapipe/framework/formats:location_data_cc_proto", + ":location_data_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -245,7 +245,7 @@ cc_library( name = "video_stream_header", hdrs = ["video_stream_header.h"], deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", ], ) @@ -263,9 +263,9 @@ cc_test( size = "small", srcs = ["image_frame_opencv_test.cc"], deps = [ + ":image_format_cc_proto", ":image_frame", ":image_frame_opencv", - "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -324,8 +324,8 @@ cc_library( "//conditions:default": [], }), deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", - "//mediapipe/framework/formats:image_frame", + ":image_format_cc_proto", + ":image_frame", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", "//mediapipe/framework:type_map", @@ -354,7 +354,7 @@ cc_library( hdrs = ["image_multi_pool.h"], deps = [ ":image", - "//mediapipe/framework/formats:image_frame_pool", + ":image_frame_pool", "//mediapipe/framework:port", "//mediapipe/framework/port:logging", "@com_google_absl//absl/memory", @@ -390,7 +390,7 @@ cc_library( ], deps = [ ":image", - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:statusor", diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index f1bbc0289..c9bb8b4ff 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -38,11 +38,11 @@ cc_library( srcs = ["optical_flow_field.cc"], hdrs = ["optical_flow_field.h"], deps = [ + ":optical_flow_field_data_cc_proto", "//mediapipe/framework:type_map", "//mediapipe/framework/deps:mathutil", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", - "//mediapipe/framework/formats/motion:optical_flow_field_data_cc_proto", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 01ef6ee86..68a9af52d 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -88,8 +88,8 @@ cc_library( srcs = ["default_input_stream_handler.cc"], hdrs = ["default_input_stream_handler.h"], deps = [ + ":default_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", - "//mediapipe/framework/stream_handler:default_input_stream_handler_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -110,8 +110,8 @@ cc_library( srcs = ["fixed_size_input_stream_handler.cc"], deps = [ ":default_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], alwayslink = 1, ) @@ -159,13 +159,13 @@ cc_library( name = "sync_set_input_stream_handler", srcs = ["sync_set_input_stream_handler.cc"], deps = [ + ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework:timestamp", - "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "//mediapipe/framework/tool:tag_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -177,10 +177,10 @@ cc_library( name = "timestamp_align_input_stream_handler", srcs = ["timestamp_align_input_stream_handler.cc"], deps = [ + ":timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:timestamp", - "//mediapipe/framework/stream_handler:timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework/tool:validate_name", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -243,6 +243,7 @@ cc_test( srcs = ["set_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", ":mux_input_stream_handler", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", @@ -251,7 +252,6 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], ) @@ -272,13 +272,13 @@ cc_test( srcs = ["fixed_size_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", ], @@ -289,11 +289,11 @@ cc_test( srcs = ["sync_set_input_stream_handler_test.cc"], deps = [ ":sync_set_input_stream_handler", + ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:test_calculators", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 89cb802da..193343a90 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -299,6 +299,7 @@ mediapipe_cc_test( requires_full_emulation = False, deps = [ ":node_chain_subgraph_cc_proto", + ":node_chain_subgraph_options_lib", ":options_field_util", ":options_registry", ":options_syntax_util", @@ -313,7 +314,6 @@ mediapipe_cc_test( "//mediapipe/framework/port:status", "//mediapipe/framework/testdata:night_light_calculator_cc_proto", "//mediapipe/framework/testdata:night_light_calculator_options_lib", - "//mediapipe/framework/tool:node_chain_subgraph_options_lib", "//mediapipe/util:header_util", "@com_google_absl//absl/strings", ], @@ -422,9 +422,9 @@ cc_library( srcs = ["source.cc"], visibility = ["//visibility:public"], deps = [ + ":source_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:source_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], @@ -485,13 +485,13 @@ cc_library( hdrs = ["template_expander.h"], visibility = ["//visibility:public"], deps = [ + ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:numbers", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/strings", ], ) @@ -506,6 +506,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:proto_descriptor_cc_proto", @@ -515,7 +516,6 @@ cc_library( "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -661,8 +661,8 @@ cc_library( hdrs = ["simulation_clock_executor.h"], visibility = ["//visibility:public"], deps = [ + ":simulation_clock", "//mediapipe/framework:thread_pool_executor", - "//mediapipe/framework/tool:simulation_clock", ], ) @@ -789,10 +789,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":name_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:switch_container_cc_proto", ], ) @@ -805,6 +805,7 @@ cc_library( deps = [ ":container_util", ":options_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -814,7 +815,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", - "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -841,6 +841,7 @@ cc_library( ], deps = [ ":container_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_shard", @@ -850,7 +851,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", - "//mediapipe/framework/tool:switch_container_cc_proto", ], alwayslink = 1, ) @@ -893,6 +893,7 @@ cc_library( ":container_util", ":name_util", ":subgraph_expansion", + ":switch_container_cc_proto", ":switch_demux_calculator", ":switch_mux_calculator", "//mediapipe/calculators/core:packet_sequencer_calculator", @@ -904,7 +905,6 @@ cc_library( "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 009eb3f9e..cc5e50dfc 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -564,6 +564,7 @@ cc_library( name = "gpu_shared_data_internal_stub", visibility = ["//visibility:private"], deps = [ + ":gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", @@ -571,7 +572,6 @@ cc_library( "//mediapipe/framework:port", "//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/port:ret_check", - "//mediapipe/gpu:gl_context_options_cc_proto", ], ) @@ -592,7 +592,7 @@ cc_library( }), visibility = ["//visibility:private"], deps = [ - "//mediapipe/gpu:gl_context_options_cc_proto", + ":gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:executor", @@ -833,10 +833,10 @@ cc_library( deps = [ ":gl_base", ":gl_simple_shaders", + ":scale_mode_cc_proto", ":shader_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/gpu:scale_mode_cc_proto", ], ) @@ -907,8 +907,8 @@ proto_library( srcs = ["gl_scaler_calculator.proto"], visibility = ["//visibility:public"], deps = [ + ":scale_mode_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/gpu:scale_mode_proto", ], ) @@ -930,6 +930,7 @@ cc_library( deps = [ ":gl_calculator_helper", ":gl_quad_renderer", + ":gl_scaler_calculator_cc_proto", ":gl_simple_shaders", ":shader_util", "//mediapipe/framework:calculator_framework", @@ -937,7 +938,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_util", - "//mediapipe/gpu:gl_scaler_calculator_cc_proto", ], alwayslink = 1, ) @@ -950,13 +950,13 @@ cc_library( ":egl_surface_holder", ":gl_calculator_helper", ":gl_quad_renderer", + ":gl_surface_sink_calculator_cc_proto", ":gpu_buffer", ":shader_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/gpu:gl_surface_sink_calculator_cc_proto", "@com_google_absl//absl/synchronization", ], alwayslink = 1, @@ -966,8 +966,8 @@ proto_library( name = "gl_surface_sink_calculator_proto", srcs = ["gl_surface_sink_calculator.proto"], deps = [ + ":scale_mode_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/gpu:scale_mode_proto", ], ) From 8c013647c87cc5784cd545df5f92afd33c6fe941 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 20 Dec 2022 04:47:09 -0800 Subject: [PATCH 184/346] Internal change PiperOrigin-RevId: 496629682 --- mediapipe/calculators/core/BUILD | 14 ++--- mediapipe/calculators/image/BUILD | 4 +- mediapipe/calculators/util/BUILD | 2 +- mediapipe/framework/BUILD | 70 ++++++++++++------------ mediapipe/framework/formats/BUILD | 24 ++++---- mediapipe/framework/formats/motion/BUILD | 2 +- mediapipe/framework/stream_handler/BUILD | 14 ++--- mediapipe/framework/tool/BUILD | 18 +++--- mediapipe/gpu/BUILD | 14 ++--- 9 files changed, 81 insertions(+), 81 deletions(-) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index b3378a74e..2c143a609 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -567,7 +567,7 @@ cc_library( name = "packet_thinner_calculator", srcs = ["packet_thinner_calculator.cc"], deps = [ - ":packet_thinner_calculator_cc_proto", + "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:video_stream_header", @@ -584,7 +584,7 @@ cc_test( srcs = ["packet_thinner_calculator_test.cc"], deps = [ ":packet_thinner_calculator", - ":packet_thinner_calculator_cc_proto", + "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -762,7 +762,7 @@ cc_library( srcs = ["packet_resampler_calculator.cc"], hdrs = ["packet_resampler_calculator.h"], deps = [ - ":packet_resampler_calculator_cc_proto", + "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -786,7 +786,7 @@ cc_test( ], deps = [ ":packet_resampler_calculator", - ":packet_resampler_calculator_cc_proto", + "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -852,10 +852,10 @@ cc_test( name = "flow_limiter_calculator_test", srcs = ["flow_limiter_calculator_test.cc"], deps = [ - ":counting_source_calculator", ":flow_limiter_calculator", ":flow_limiter_calculator_cc_proto", - ":pass_through_calculator", + "//mediapipe/calculators/core:counting_source_calculator", + "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:test_calculators", @@ -1302,7 +1302,7 @@ cc_test( srcs = ["packet_sequencer_calculator_test.cc"], deps = [ ":packet_sequencer_calculator", - ":pass_through_calculator", + "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:subgraph", diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 9aae8cfbc..530dd3d4a 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -378,8 +378,8 @@ cc_library( name = "scale_image_calculator", srcs = ["scale_image_calculator.cc"], deps = [ - ":scale_image_calculator_cc_proto", ":scale_image_utils", + "//mediapipe/calculators/image:scale_image_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", @@ -747,8 +747,8 @@ cc_test( tags = ["desktop_only_test"], deps = [ ":affine_transformation", - ":image_transformation_calculator", ":warp_affine_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/tensor:image_to_tensor_converter", "//mediapipe/calculators/tensor:image_to_tensor_utils", "//mediapipe/calculators/util:from_image_calculator", diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index a679a80fd..1529ead8a 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -22,8 +22,8 @@ cc_library( name = "alignment_points_to_rects_calculator", srcs = ["alignment_points_to_rects_calculator.cc"], deps = [ - ":detections_to_rects_calculator", ":detections_to_rects_calculator_cc_proto", + "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 082ea9994..0dd694760 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -226,13 +226,13 @@ cc_library( ":mediapipe_internal", ], deps = [ - ":calculator_cc_proto", ":graph_service", - ":mediapipe_options_cc_proto", - ":packet_generator_cc_proto", ":packet_type", ":port", - ":status_handler_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework:status_handler_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_map", @@ -328,10 +328,10 @@ cc_library( ":thread_pool_executor", ":timestamp", ":validated_graph_config", - ":calculator_cc_proto", - ":packet_generator_cc_proto", - ":status_handler_cc_proto", - ":thread_pool_executor_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework:status_handler_cc_proto", + "//mediapipe/framework:thread_pool_executor_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", @@ -391,7 +391,6 @@ cc_library( visibility = [":mediapipe_internal"], deps = [ ":calculator_base", - ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -408,9 +407,10 @@ cc_library( ":packet_set", ":packet_type", ":port", - ":stream_handler_cc_proto", ":timestamp", ":validated_graph_config", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:stream_handler_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -466,7 +466,6 @@ cc_library( hdrs = ["calculator_state.h"], visibility = [":mediapipe_internal"], deps = [ - ":calculator_cc_proto", ":counter", ":counter_factory", ":graph_service", @@ -476,6 +475,7 @@ cc_library( ":packet", ":packet_set", ":port", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/tool:options_map", @@ -583,7 +583,7 @@ cc_library( hdrs = ["executor.h"], visibility = ["//visibility:public"], deps = [ - ":mediapipe_options_cc_proto", + "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", @@ -670,11 +670,11 @@ cc_library( ":collection_item_id", ":input_stream_manager", ":input_stream_shard", - ":mediapipe_options_cc_proto", ":mediapipe_profiling", ":packet", ":packet_set", ":packet_type", + "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -784,12 +784,12 @@ cc_library( ":calculator_context_manager", ":collection", ":collection_item_id", - ":mediapipe_options_cc_proto", ":output_stream_manager", ":output_stream_shard", ":packet_set", ":packet_type", ":timestamp", + "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -875,10 +875,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":packet", - ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", + "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:status", @@ -896,13 +896,13 @@ cc_library( ":delegating_executor", ":executor", ":packet", - ":packet_factory_cc_proto", ":packet_generator", - ":packet_generator_cc_proto", ":packet_type", ":port", ":thread_pool_executor", ":validated_graph_config", + "//mediapipe/framework:packet_factory_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -1019,10 +1019,10 @@ cc_library( hdrs = ["status_handler.h"], visibility = ["//visibility:public"], deps = [ - ":mediapipe_options_cc_proto", ":packet_set", ":packet_type", ":port", + "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "@com_google_absl//absl/memory", @@ -1035,10 +1035,10 @@ cc_library( hdrs = ["subgraph.h"], visibility = ["//visibility:public"], deps = [ - ":calculator_cc_proto", ":graph_service", ":graph_service_manager", ":port", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1097,7 +1097,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":executor", - ":thread_pool_executor_cc_proto", + "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/deps:thread_options", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -1162,22 +1162,22 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":calculator_base", - ":calculator_cc_proto", ":calculator_contract", ":graph_service_manager", ":legacy_calculator_support", ":packet", ":packet_generator", - ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", ":status_handler", - ":status_handler_cc_proto", - ":stream_handler_cc_proto", ":subgraph", - ":thread_pool_executor_cc_proto", ":timestamp", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework:status_handler_cc_proto", + "//mediapipe/framework:stream_handler_cc_proto", + "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -1202,11 +1202,11 @@ cc_test( name = "validated_graph_config_test", srcs = ["validated_graph_config_test.cc"], deps = [ - ":calculator_cc_proto", ":calculator_framework", ":graph_service", ":graph_service_manager", ":validated_graph_config", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", "//mediapipe/framework/port:gtest_main", @@ -1233,7 +1233,6 @@ cc_test( linkstatic = 1, deps = [ ":calculator_base", - ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_registry", @@ -1243,6 +1242,7 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", @@ -1256,11 +1256,11 @@ cc_test( srcs = ["calculator_contract_test.cc"], linkstatic = 1, deps = [ - ":calculator_cc_proto", ":calculator_contract", ":calculator_contract_test_cc_proto", - ":packet_generator_cc_proto", - ":status_handler_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework:status_handler_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", ], @@ -1368,7 +1368,6 @@ cc_test( srcs = ["calculator_context_test.cc"], linkstatic = 1, deps = [ - ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -1377,6 +1376,7 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", @@ -1403,7 +1403,6 @@ cc_test( ":executor", ":input_stream_handler", ":lifetime_tracker", - ":mediapipe_options_cc_proto", ":output_stream_poller", ":packet_set", ":packet_type", @@ -1411,12 +1410,13 @@ cc_test( ":subgraph", ":test_calculators", ":thread_pool_executor", - ":thread_pool_executor_cc_proto", ":timestamp", ":type_map", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1481,12 +1481,12 @@ cc_test( ], visibility = ["//visibility:public"], deps = [ - ":calculator_cc_proto", ":calculator_framework", ":test_calculators", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1630,8 +1630,8 @@ cc_test( srcs = ["packet_generator_test.cc"], deps = [ ":packet_generator", - ":packet_generator_cc_proto", ":packet_type", + "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/tool:validate_type", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index f5a043f10..fdd9b8909 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -26,7 +26,7 @@ licenses(["notice"]) mediapipe_proto_library( name = "detection_proto", srcs = ["detection.proto"], - deps = [":location_data_proto"], + deps = ["//mediapipe/framework/formats:location_data_proto"], ) mediapipe_register_type( @@ -38,7 +38,7 @@ mediapipe_register_type( "::std::vector<::mediapipe::Detection>", "::std::vector<::mediapipe::DetectionList>", ], - deps = [":detection_cc_proto"], + deps = ["//mediapipe/framework/formats:detection_cc_proto"], ) mediapipe_proto_library( @@ -105,8 +105,8 @@ cc_library( srcs = ["matrix.cc"], hdrs = ["matrix.h"], deps = [ - ":matrix_data_cc_proto", "//mediapipe/framework:port", + "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -142,7 +142,7 @@ cc_library( srcs = ["image_frame.cc"], hdrs = ["image_frame.h"], deps = [ - ":image_format_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -166,8 +166,8 @@ cc_library( srcs = ["image_frame_opencv.cc"], hdrs = ["image_frame_opencv.h"], deps = [ - ":image_format_cc_proto", ":image_frame", + "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:opencv_core", ], ) @@ -194,7 +194,7 @@ cc_library( deps = [ "@com_google_protobuf//:protobuf", "//mediapipe/framework/formats/annotation:locus_cc_proto", - ":location_data_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -245,7 +245,7 @@ cc_library( name = "video_stream_header", hdrs = ["video_stream_header.h"], deps = [ - ":image_format_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", ], ) @@ -263,9 +263,9 @@ cc_test( size = "small", srcs = ["image_frame_opencv_test.cc"], deps = [ - ":image_format_cc_proto", ":image_frame", ":image_frame_opencv", + "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -324,8 +324,8 @@ cc_library( "//conditions:default": [], }), deps = [ - ":image_format_cc_proto", - ":image_frame", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", "//mediapipe/framework:type_map", @@ -354,7 +354,7 @@ cc_library( hdrs = ["image_multi_pool.h"], deps = [ ":image", - ":image_frame_pool", + "//mediapipe/framework/formats:image_frame_pool", "//mediapipe/framework:port", "//mediapipe/framework/port:logging", "@com_google_absl//absl/memory", @@ -390,7 +390,7 @@ cc_library( ], deps = [ ":image", - ":image_format_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:statusor", diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index c9bb8b4ff..f1bbc0289 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -38,11 +38,11 @@ cc_library( srcs = ["optical_flow_field.cc"], hdrs = ["optical_flow_field.h"], deps = [ - ":optical_flow_field_data_cc_proto", "//mediapipe/framework:type_map", "//mediapipe/framework/deps:mathutil", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", + "//mediapipe/framework/formats/motion:optical_flow_field_data_cc_proto", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 68a9af52d..01ef6ee86 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -88,8 +88,8 @@ cc_library( srcs = ["default_input_stream_handler.cc"], hdrs = ["default_input_stream_handler.h"], deps = [ - ":default_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", + "//mediapipe/framework/stream_handler:default_input_stream_handler_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -110,8 +110,8 @@ cc_library( srcs = ["fixed_size_input_stream_handler.cc"], deps = [ ":default_input_stream_handler", - ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", + "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], alwayslink = 1, ) @@ -159,13 +159,13 @@ cc_library( name = "sync_set_input_stream_handler", srcs = ["sync_set_input_stream_handler.cc"], deps = [ - ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework:timestamp", + "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "//mediapipe/framework/tool:tag_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -177,10 +177,10 @@ cc_library( name = "timestamp_align_input_stream_handler", srcs = ["timestamp_align_input_stream_handler.cc"], deps = [ - ":timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:timestamp", + "//mediapipe/framework/stream_handler:timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework/tool:validate_name", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -243,7 +243,6 @@ cc_test( srcs = ["set_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", - ":fixed_size_input_stream_handler_cc_proto", ":mux_input_stream_handler", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", @@ -252,6 +251,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], ) @@ -272,13 +272,13 @@ cc_test( srcs = ["fixed_size_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", - ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", ], @@ -289,11 +289,11 @@ cc_test( srcs = ["sync_set_input_stream_handler_test.cc"], deps = [ ":sync_set_input_stream_handler", - ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:test_calculators", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 193343a90..89cb802da 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -299,7 +299,6 @@ mediapipe_cc_test( requires_full_emulation = False, deps = [ ":node_chain_subgraph_cc_proto", - ":node_chain_subgraph_options_lib", ":options_field_util", ":options_registry", ":options_syntax_util", @@ -314,6 +313,7 @@ mediapipe_cc_test( "//mediapipe/framework/port:status", "//mediapipe/framework/testdata:night_light_calculator_cc_proto", "//mediapipe/framework/testdata:night_light_calculator_options_lib", + "//mediapipe/framework/tool:node_chain_subgraph_options_lib", "//mediapipe/util:header_util", "@com_google_absl//absl/strings", ], @@ -422,9 +422,9 @@ cc_library( srcs = ["source.cc"], visibility = ["//visibility:public"], deps = [ - ":source_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:source_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], @@ -485,13 +485,13 @@ cc_library( hdrs = ["template_expander.h"], visibility = ["//visibility:public"], deps = [ - ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:numbers", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/strings", ], ) @@ -506,7 +506,6 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:proto_descriptor_cc_proto", @@ -516,6 +515,7 @@ cc_library( "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -661,8 +661,8 @@ cc_library( hdrs = ["simulation_clock_executor.h"], visibility = ["//visibility:public"], deps = [ - ":simulation_clock", "//mediapipe/framework:thread_pool_executor", + "//mediapipe/framework/tool:simulation_clock", ], ) @@ -789,10 +789,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":name_util", - ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:switch_container_cc_proto", ], ) @@ -805,7 +805,6 @@ cc_library( deps = [ ":container_util", ":options_util", - ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -815,6 +814,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -841,7 +841,6 @@ cc_library( ], deps = [ ":container_util", - ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_shard", @@ -851,6 +850,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "//mediapipe/framework/tool:switch_container_cc_proto", ], alwayslink = 1, ) @@ -893,7 +893,6 @@ cc_library( ":container_util", ":name_util", ":subgraph_expansion", - ":switch_container_cc_proto", ":switch_demux_calculator", ":switch_mux_calculator", "//mediapipe/calculators/core:packet_sequencer_calculator", @@ -905,6 +904,7 @@ cc_library( "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index cc5e50dfc..009eb3f9e 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -564,7 +564,6 @@ cc_library( name = "gpu_shared_data_internal_stub", visibility = ["//visibility:private"], deps = [ - ":gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", @@ -572,6 +571,7 @@ cc_library( "//mediapipe/framework:port", "//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/port:ret_check", + "//mediapipe/gpu:gl_context_options_cc_proto", ], ) @@ -592,7 +592,7 @@ cc_library( }), visibility = ["//visibility:private"], deps = [ - ":gl_context_options_cc_proto", + "//mediapipe/gpu:gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:executor", @@ -833,10 +833,10 @@ cc_library( deps = [ ":gl_base", ":gl_simple_shaders", - ":scale_mode_cc_proto", ":shader_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/gpu:scale_mode_cc_proto", ], ) @@ -907,8 +907,8 @@ proto_library( srcs = ["gl_scaler_calculator.proto"], visibility = ["//visibility:public"], deps = [ - ":scale_mode_proto", "//mediapipe/framework:calculator_proto", + "//mediapipe/gpu:scale_mode_proto", ], ) @@ -930,7 +930,6 @@ cc_library( deps = [ ":gl_calculator_helper", ":gl_quad_renderer", - ":gl_scaler_calculator_cc_proto", ":gl_simple_shaders", ":shader_util", "//mediapipe/framework:calculator_framework", @@ -938,6 +937,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_util", + "//mediapipe/gpu:gl_scaler_calculator_cc_proto", ], alwayslink = 1, ) @@ -950,13 +950,13 @@ cc_library( ":egl_surface_holder", ":gl_calculator_helper", ":gl_quad_renderer", - ":gl_surface_sink_calculator_cc_proto", ":gpu_buffer", ":shader_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/gpu:gl_surface_sink_calculator_cc_proto", "@com_google_absl//absl/synchronization", ], alwayslink = 1, @@ -966,8 +966,8 @@ proto_library( name = "gl_surface_sink_calculator_proto", srcs = ["gl_surface_sink_calculator.proto"], deps = [ - ":scale_mode_proto", "//mediapipe/framework:calculator_proto", + "//mediapipe/gpu:scale_mode_proto", ], ) From e405c2b67d68e6c99fbd7bbf4731ce4a387201f7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 20 Dec 2022 10:59:23 -0800 Subject: [PATCH 185/346] Internal change PiperOrigin-RevId: 496702117 --- .../calculators/image/affine_transformation_runner_gl.cc | 6 +++--- .../tensor/image_to_tensor_converter_gl_texture.cc | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/calculators/image/affine_transformation_runner_gl.cc b/mediapipe/calculators/image/affine_transformation_runner_gl.cc index c38fc8e07..361dfc902 100644 --- a/mediapipe/calculators/image/affine_transformation_runner_gl.cc +++ b/mediapipe/calculators/image/affine_transformation_runner_gl.cc @@ -92,8 +92,8 @@ class GlTextureWarpAffineRunner constexpr GLchar kVertShader[] = R"( in vec4 position; - in mediump vec4 texture_coordinate; - out mediump vec2 sample_coordinate; + in highp vec4 texture_coordinate; + out highp vec2 sample_coordinate; uniform mat4 transform_matrix; void main() { @@ -104,7 +104,7 @@ class GlTextureWarpAffineRunner )"; constexpr GLchar kFragShader[] = R"( - DEFAULT_PRECISION(mediump, float) + DEFAULT_PRECISION(highp, float) in vec2 sample_coordinate; uniform sampler2D input_texture; diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc index 5efd34041..165df8970 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc @@ -68,8 +68,8 @@ class GlProcessor : public ImageToTensorConverter { constexpr GLchar kExtractSubRectVertexShader[] = R"( in vec4 position; - in mediump vec4 texture_coordinate; - out mediump vec2 sample_coordinate; + in highp vec4 texture_coordinate; + out highp vec2 sample_coordinate; uniform mat4 transform_matrix; void main() { @@ -86,7 +86,7 @@ class GlProcessor : public ImageToTensorConverter { )"; constexpr GLchar kExtractSubRectFragBody[] = R"( - DEFAULT_PRECISION(mediump, float) + DEFAULT_PRECISION(highp, float) // Provided by kExtractSubRectVertexShader. in vec2 sample_coordinate; From e997a19289d85071775751b453aa2e1b982f3891 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:22:32 +0530 Subject: [PATCH 186/346] Added common utils and string helpers --- mediapipe/tasks/ios/common/utils/BUILD | 41 ++++++ .../ios/common/utils/sources/MPPCommonUtils.h | 78 ++++++++++ .../common/utils/sources/MPPCommonUtils.mm | 137 ++++++++++++++++++ .../common/utils/sources/NSString+Helpers.h | 28 ++++ .../common/utils/sources/NSString+Helpers.mm | 27 ++++ 5 files changed, 311 insertions(+) create mode 100644 mediapipe/tasks/ios/common/utils/BUILD create mode 100644 mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h create mode 100644 mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm create mode 100644 mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h create mode 100644 mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm diff --git a/mediapipe/tasks/ios/common/utils/BUILD b/mediapipe/tasks/ios/common/utils/BUILD new file mode 100644 index 000000000..f2ffda39e --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/BUILD @@ -0,0 +1,41 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCommonUtils", + srcs = ["sources/MPPCommonUtils.mm"], + hdrs = ["sources/MPPCommonUtils.h"], + deps = [ + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/ios/common:MPPCommon", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) + +objc_library( + name = "NSStringHelpers", + srcs = ["sources/NSString+Helpers.mm"], + hdrs = ["sources/NSString+Helpers.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], +) + diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h new file mode 100644 index 000000000..8a90856c7 --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -0,0 +1,78 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#include "mediapipe/tasks/cc/common.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Error domain of Mediapipe Task related errors. */ +extern NSString *const MPPTasksErrorDomain; + +/** Helper utility for the all tasks which encapsulates common functionality. */ +@interface MPPCommonUtils : NSObject + +/** + * Creates and saves an NSError in the Mediapipe task library domain, with the given code and + * description. + * + * @param code Error code. + * @param description Error description. + * @param error Pointer to the memory location where the created error should be saved. If `nil`, + * no error will be saved. + */ ++ (void)createCustomError:(NSError **)error + withCode:(NSUInteger)code + description:(NSString *)description; + +/** + * Creates and saves an NSError with the given domain, code and description. + * + * @param error Pointer to the memory location where the created error should be saved. If `nil`, + * no error will be saved. + * @param domain Error domain. + * @param code Error code. + * @param description Error description. + */ ++ (void)createCustomError:(NSError **)error + withDomain:(NSString *)domain + code:(NSUInteger)code + description:(NSString *)description; + +/** + * Converts an absl status to an NSError. + * + * @param status absl status. + * @param error Pointer to the memory location where the created error should be saved. If `nil`, + * no error will be saved. + */ ++ (BOOL)checkCppError:(const absl::Status &)status toError:(NSError **)error; + +/** + * Allocates a block of memory with the specified size and returns a pointer to it. If memory + * cannot be allocated because of an invalid memSize, it saves an error. In other cases, it + * terminates program execution. + * + * @param memSize size of memory to be allocated + * @param error Pointer to the memory location where errors if any should be saved. If `nil`, no + * error will be saved. + * + * @return Pointer to the allocated block of memory on successfull allocation. nil in case as + * error is encountered because of invalid memSize. If failure is due to any other reason, method + * terminates program execution. + */ ++ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error; +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm new file mode 100644 index 000000000..574f2ef9a --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -0,0 +1,137 @@ +// Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" + +#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" + +#include + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/cord.h" // from @com_google_absl + +#include "mediapipe/tasks/cc/common.h" + +/** Error domain of MediaPipe task library errors. */ +NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; + +@implementation MPPCommonUtils + ++ (void)createCustomError:(NSError **)error + withCode:(NSUInteger)code + description:(NSString *)description { + [MPPCommonUtils createCustomError:error + withDomain:MPPTasksErrorDomain + code:code + description:description]; +} + ++ (void)createCustomError:(NSError **)error + withDomain:(NSString *)domain + code:(NSUInteger)code + description:(NSString *)description { + if (error) { + *error = [NSError errorWithDomain:domain + code:code + userInfo:@{NSLocalizedDescriptionKey : description}]; + } +} + ++ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error { + if (!memSize) { + [MPPCommonUtils createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description:@"memSize cannot be zero."]; + return NULL; + } + + void *allocedMemory = malloc(memSize); + if (!allocedMemory) { + exit(-1); + } + + return allocedMemory; +} + ++ (BOOL)checkCppError:(const absl::Status &)status toError:(NSError *_Nullable *)error { + if (status.ok()) { + return YES; + } + // Payload of absl::Status created by the Media Pipe task library stores an appropriate value of + // the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum + // stored in the payload is extracted here to later map to the appropriate error code to be + // returned. In cases where the enum is not stored in (payload is NULL or the payload string + // cannot be converted to an integer), we set the error code value to be 1 + // (MPPTasksErrorCodeError of MPPTasksErrorCode used in the iOS library to signify + // any errors not falling into other categories.) Since payload is of type absl::Cord that can be + // type cast into an absl::optional, we use the std::stoi function to convert it into + // an integer code if possible. + NSUInteger genericErrorCode = MPPTasksErrorCodeError; + NSUInteger errorCode; + try { + // Try converting payload to integer if payload is not empty. Otherwise convert a string + // signifying generic error code MPPTasksErrorCodeError to integer. + errorCode = + (NSUInteger)std::stoi(static_cast>( + status.GetPayload(mediapipe::tasks::kMediaPipeTasksPayload)) + .value_or(std::to_string(genericErrorCode))); + } catch (std::invalid_argument &e) { + // If non empty payload string cannot be converted to an integer. Set error code to 1(kError). + errorCode = MPPTasksErrorCodeError; + } + + // If errorCode is outside the range of enum values possible or is + // MPPTasksErrorCodeError, we try to map the absl::Status::code() to assign + // appropriate MPPTasksErrorCode in default cases. Note: + // The mapping to absl::Status::code() is done to generate a more specific error code than + // MPPTasksErrorCodeError in cases when the payload can't be mapped to + // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn returned + // without modification by Mediapipe cc library methods. + if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) { + switch (status.code()) { + case absl::StatusCode::kInternal: + errorCode = MPPTasksErrorCodeError; + break; + case absl::StatusCode::kInvalidArgument: + errorCode = MPPTasksErrorCodeInvalidArgumentError; + break; + case absl::StatusCode::kNotFound: + errorCode = MPPTasksErrorCodeError; + break; + default: + errorCode = MPPTasksErrorCodeError; + break; + } + } + + // Creates the NSEror with the appropriate error + // MPPTasksErrorCode and message. MPPTasksErrorCode has a one to one + // mapping with MediaPipeTasksStatus starting from the value 1(MPPTasksErrorCodeError) + // and hence will be correctly initialized if directly cast from the integer code derived from + // MediaPipeTasksStatus stored in its payload. MPPTasksErrorCode omits kOk = 0 of + // MediaPipeTasksStatusx. + // + // Stores a string including absl status code and message(if non empty) as the + // error message See + // https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h#L514 + // for explanation. absl::Status::message() can also be used but not always + // guaranteed to be non empty. + NSString *description = [NSString + stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str() + encoding:NSUTF8StringEncoding]; + [MPPCommonUtils createCustomError:error withCode:errorCode description:description]; + return NO; +} + +@end diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h new file mode 100644 index 000000000..aac7485da --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h @@ -0,0 +1,28 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#include + +NS_ASSUME_NONNULL_BEGIN + +@interface NSString (Helpers) + +@property(readonly) std::string cppString; + ++ (NSString *)stringWithCppString:(std::string)text; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm new file mode 100644 index 000000000..183ed4365 --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm @@ -0,0 +1,27 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +@implementation NSString (Helpers) + +- (std::string)cppString { + return std::string(self.UTF8String, [self lengthOfBytesUsingEncoding:NSUTF8StringEncoding]); +} + ++ (NSString *)stringWithCppString:(std::string)text { + return [NSString stringWithCString:text.c_str() encoding:[NSString defaultCStringEncoding]]; +} + +@end From 03bfbca53940d71b68c5c6e5c6a697abbf9fe5fe Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:22:44 +0530 Subject: [PATCH 187/346] Added classifier options --- .../tasks/ios/components/processors/BUILD | 24 +++++++++++ .../processors/sources/MPPClassifierOptions.h | 42 +++++++++++++++++++ .../processors/sources/MPPClassifierOptions.m | 40 ++++++++++++++++++ 3 files changed, 106 insertions(+) create mode 100644 mediapipe/tasks/ios/components/processors/BUILD create mode 100644 mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h create mode 100644 mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m diff --git a/mediapipe/tasks/ios/components/processors/BUILD b/mediapipe/tasks/ios/components/processors/BUILD new file mode 100644 index 000000000..6d1cfdf59 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/BUILD @@ -0,0 +1,24 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPClassifierOptions", + srcs = ["sources/MPPClassifierOptions.m"], + hdrs = ["sources/MPPClassifierOptions.h"], +) + diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h new file mode 100644 index 000000000..8c4981642 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -0,0 +1,42 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * Holds settings for any single iOS Mediapipe classification task. + */ +NS_SWIFT_NAME(ClassifierOptions) +@interface MPPClassifierOptions : NSObject + +/** If set, all classes in this list will be filtered out from the results . */ +@property(nonatomic, copy) NSArray *labelDenyList; + +/** If set, all classes not in this list will be filtered out from the results . */ +@property(nonatomic, copy) NSArray *labelAllowList; + +/** Display names local for display names*/ +@property(nonatomic, copy) NSString *displayNamesLocale; + +/** Results with score threshold greater than this value are returned . */ +@property(nonatomic) float scoreThreshold; + +/** Limit to the number of classes that can be returned in results. */ +@property(nonatomic) NSInteger maxResults; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m new file mode 100644 index 000000000..52dce23e4 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m @@ -0,0 +1,40 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" + +@implementation MPPClassifierOptions + +- (instancetype)init { + self = [super init]; + if (self) { + self.maxResults = -1; + self.scoreThreshold = 0; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPClassifierOptions *classifierOptions = [[MPPClassifierOptions alloc] init]; + + classifierOptions.scoreThreshold = self.scoreThreshold; + classifierOptions.maxResults = self.maxResults; + classifierOptions.labelDenyList = self.labelDenyList; + classifierOptions.labelAllowList = self.labelAllowList; + classifierOptions.displayNamesLocale = self.displayNamesLocale; + + return classifierOptions; +} + +@end From c56ef735d7f8bf4fb7482c2b7dd01a61c3d0ffc4 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:22:57 +0530 Subject: [PATCH 188/346] Added classifier options helpers --- .../ios/components/processors/utils/BUILD | 29 ++++++++++++++ .../sources/MPPClassifierOptions+Helpers.h | 25 ++++++++++++ .../sources/MPPClassifierOptions+Helpers.mm | 38 +++++++++++++++++++ 3 files changed, 92 insertions(+) create mode 100644 mediapipe/tasks/ios/components/processors/utils/BUILD create mode 100644 mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h create mode 100644 mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm diff --git a/mediapipe/tasks/ios/components/processors/utils/BUILD b/mediapipe/tasks/ios/components/processors/utils/BUILD new file mode 100644 index 000000000..820c6bb56 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/utils/BUILD @@ -0,0 +1,29 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPClassifierOptionsHelpers", + srcs = ["sources/MPPClassifierOptions+Helpers.mm"], + hdrs = ["sources/MPPClassifierOptions+Helpers.h"], + deps = [ + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/ios/components/processors:MPPClassifierOptions", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + ] +) + diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h new file mode 100644 index 000000000..6644a6255 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h @@ -0,0 +1,25 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" +#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPClassifierOptions (Helpers) +- (void)copyToProto: + (mediapipe::tasks::components::processors::proto::ClassifierOptions *)classifierOptionsProto; +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm new file mode 100644 index 000000000..25e657599 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm @@ -0,0 +1,38 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h" + +namespace { +using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto::ClassifierOptions; +} + +@implementation MPPClassifierOptions (Helpers) +- (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto { + if (self.displayNamesLocale) { + classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); + } + classifierOptionsProto->set_max_results((int)self.maxResults); + classifierOptionsProto->set_score_threshold(self.scoreThreshold); + for (NSString *category in self.labelAllowList) { + classifierOptionsProto->add_category_allowlist(category.cppString); + } + + for (NSString *category in self.labelDenyList) { + classifierOptionsProto->add_category_denylist(category.cppString); + } +} + +@end From 6d02108bf5244f190dc07e035913694864138467 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:23:29 +0530 Subject: [PATCH 189/346] Added task info --- .../tasks/ios/core/sources/MPPTaskInfo.h | 69 +++++++++ .../tasks/ios/core/sources/MPPTaskInfo.mm | 136 ++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskInfo.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h new file mode 100644 index 000000000..fca660fae --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h @@ -0,0 +1,69 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#include "mediapipe/framework/calculator.pb.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" + + +NS_ASSUME_NONNULL_BEGIN + +/** + * Holds all needed informaton to initialize a MediaPipe Task. + */ +@interface MPPTaskInfo : NSObject + +@property(nonatomic, copy, nonnull) NSString *taskGraphName; + +/** + * A task-specific options that is derived from MPPTaskOptions and confirms to + * MPPTaskOptionsProtocol. + */ +@property(nonatomic, copy) id taskOptions; + +/** + * List of task graph input stream info strings in the form TAG:name. + */ +@property(nonatomic, copy) NSArray *inputStreams; + +/** + * List of task graph output stream info in the form TAG:name. + */ +@property(nonatomic, copy) NSArray *outputStreams; + +/** + * If the task requires a flow limiter. + */ +@property(nonatomic) BOOL enableFlowLimiting; + ++ (instancetype)new NS_UNAVAILABLE; + +- (instancetype)initWithTaskGraphName:(NSString *)taskGraphName + inputStreams:(NSArray *)inputStreams + outputStreams:(NSArray *)outputStreams + taskOptions:(id)taskOptions + enableFlowLimiting:(BOOL)enableFlowLimiting + error:(NSError **)error; + +/** + * Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance. + */ +- (mediapipe::CalculatorGraphConfig)generateGraphConfig; + +- (instancetype)init NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm new file mode 100644 index 000000000..7d2fd6f28 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm @@ -0,0 +1,136 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_options.pb.h" + +namespace { +using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig; +using Node = ::mediapipe::CalculatorGraphConfig::Node; +using ::mediapipe::InputStreamInfo; +using ::mediapipe::CalculatorOptions; +using ::mediapipe::FlowLimiterCalculatorOptions; +} // namespace + +@implementation MPPTaskInfo + +- (instancetype)initWithTaskGraphName:(NSString *)taskGraphName + inputStreams:(NSArray *)inputStreams + outputStreams:(NSArray *)outputStreams + taskOptions:(id)taskOptions + enableFlowLimiting:(BOOL)enableFlowLimiting + error:(NSError **)error { + if (!taskGraphName || !inputStreams.count || !outputStreams.count) { + [MPPCommonUtils + createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description: + @"Task graph's name, input streams, and output streams should be non-empty."]; + } + + self = [super init]; + + if (self) { + _taskGraphName = taskGraphName; + _inputStreams = inputStreams; + _outputStreams = outputStreams; + _taskOptions = taskOptions; + _enableFlowLimiting = enableFlowLimiting; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] init]; + + taskInfo.taskGraphName = self.taskGraphName; + taskInfo.inputStreams = self.inputStreams; + taskInfo.outputStreams = self.outputStreams; + taskInfo.taskOptions = self.taskOptions; + taskInfo.enableFlowLimiting = self.enableFlowLimiting; + + return taskInfo; +} + +- (CalculatorGraphConfig)generateGraphConfig { + CalculatorGraphConfig graph_config; + + Node *task_subgraph_node = graph_config.add_node(); + task_subgraph_node->set_calculator(self.taskGraphName.cppString); + [self.taskOptions copyToProto:task_subgraph_node->mutable_options()]; + + for (NSString *outputStream in self.outputStreams) { + auto cpp_output_stream = std::string(outputStream.cppString); + task_subgraph_node->add_output_stream(cpp_output_stream); + graph_config.add_output_stream(cpp_output_stream); + } + + if (self.enableFlowLimiting) { + Node *flow_limit_calculator_node = graph_config.add_node(); + + flow_limit_calculator_node->set_calculator("FlowLimiterCalculator"); + + InputStreamInfo *input_stream_info = flow_limit_calculator_node->add_input_stream_info(); + input_stream_info->set_tag_index("FINISHED"); + input_stream_info->set_back_edge(true); + + FlowLimiterCalculatorOptions *flow_limit_calculator_options = + flow_limit_calculator_node->mutable_options()->MutableExtension( + FlowLimiterCalculatorOptions::ext); + flow_limit_calculator_options->set_max_in_flight(1); + flow_limit_calculator_options->set_max_in_queue(1); + + for (NSString *inputStream in self.inputStreams) { + graph_config.add_input_stream(inputStream.cppString); + + NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream]; + flow_limit_calculator_node->add_input_stream(strippedInputStream.cppString); + + NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream]; + task_subgraph_node->add_input_stream(taskInputStream.cppString); + + NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream]; + flow_limit_calculator_node->add_output_stream(strippedTaskInputStream.cppString); + } + + NSString *firstOutputStream = self.outputStreams[0]; + auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString; + flow_limit_calculator_node->add_input_stream(finished_output_stream); + } else { + for (NSString *inputStream in self.inputStreams) { + auto cpp_input_stream = inputStream.cppString; + task_subgraph_node->add_input_stream(cpp_input_stream); + graph_config.add_input_stream(cpp_input_stream); + } + } + + return graph_config; +} + ++ (NSString *)stripTagIndex:(NSString *)tagIndexName { + return [tagIndexName componentsSeparatedByString:@":"][1]; +} + ++ (NSString *)addStreamNamePrefix:(NSString *)tagIndexName { + NSArray *splits = [tagIndexName componentsSeparatedByString:@":"]; + return [NSString stringWithFormat:@"%@:throttled_%@", splits[0], splits[1]]; +} + +@end From 64cf5e9b4e7208b3b974038d9e160e1509be1945 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:23:41 +0530 Subject: [PATCH 190/346] Added iOS task options protocol --- .../ios/core/sources/MPPTaskOptionsProtocol.h | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h new file mode 100644 index 000000000..c6f115451 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h @@ -0,0 +1,32 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#include "mediapipe/framework/calculator_options.pb.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Any mediapipe task options should confirm to this protocol. + */ +@protocol MPPTaskOptionsProtocol + +/** + * Copies the iOS Mediapipe task options to an object of mediapipe::CalculatorOptions proto. + */ +- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto; + +@end + +NS_ASSUME_NONNULL_END From e9fc3713f0bd0ab0fceb9ea07e78373fc8c50efd Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:23:51 +0530 Subject: [PATCH 191/346] Added iOS task runner --- .../tasks/ios/core/sources/MPPTaskRunner.h | 47 ++++++++++++++++ .../tasks/ios/core/sources/MPPTaskRunner.mm | 56 +++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskRunner.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h new file mode 100644 index 000000000..64e34b82e --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -0,0 +1,47 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" + + +NS_ASSUME_NONNULL_BEGIN + +/** + * This class is used to create and call appropriate methods on the C++ Task Runner. + */ +@interface MPPTaskRunner : NSObject +/** + * Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto. + * + * @param graphConfig A mediapipe task graph config proto. + * + * @return An instance of `MPPTaskRunner` initialized to the given graph config proto. + */ +- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + error:(NSError **)error; + +- (absl::StatusOr)process:(const mediapipe::tasks::core::PacketMap&)packetMap error:(NSError **)error; + +- (void)close; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm new file mode 100644 index 000000000..404f6c582 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm @@ -0,0 +1,56 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" + +namespace { +using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; +using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; +} // namespace + +@interface MPPTaskRunner () { + // Cpp Task Runner + std::unique_ptr _cppTaskRunner; +} +@end + +@implementation MPPTaskRunner + +- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + error:(NSError **)error { + self = [super init]; + if (self) { + auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig)); + + if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) { + return nil; + } + + _cppTaskRunner = std::move(taskRunnerResult.value()); + } + return self; +} + +- (absl::StatusOr)process:(const PacketMap&)packetMap { + return _cppTaskRunner->Process(packetMap); +} + +- (void)close { + _cppTaskRunner->Close(); +} + +@end From 4fedea60a93adb6ac9db50212b7f06f29758576e Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:24:02 +0530 Subject: [PATCH 192/346] Added text packet creator --- .../ios/core/sources/MPPTextPacketCreator.h | 26 +++++++++++++++++ .../ios/core/sources/MPPTextPacketCreator.mm | 29 +++++++++++++++++++ 2 files changed, 55 insertions(+) create mode 100644 mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm diff --git a/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h new file mode 100644 index 000000000..03f946dd0 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h @@ -0,0 +1,26 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#include "mediapipe/framework/packet.h" + +/* This class is an Objective-C wrapper around a MediaPipe graph object, and + * helps interface it with iOS technologies such as AVFoundation. + */ +@interface MPPTextPacketCreator : NSObject + ++ (mediapipe::Packet)createWithText:(NSString *)text; + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm new file mode 100644 index 000000000..ca86e7a0b --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm @@ -0,0 +1,29 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +namespace { +using ::mediapipe::MakePacket; +using ::mediapipe::Packet; +} // namespace + +@implementation MPPTextPacketCreator + ++ (Packet)createWithText:(NSString *)text { + return MakePacket(text.cppString); +} + +@end From ff901a80a5398276b04e03e561c9acf0892d3aaa Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:24:11 +0530 Subject: [PATCH 193/346] Added targets in core --- mediapipe/tasks/ios/core/BUILD | 53 ++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index 7b648945e..adc37d901 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -36,3 +36,56 @@ objc_library( srcs = ["sources/MPPTaskResult.m"], hdrs = ["sources/MPPTaskResult.h"], ) + +objc_library( + name = "MPPTaskOptionsProtocol", + hdrs = ["sources/MPPTaskOptionsProtocol.h"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + ], +) + +objc_library( + name = "MPPTaskInfo", + srcs = ["sources/MPPTaskInfo.mm"], + hdrs = ["sources/MPPTaskInfo.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", + ":MPPTaskOptions", + ":MPPTaskOptionsProtocol", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/common:MPPCommon", + ], +) + +objc_library( + name = "MPPTextPacketCreator", + srcs = ["sources/MPPTextPacketCreator.mm"], + hdrs = ["sources/MPPTextPacketCreator.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + ], +) + +objc_library( + name = "MPPTaskRunner", + srcs = ["sources/MPPTaskRunner.mm"], + hdrs = ["sources/MPPTaskRunner.h"], + deps = [ + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + ], +) From ce0bc2b9acb9c11d0e54aabd8cb9430aedfc0c9b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 20 Dec 2022 13:51:12 -0800 Subject: [PATCH 194/346] Internal change PiperOrigin-RevId: 496742964 --- .github/bot_config.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/bot_config.yml b/.github/bot_config.yml index 74a60e4b9..8ad724168 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -15,5 +15,4 @@ # A list of assignees assignees: - - kuaashish - - ayushgdev + - sureshdagooglecom From a7b52d2c5281e82c208932ff2bedcf85356f868f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 20 Dec 2022 14:34:40 -0800 Subject: [PATCH 195/346] Internal changes PiperOrigin-RevId: 496754449 --- mediapipe/model_maker/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt index 9b3c9f906..d7e4a950f 100644 --- a/mediapipe/model_maker/requirements.txt +++ b/mediapipe/model_maker/requirements.txt @@ -1,5 +1,5 @@ absl-py -mediapipe==0.9.1 +mediapipe==0.9.0.1 numpy opencv-python tensorflow>=2.10 From d2f738793c8ec5ad9b66aeec78fac74a76b37100 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 20 Dec 2022 15:15:24 -0800 Subject: [PATCH 196/346] Use uppercase options name for "delegate" PiperOrigin-RevId: 496764089 --- .../tasks/web/components/processors/base_options.test.ts | 6 +++--- mediapipe/tasks/web/components/processors/base_options.ts | 2 +- mediapipe/tasks/web/core/task_runner_options.d.ts | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/web/components/processors/base_options.test.ts b/mediapipe/tasks/web/components/processors/base_options.test.ts index 46c2277e9..6d58be68f 100644 --- a/mediapipe/tasks/web/components/processors/base_options.test.ts +++ b/mediapipe/tasks/web/components/processors/base_options.test.ts @@ -86,7 +86,7 @@ describe('convertBaseOptionsToProto()', () => { it('can enable CPU delegate', async () => { const baseOptionsProto = await convertBaseOptionsToProto({ modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'cpu', + delegate: 'CPU', }); expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); }); @@ -94,7 +94,7 @@ describe('convertBaseOptionsToProto()', () => { it('can enable GPU delegate', async () => { const baseOptionsProto = await convertBaseOptionsToProto({ modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'gpu', + delegate: 'GPU', }); expect(baseOptionsProto.toObject()).toEqual({ ...mockBytesResult, @@ -117,7 +117,7 @@ describe('convertBaseOptionsToProto()', () => { it('can reset delegate', async () => { let baseOptionsProto = await convertBaseOptionsToProto({ modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'gpu', + delegate: 'GPU', }); // Clear backend baseOptionsProto = diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts index 16d562262..97b62b784 100644 --- a/mediapipe/tasks/web/components/processors/base_options.ts +++ b/mediapipe/tasks/web/components/processors/base_options.ts @@ -71,7 +71,7 @@ async function configureExternalFile( /** Configues the `acceleration` option. */ function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) { const acceleration = proto.getAcceleration() ?? new Acceleration(); - if (options.delegate === 'gpu') { + if (options.delegate === 'GPU') { acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); } else { acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); diff --git a/mediapipe/tasks/web/core/task_runner_options.d.ts b/mediapipe/tasks/web/core/task_runner_options.d.ts index aa0b4a028..5f23cd4bf 100644 --- a/mediapipe/tasks/web/core/task_runner_options.d.ts +++ b/mediapipe/tasks/web/core/task_runner_options.d.ts @@ -31,7 +31,7 @@ export declare interface BaseOptions { modelAssetBuffer?: Uint8Array|undefined; /** Overrides the default backend to use for the provided model. */ - delegate?: 'cpu'|'gpu'|undefined; + delegate?: 'CPU'|'GPU'|undefined; } /** Options to configure MediaPipe Tasks in general. */ From 64406a9bf27cd324e6856dbeb0f8b9c69d496ac7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 20 Dec 2022 16:39:52 -0800 Subject: [PATCH 197/346] Internal change PiperOrigin-RevId: 496781536 --- mediapipe/calculators/core/BUILD | 14 ++--- mediapipe/calculators/image/BUILD | 4 +- mediapipe/calculators/util/BUILD | 2 +- mediapipe/framework/BUILD | 70 ++++++++++++------------ mediapipe/framework/formats/BUILD | 24 ++++---- mediapipe/framework/formats/motion/BUILD | 2 +- mediapipe/framework/stream_handler/BUILD | 14 ++--- mediapipe/framework/tool/BUILD | 18 +++--- mediapipe/gpu/BUILD | 14 ++--- 9 files changed, 81 insertions(+), 81 deletions(-) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 2c143a609..b3378a74e 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -567,7 +567,7 @@ cc_library( name = "packet_thinner_calculator", srcs = ["packet_thinner_calculator.cc"], deps = [ - "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", + ":packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:video_stream_header", @@ -584,7 +584,7 @@ cc_test( srcs = ["packet_thinner_calculator_test.cc"], deps = [ ":packet_thinner_calculator", - "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", + ":packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -762,7 +762,7 @@ cc_library( srcs = ["packet_resampler_calculator.cc"], hdrs = ["packet_resampler_calculator.h"], deps = [ - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + ":packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -786,7 +786,7 @@ cc_test( ], deps = [ ":packet_resampler_calculator", - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + ":packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -852,10 +852,10 @@ cc_test( name = "flow_limiter_calculator_test", srcs = ["flow_limiter_calculator_test.cc"], deps = [ + ":counting_source_calculator", ":flow_limiter_calculator", ":flow_limiter_calculator_cc_proto", - "//mediapipe/calculators/core:counting_source_calculator", - "//mediapipe/calculators/core:pass_through_calculator", + ":pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:test_calculators", @@ -1302,7 +1302,7 @@ cc_test( srcs = ["packet_sequencer_calculator_test.cc"], deps = [ ":packet_sequencer_calculator", - "//mediapipe/calculators/core:pass_through_calculator", + ":pass_through_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:subgraph", diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 530dd3d4a..9aae8cfbc 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -378,8 +378,8 @@ cc_library( name = "scale_image_calculator", srcs = ["scale_image_calculator.cc"], deps = [ + ":scale_image_calculator_cc_proto", ":scale_image_utils", - "//mediapipe/calculators/image:scale_image_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", @@ -747,8 +747,8 @@ cc_test( tags = ["desktop_only_test"], deps = [ ":affine_transformation", + ":image_transformation_calculator", ":warp_affine_calculator", - "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/tensor:image_to_tensor_converter", "//mediapipe/calculators/tensor:image_to_tensor_utils", "//mediapipe/calculators/util:from_image_calculator", diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 1529ead8a..a679a80fd 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -22,8 +22,8 @@ cc_library( name = "alignment_points_to_rects_calculator", srcs = ["alignment_points_to_rects_calculator.cc"], deps = [ + ":detections_to_rects_calculator", ":detections_to_rects_calculator_cc_proto", - "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 0dd694760..082ea9994 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -226,13 +226,13 @@ cc_library( ":mediapipe_internal", ], deps = [ + ":calculator_cc_proto", ":graph_service", + ":mediapipe_options_cc_proto", + ":packet_generator_cc_proto", ":packet_type", ":port", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:mediapipe_options_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", + ":status_handler_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_map", @@ -328,10 +328,10 @@ cc_library( ":thread_pool_executor", ":timestamp", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", + ":calculator_cc_proto", + ":packet_generator_cc_proto", + ":status_handler_cc_proto", + ":thread_pool_executor_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", @@ -391,6 +391,7 @@ cc_library( visibility = [":mediapipe_internal"], deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -407,10 +408,9 @@ cc_library( ":packet_set", ":packet_type", ":port", + ":stream_handler_cc_proto", ":timestamp", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:stream_handler_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -466,6 +466,7 @@ cc_library( hdrs = ["calculator_state.h"], visibility = [":mediapipe_internal"], deps = [ + ":calculator_cc_proto", ":counter", ":counter_factory", ":graph_service", @@ -475,7 +476,6 @@ cc_library( ":packet", ":packet_set", ":port", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/tool:options_map", @@ -583,7 +583,7 @@ cc_library( hdrs = ["executor.h"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:mediapipe_options_cc_proto", + ":mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", @@ -670,11 +670,11 @@ cc_library( ":collection_item_id", ":input_stream_manager", ":input_stream_shard", + ":mediapipe_options_cc_proto", ":mediapipe_profiling", ":packet", ":packet_set", ":packet_type", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -784,12 +784,12 @@ cc_library( ":calculator_context_manager", ":collection", ":collection_item_id", + ":mediapipe_options_cc_proto", ":output_stream_manager", ":output_stream_shard", ":packet_set", ":packet_type", ":timestamp", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -875,10 +875,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":packet", + ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:status", @@ -896,13 +896,13 @@ cc_library( ":delegating_executor", ":executor", ":packet", + ":packet_factory_cc_proto", ":packet_generator", + ":packet_generator_cc_proto", ":packet_type", ":port", ":thread_pool_executor", ":validated_graph_config", - "//mediapipe/framework:packet_factory_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -1019,10 +1019,10 @@ cc_library( hdrs = ["status_handler.h"], visibility = ["//visibility:public"], deps = [ + ":mediapipe_options_cc_proto", ":packet_set", ":packet_type", ":port", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "@com_google_absl//absl/memory", @@ -1035,10 +1035,10 @@ cc_library( hdrs = ["subgraph.h"], visibility = ["//visibility:public"], deps = [ + ":calculator_cc_proto", ":graph_service", ":graph_service_manager", ":port", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1097,7 +1097,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":executor", - "//mediapipe/framework:thread_pool_executor_cc_proto", + ":thread_pool_executor_cc_proto", "//mediapipe/framework/deps:thread_options", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -1162,22 +1162,22 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_contract", ":graph_service_manager", ":legacy_calculator_support", ":packet", ":packet_generator", + ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", ":status_handler", + ":status_handler_cc_proto", + ":stream_handler_cc_proto", ":subgraph", + ":thread_pool_executor_cc_proto", ":timestamp", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", - "//mediapipe/framework:stream_handler_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -1202,11 +1202,11 @@ cc_test( name = "validated_graph_config_test", srcs = ["validated_graph_config_test.cc"], deps = [ + ":calculator_cc_proto", ":calculator_framework", ":graph_service", ":graph_service_manager", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", "//mediapipe/framework/port:gtest_main", @@ -1233,6 +1233,7 @@ cc_test( linkstatic = 1, deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_registry", @@ -1242,7 +1243,6 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", @@ -1256,11 +1256,11 @@ cc_test( srcs = ["calculator_contract_test.cc"], linkstatic = 1, deps = [ + ":calculator_cc_proto", ":calculator_contract", ":calculator_contract_test_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", + ":packet_generator_cc_proto", + ":status_handler_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", ], @@ -1368,6 +1368,7 @@ cc_test( srcs = ["calculator_context_test.cc"], linkstatic = 1, deps = [ + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -1376,7 +1377,6 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", @@ -1403,6 +1403,7 @@ cc_test( ":executor", ":input_stream_handler", ":lifetime_tracker", + ":mediapipe_options_cc_proto", ":output_stream_poller", ":packet_set", ":packet_type", @@ -1410,13 +1411,12 @@ cc_test( ":subgraph", ":test_calculators", ":thread_pool_executor", + ":thread_pool_executor_cc_proto", ":timestamp", ":type_map", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:mediapipe_options_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1481,12 +1481,12 @@ cc_test( ], visibility = ["//visibility:public"], deps = [ + ":calculator_cc_proto", ":calculator_framework", ":test_calculators", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1630,8 +1630,8 @@ cc_test( srcs = ["packet_generator_test.cc"], deps = [ ":packet_generator", + ":packet_generator_cc_proto", ":packet_type", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/tool:validate_type", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index fdd9b8909..f5a043f10 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -26,7 +26,7 @@ licenses(["notice"]) mediapipe_proto_library( name = "detection_proto", srcs = ["detection.proto"], - deps = ["//mediapipe/framework/formats:location_data_proto"], + deps = [":location_data_proto"], ) mediapipe_register_type( @@ -38,7 +38,7 @@ mediapipe_register_type( "::std::vector<::mediapipe::Detection>", "::std::vector<::mediapipe::DetectionList>", ], - deps = ["//mediapipe/framework/formats:detection_cc_proto"], + deps = [":detection_cc_proto"], ) mediapipe_proto_library( @@ -105,8 +105,8 @@ cc_library( srcs = ["matrix.cc"], hdrs = ["matrix.h"], deps = [ + ":matrix_data_cc_proto", "//mediapipe/framework:port", - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -142,7 +142,7 @@ cc_library( srcs = ["image_frame.cc"], hdrs = ["image_frame.h"], deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -166,8 +166,8 @@ cc_library( srcs = ["image_frame_opencv.cc"], hdrs = ["image_frame_opencv.h"], deps = [ + ":image_format_cc_proto", ":image_frame", - "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:opencv_core", ], ) @@ -194,7 +194,7 @@ cc_library( deps = [ "@com_google_protobuf//:protobuf", "//mediapipe/framework/formats/annotation:locus_cc_proto", - "//mediapipe/framework/formats:location_data_cc_proto", + ":location_data_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -245,7 +245,7 @@ cc_library( name = "video_stream_header", hdrs = ["video_stream_header.h"], deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", ], ) @@ -263,9 +263,9 @@ cc_test( size = "small", srcs = ["image_frame_opencv_test.cc"], deps = [ + ":image_format_cc_proto", ":image_frame", ":image_frame_opencv", - "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -324,8 +324,8 @@ cc_library( "//conditions:default": [], }), deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", - "//mediapipe/framework/formats:image_frame", + ":image_format_cc_proto", + ":image_frame", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", "//mediapipe/framework:type_map", @@ -354,7 +354,7 @@ cc_library( hdrs = ["image_multi_pool.h"], deps = [ ":image", - "//mediapipe/framework/formats:image_frame_pool", + ":image_frame_pool", "//mediapipe/framework:port", "//mediapipe/framework/port:logging", "@com_google_absl//absl/memory", @@ -390,7 +390,7 @@ cc_library( ], deps = [ ":image", - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:statusor", diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index f1bbc0289..c9bb8b4ff 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -38,11 +38,11 @@ cc_library( srcs = ["optical_flow_field.cc"], hdrs = ["optical_flow_field.h"], deps = [ + ":optical_flow_field_data_cc_proto", "//mediapipe/framework:type_map", "//mediapipe/framework/deps:mathutil", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", - "//mediapipe/framework/formats/motion:optical_flow_field_data_cc_proto", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 01ef6ee86..68a9af52d 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -88,8 +88,8 @@ cc_library( srcs = ["default_input_stream_handler.cc"], hdrs = ["default_input_stream_handler.h"], deps = [ + ":default_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", - "//mediapipe/framework/stream_handler:default_input_stream_handler_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -110,8 +110,8 @@ cc_library( srcs = ["fixed_size_input_stream_handler.cc"], deps = [ ":default_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], alwayslink = 1, ) @@ -159,13 +159,13 @@ cc_library( name = "sync_set_input_stream_handler", srcs = ["sync_set_input_stream_handler.cc"], deps = [ + ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework:timestamp", - "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "//mediapipe/framework/tool:tag_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -177,10 +177,10 @@ cc_library( name = "timestamp_align_input_stream_handler", srcs = ["timestamp_align_input_stream_handler.cc"], deps = [ + ":timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:timestamp", - "//mediapipe/framework/stream_handler:timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework/tool:validate_name", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -243,6 +243,7 @@ cc_test( srcs = ["set_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", ":mux_input_stream_handler", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", @@ -251,7 +252,6 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], ) @@ -272,13 +272,13 @@ cc_test( srcs = ["fixed_size_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", ], @@ -289,11 +289,11 @@ cc_test( srcs = ["sync_set_input_stream_handler_test.cc"], deps = [ ":sync_set_input_stream_handler", + ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:test_calculators", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 89cb802da..193343a90 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -299,6 +299,7 @@ mediapipe_cc_test( requires_full_emulation = False, deps = [ ":node_chain_subgraph_cc_proto", + ":node_chain_subgraph_options_lib", ":options_field_util", ":options_registry", ":options_syntax_util", @@ -313,7 +314,6 @@ mediapipe_cc_test( "//mediapipe/framework/port:status", "//mediapipe/framework/testdata:night_light_calculator_cc_proto", "//mediapipe/framework/testdata:night_light_calculator_options_lib", - "//mediapipe/framework/tool:node_chain_subgraph_options_lib", "//mediapipe/util:header_util", "@com_google_absl//absl/strings", ], @@ -422,9 +422,9 @@ cc_library( srcs = ["source.cc"], visibility = ["//visibility:public"], deps = [ + ":source_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:source_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], @@ -485,13 +485,13 @@ cc_library( hdrs = ["template_expander.h"], visibility = ["//visibility:public"], deps = [ + ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:numbers", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/strings", ], ) @@ -506,6 +506,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:proto_descriptor_cc_proto", @@ -515,7 +516,6 @@ cc_library( "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -661,8 +661,8 @@ cc_library( hdrs = ["simulation_clock_executor.h"], visibility = ["//visibility:public"], deps = [ + ":simulation_clock", "//mediapipe/framework:thread_pool_executor", - "//mediapipe/framework/tool:simulation_clock", ], ) @@ -789,10 +789,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":name_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:switch_container_cc_proto", ], ) @@ -805,6 +805,7 @@ cc_library( deps = [ ":container_util", ":options_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -814,7 +815,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", - "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -841,6 +841,7 @@ cc_library( ], deps = [ ":container_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_shard", @@ -850,7 +851,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", - "//mediapipe/framework/tool:switch_container_cc_proto", ], alwayslink = 1, ) @@ -893,6 +893,7 @@ cc_library( ":container_util", ":name_util", ":subgraph_expansion", + ":switch_container_cc_proto", ":switch_demux_calculator", ":switch_mux_calculator", "//mediapipe/calculators/core:packet_sequencer_calculator", @@ -904,7 +905,6 @@ cc_library( "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 009eb3f9e..cc5e50dfc 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -564,6 +564,7 @@ cc_library( name = "gpu_shared_data_internal_stub", visibility = ["//visibility:private"], deps = [ + ":gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", @@ -571,7 +572,6 @@ cc_library( "//mediapipe/framework:port", "//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/port:ret_check", - "//mediapipe/gpu:gl_context_options_cc_proto", ], ) @@ -592,7 +592,7 @@ cc_library( }), visibility = ["//visibility:private"], deps = [ - "//mediapipe/gpu:gl_context_options_cc_proto", + ":gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:executor", @@ -833,10 +833,10 @@ cc_library( deps = [ ":gl_base", ":gl_simple_shaders", + ":scale_mode_cc_proto", ":shader_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/gpu:scale_mode_cc_proto", ], ) @@ -907,8 +907,8 @@ proto_library( srcs = ["gl_scaler_calculator.proto"], visibility = ["//visibility:public"], deps = [ + ":scale_mode_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/gpu:scale_mode_proto", ], ) @@ -930,6 +930,7 @@ cc_library( deps = [ ":gl_calculator_helper", ":gl_quad_renderer", + ":gl_scaler_calculator_cc_proto", ":gl_simple_shaders", ":shader_util", "//mediapipe/framework:calculator_framework", @@ -937,7 +938,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_util", - "//mediapipe/gpu:gl_scaler_calculator_cc_proto", ], alwayslink = 1, ) @@ -950,13 +950,13 @@ cc_library( ":egl_surface_holder", ":gl_calculator_helper", ":gl_quad_renderer", + ":gl_surface_sink_calculator_cc_proto", ":gpu_buffer", ":shader_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/gpu:gl_surface_sink_calculator_cc_proto", "@com_google_absl//absl/synchronization", ], alwayslink = 1, @@ -966,8 +966,8 @@ proto_library( name = "gl_surface_sink_calculator_proto", srcs = ["gl_surface_sink_calculator.proto"], deps = [ + ":scale_mode_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/gpu:scale_mode_proto", ], ) From 151e447614741f02185c94f4412a3ab665a16c17 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 20 Dec 2022 17:50:21 -0800 Subject: [PATCH 198/346] Internal changes PiperOrigin-RevId: 496793199 --- mediapipe/calculators/core/sequence_shift_calculator.cc | 6 ++++++ mediapipe/calculators/core/sequence_shift_calculator.proto | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/mediapipe/calculators/core/sequence_shift_calculator.cc b/mediapipe/calculators/core/sequence_shift_calculator.cc index 66dbdef2e..026048b79 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator.cc +++ b/mediapipe/calculators/core/sequence_shift_calculator.cc @@ -66,12 +66,16 @@ class SequenceShiftCalculator : public Node { // The number of packets or timestamps we need to store to output packet[i] at // the timestamp of packet[i + packet_offset]; equal to abs(packet_offset). int cache_size_; + bool emit_empty_packets_before_first_packet_ = false; }; MEDIAPIPE_REGISTER_NODE(SequenceShiftCalculator); absl::Status SequenceShiftCalculator::Open(CalculatorContext* cc) { packet_offset_ = kOffset(cc).GetOr( cc->Options().packet_offset()); + emit_empty_packets_before_first_packet_ = + cc->Options() + .emit_empty_packets_before_first_packet(); cache_size_ = abs(packet_offset_); // An offset of zero is a no-op, but someone might still request it. if (packet_offset_ == 0) { @@ -96,6 +100,8 @@ void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) { // Ready to output oldest packet with current timestamp. kOut(cc).Send(packet_cache_.front().At(cc->InputTimestamp())); packet_cache_.pop_front(); + } else if (emit_empty_packets_before_first_packet_) { + LOG(FATAL) << "Not supported yet"; } // Store current packet for later output. packet_cache_.push_back(kIn(cc).packet()); diff --git a/mediapipe/calculators/core/sequence_shift_calculator.proto b/mediapipe/calculators/core/sequence_shift_calculator.proto index 15b111d71..36b0bb959 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator.proto +++ b/mediapipe/calculators/core/sequence_shift_calculator.proto @@ -23,4 +23,8 @@ message SequenceShiftCalculatorOptions { optional SequenceShiftCalculatorOptions ext = 107633927; } optional int32 packet_offset = 1 [default = -1]; + + // Emits empty packets before the first delayed packet is emitted. Takes + // effect only when packet offset is set to positive. + optional bool emit_empty_packets_before_first_packet = 2 [default = false]; } From 5c0f548f5f5b31d94b749456cdac306b5330dfa3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 20 Dec 2022 20:51:23 -0800 Subject: [PATCH 199/346] Switches to tf.keras.optimizers.experimental.AdamW instead of the legacy AdamW. PiperOrigin-RevId: 496821354 --- .../text/text_classifier/text_classifier.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index c285702d2..c4d3fdbe2 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -417,8 +417,22 @@ class _BertClassifier(TextClassifier): total_steps = self._hparams.steps_per_epoch * self._hparams.epochs warmup_steps = int(total_steps * 0.1) initial_lr = self._hparams.learning_rate - self._optimizer = optimization.create_optimizer(initial_lr, total_steps, - warmup_steps) + # Implements linear decay of the learning rate. + lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=initial_lr, + decay_steps=total_steps, + end_learning_rate=0.0, + power=1.0) + if warmup_steps: + lr_schedule = optimization.WarmUp( + initial_learning_rate=initial_lr, + decay_schedule_fn=lr_schedule, + warmup_steps=warmup_steps) + + self._optimizer = tf.keras.optimizers.experimental.AdamW( + lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0) + self._optimizer.exclude_from_weight_decay( + var_names=["LayerNorm", "layer_norm", "bias"]) def _save_vocab(self, vocab_filepath: str): tf.io.gfile.copy( From 1341720d6db044d2771eabe5d5574d67bb04a4f6 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 21 Dec 2022 00:52:17 -0800 Subject: [PATCH 200/346] Internal change PiperOrigin-RevId: 496854337 --- mediapipe/framework/BUILD | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 082ea9994..a4c9a520d 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -369,7 +369,7 @@ cc_library( visibility = [":mediapipe_internal"], deps = [ ":graph_service", - "//mediapipe/framework:packet", + ":packet", "@com_google_absl//absl/status", ], ) @@ -379,7 +379,7 @@ cc_test( srcs = ["graph_service_manager_test.cc"], deps = [ ":graph_service_manager", - "//mediapipe/framework:packet", + ":packet", "//mediapipe/framework/port:gtest_main", ], ) From 714a6e555b106e7fd4de1b1e83d70e1c1c8570f3 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 21 Dec 2022 08:06:20 -0800 Subject: [PATCH 201/346] Enable creating mediapipe image c++ packet directly from an Android media image object when its format is RGBA_8888. PiperOrigin-RevId: 496923491 --- .../mediapipe/framework/AndroidPacketCreator.java | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java index 05700ba17..fc1e5484e 100644 --- a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java @@ -15,10 +15,13 @@ package com.google.mediapipe.framework; import android.graphics.Bitmap; +import android.graphics.PixelFormat; +import android.media.Image; import com.google.mediapipe.framework.image.BitmapExtractor; import com.google.mediapipe.framework.image.ByteBufferExtractor; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.framework.image.MPImageProperties; +import com.google.mediapipe.framework.image.MediaImageExtractor; import java.nio.ByteBuffer; // TODO: use Preconditions in this file. @@ -97,7 +100,17 @@ public class AndroidPacketCreator extends PacketCreator { } return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap)); } - + if (properties.getStorageType() == MPImage.STORAGE_TYPE_MEDIA_IMAGE) { + Image mediaImage = MediaImageExtractor.extract(image); + if (mediaImage.getFormat() != PixelFormat.RGBA_8888) { + throw new UnsupportedOperationException("Android media image must use RGBA_8888 config."); + } + return createImage( + mediaImage.getPlanes()[0].getBuffer(), + mediaImage.getWidth(), + mediaImage.getHeight(), + /* numChannels= */ 4); + } // Unsupported type. throw new UnsupportedOperationException( "Unsupported Image container type: " + properties.getStorageType()); From c8b8d1fe6b04ef906b7ef3956fdff266c2704228 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 21 Dec 2022 11:08:01 -0800 Subject: [PATCH 202/346] Remove scripts for building MediaPipe Python 3.7 wheels. PiperOrigin-RevId: 496962729 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b072a850e..992430cf1 100644 --- a/setup.py +++ b/setup.py @@ -490,10 +490,10 @@ setuptools.setup( 'Operating System :: MacOS :: MacOS X', 'Operating System :: Microsoft :: Windows', 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3 :: Only', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', From ae28948ca150fd3a801c2ef1387b460151083204 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 00:49:24 +0530 Subject: [PATCH 203/346] Marked designated initializers --- mediapipe/tasks/ios/core/sources/MPPTaskInfo.h | 2 +- mediapipe/tasks/ios/core/sources/MPPTaskRunner.h | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h index fca660fae..4c01787a8 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h @@ -55,7 +55,7 @@ NS_ASSUME_NONNULL_BEGIN outputStreams:(NSArray *)outputStreams taskOptions:(id)taskOptions enableFlowLimiting:(BOOL)enableFlowLimiting - error:(NSError **)error; + error:(NSError **)error NS_DESIGNATED_INITIALIZER; /** * Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance. diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h index 64e34b82e..e07cb344d 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -24,6 +24,7 @@ NS_ASSUME_NONNULL_BEGIN * This class is used to create and call appropriate methods on the C++ Task Runner. */ @interface MPPTaskRunner : NSObject + /** * Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto. * @@ -32,7 +33,7 @@ NS_ASSUME_NONNULL_BEGIN * @return An instance of `MPPTaskRunner` initialized to the given graph config proto. */ - (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig - error:(NSError **)error; + error:(NSError **)error NS_DESIGNATED_INITIALIZER; - (absl::StatusOr)process:(const mediapipe::tasks::core::PacketMap&)packetMap error:(NSError **)error; From 481f4e960e009df7431f7281312e985507428c87 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 00:49:44 +0530 Subject: [PATCH 204/346] Updated comments --- mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h index c6f115451..44fba4c0b 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h @@ -18,12 +18,12 @@ NS_ASSUME_NONNULL_BEGIN /** - * Any mediapipe task options should confirm to this protocol. + * Any MediaPipe task options should confirm to this protocol. */ @protocol MPPTaskOptionsProtocol /** - * Copies the iOS Mediapipe task options to an object of mediapipe::CalculatorOptions proto. + * Copies the iOS MediaPipe task options to an object of mediapipe::CalculatorOptions proto. */ - (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto; From 2943d1668e0e0c66dad9a0e6626dc1c82e38cd3d Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 00:51:20 +0530 Subject: [PATCH 205/346] Updated comments --- mediapipe/tasks/ios/core/sources/MPPTaskRunner.h | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h index e07cb344d..9dfef02e1 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -23,6 +23,7 @@ NS_ASSUME_NONNULL_BEGIN /** * This class is used to create and call appropriate methods on the C++ Task Runner. */ + @interface MPPTaskRunner : NSObject /** From 20f2e136c520b937d81a8241aec7a4ca869d3f70 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 00:59:22 +0530 Subject: [PATCH 206/346] Updated empty spaces --- .../processors/utils/sources/MPPClassifierOptions+Helpers.mm | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm index 25e657599..db7fa6bfd 100644 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm @@ -20,12 +20,16 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto } @implementation MPPClassifierOptions (Helpers) + - (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto { if (self.displayNamesLocale) { classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); } + classifierOptionsProto->set_max_results((int)self.maxResults); + classifierOptionsProto->set_score_threshold(self.scoreThreshold); + for (NSString *category in self.labelAllowList) { classifierOptionsProto->add_category_allowlist(category.cppString); } From 1491b3f5a2da5ac3415edd8cd946e0e9b639887b Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:00:36 +0530 Subject: [PATCH 207/346] Updated comments --- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h index 8a90856c7..d2e6067d5 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -24,7 +24,7 @@ extern NSString *const MPPTasksErrorDomain; @interface MPPCommonUtils : NSObject /** - * Creates and saves an NSError in the Mediapipe task library domain, with the given code and + * Creates and saves an NSError in the MediPipe task library domain, with the given code and * description. * * @param code Error code. From 1de369417572ebb09e26c72d8a6fa3e5f7685795 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:02:07 +0530 Subject: [PATCH 208/346] Updated comments --- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h index d2e6067d5..1a44ee45a 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -51,9 +51,9 @@ extern NSString *const MPPTasksErrorDomain; description:(NSString *)description; /** - * Converts an absl status to an NSError. + * Converts an absl::Status to an NSError. * - * @param status absl status. + * @param status absl::Status. * @param error Pointer to the memory location where the created error should be saved. If `nil`, * no error will be saved. */ @@ -68,7 +68,7 @@ extern NSString *const MPPTasksErrorDomain; * @param error Pointer to the memory location where errors if any should be saved. If `nil`, no * error will be saved. * - * @return Pointer to the allocated block of memory on successfull allocation. nil in case as + * @return Pointer to the allocated block of memory on successfull allocation. `nil` in case as * error is encountered because of invalid memSize. If failure is due to any other reason, method * terminates program execution. */ From 99c11ff9743fab7799bfd7db9e2e63755fe4a123 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:03:39 +0530 Subject: [PATCH 209/346] Updated comments --- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h index 1a44ee45a..407d87aba 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -61,7 +61,7 @@ extern NSString *const MPPTasksErrorDomain; /** * Allocates a block of memory with the specified size and returns a pointer to it. If memory - * cannot be allocated because of an invalid memSize, it saves an error. In other cases, it + * cannot be allocated because of an invalid `memSize`, it saves an error. In other cases, it * terminates program execution. * * @param memSize size of memory to be allocated @@ -69,7 +69,7 @@ extern NSString *const MPPTasksErrorDomain; * error will be saved. * * @return Pointer to the allocated block of memory on successfull allocation. `nil` in case as - * error is encountered because of invalid memSize. If failure is due to any other reason, method + * error is encountered because of invalid `memSize`. If failure is due to any other reason, method * terminates program execution. */ + (void *)mallocWithSize:(size_t)memSize error:(NSError **)error; From 7ae4b7e6394b5e75315fb46b4ee9c44b6e02ecc1 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:05:01 +0530 Subject: [PATCH 210/346] Updated error domain --- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 574f2ef9a..4d4880a87 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -24,7 +24,7 @@ #include "mediapipe/tasks/cc/common.h" /** Error domain of MediaPipe task library errors. */ -NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; +NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; @implementation MPPCommonUtils @@ -68,7 +68,7 @@ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; if (status.ok()) { return YES; } - // Payload of absl::Status created by the Media Pipe task library stores an appropriate value of + // Payload of absl::Status created by the MediaPipe task library stores an appropriate value of // the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum // stored in the payload is extracted here to later map to the appropriate error code to be // returned. In cases where the enum is not stored in (payload is NULL or the payload string From 54d36dfedad1d2e84f680fb69defb13b6eae45b9 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:05:50 +0530 Subject: [PATCH 211/346] Update MPPClassifierOptions.h --- .../ios/components/processors/sources/MPPClassifierOptions.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index 8c4981642..b31dadb63 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -17,7 +17,7 @@ NS_ASSUME_NONNULL_BEGIN /** - * Holds settings for any single iOS Mediapipe classification task. + * Holds settings for any single iOS MediaPipe classification task. */ NS_SWIFT_NAME(ClassifierOptions) @interface MPPClassifierOptions : NSObject From 673b38dfe87c35504ac81f5b29935ab6b25beaa1 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:08:13 +0530 Subject: [PATCH 212/346] Updated comments --- .../processors/sources/MPPClassifierOptions.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index b31dadb63..d6b9a9582 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -22,16 +22,18 @@ NS_ASSUME_NONNULL_BEGIN NS_SWIFT_NAME(ClassifierOptions) @interface MPPClassifierOptions : NSObject -/** If set, all classes in this list will be filtered out from the results . */ +/** If set, all classes in this list will be filtered out from the results. */ @property(nonatomic, copy) NSArray *labelDenyList; -/** If set, all classes not in this list will be filtered out from the results . */ +/** If set, all classes not in this list will be filtered out from the results. */ @property(nonatomic, copy) NSArray *labelAllowList; -/** Display names local for display names*/ +/** The locale to use for display names specified through the TFLite Model + * Metadata, if any. Defaults to English. + */ @property(nonatomic, copy) NSString *displayNamesLocale; -/** Results with score threshold greater than this value are returned . */ +/** Results with score threshold greater than this value are returned. */ @property(nonatomic) float scoreThreshold; /** Limit to the number of classes that can be returned in results. */ From 66ee8d47c0d13c4f0a4f4ee91bde7bf570fbaa61 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:10:07 +0530 Subject: [PATCH 213/346] Resorted options --- .../processors/sources/MPPClassifierOptions.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index d6b9a9582..0c22ed9de 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -22,22 +22,22 @@ NS_ASSUME_NONNULL_BEGIN NS_SWIFT_NAME(ClassifierOptions) @interface MPPClassifierOptions : NSObject -/** If set, all classes in this list will be filtered out from the results. */ -@property(nonatomic, copy) NSArray *labelDenyList; - -/** If set, all classes not in this list will be filtered out from the results. */ -@property(nonatomic, copy) NSArray *labelAllowList; - /** The locale to use for display names specified through the TFLite Model * Metadata, if any. Defaults to English. */ @property(nonatomic, copy) NSString *displayNamesLocale; +/** Limit to the number of classes that can be returned in results. */ +@property(nonatomic) NSInteger maxResults; + /** Results with score threshold greater than this value are returned. */ @property(nonatomic) float scoreThreshold; -/** Limit to the number of classes that can be returned in results. */ -@property(nonatomic) NSInteger maxResults; +/** If set, all classes not in this list will be filtered out from the results. */ +@property(nonatomic, copy) NSArray *labelAllowList; + +/** If set, all classes in this list will be filtered out from the results. */ +@property(nonatomic, copy) NSArray *labelDenyList; @end From e1dfcf03cf41f0f9519206ea4fa97f255161191f Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:12:34 +0530 Subject: [PATCH 214/346] Updated comments in MPPClassifierOptions.h --- .../ios/components/processors/sources/MPPClassifierOptions.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index 0c22ed9de..371472cab 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -30,7 +30,9 @@ NS_SWIFT_NAME(ClassifierOptions) /** Limit to the number of classes that can be returned in results. */ @property(nonatomic) NSInteger maxResults; -/** Results with score threshold greater than this value are returned. */ +/** Score threshold to override the one provided in the model metadata (if any). + * Results below this value are rejected. + */ @property(nonatomic) float scoreThreshold; /** If set, all classes not in this list will be filtered out from the results. */ From c185dc9ad7ba33844fac9560f33c59fb2c9e4ad6 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:19:01 +0530 Subject: [PATCH 215/346] Renamed label to category in classifier options --- .../ios/components/processors/sources/MPPClassifierOptions.h | 4 ++-- .../ios/components/processors/sources/MPPClassifierOptions.m | 4 ++-- .../processors/utils/sources/MPPClassifierOptions+Helpers.mm | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index 371472cab..0f0abe398 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -36,10 +36,10 @@ NS_SWIFT_NAME(ClassifierOptions) @property(nonatomic) float scoreThreshold; /** If set, all classes not in this list will be filtered out from the results. */ -@property(nonatomic, copy) NSArray *labelAllowList; +@property(nonatomic, copy) NSArray *categoryAllowList; /** If set, all classes in this list will be filtered out from the results. */ -@property(nonatomic, copy) NSArray *labelDenyList; +@property(nonatomic, copy) NSArray *categoryDenyList; @end diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m index 52dce23e4..1d9191802 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m @@ -30,8 +30,8 @@ classifierOptions.scoreThreshold = self.scoreThreshold; classifierOptions.maxResults = self.maxResults; - classifierOptions.labelDenyList = self.labelDenyList; - classifierOptions.labelAllowList = self.labelAllowList; + classifierOptions.categoryDenyList = self.categoryDenyList; + classifierOptions.categoryAllowList = self.categoryAllowList; classifierOptions.displayNamesLocale = self.displayNamesLocale; return classifierOptions; diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm index db7fa6bfd..3d8397efa 100644 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm @@ -30,11 +30,11 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto classifierOptionsProto->set_score_threshold(self.scoreThreshold); - for (NSString *category in self.labelAllowList) { + for (NSString *category in self.categoryAllowList) { classifierOptionsProto->add_category_allowlist(category.cppString); } - for (NSString *category in self.labelDenyList) { + for (NSString *category in self.categoryDenyList) { classifierOptionsProto->add_category_denylist(category.cppString); } } From 20c3388ab68c11082de43aff825762686b3bc8e1 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 01:59:38 +0530 Subject: [PATCH 216/346] Updated category allowlist and denylist names --- .../ios/components/processors/sources/MPPClassifierOptions.h | 4 ++-- .../ios/components/processors/sources/MPPClassifierOptions.m | 4 ++-- .../processors/utils/sources/MPPClassifierOptions+Helpers.mm | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index 0f0abe398..e95de89e4 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -36,10 +36,10 @@ NS_SWIFT_NAME(ClassifierOptions) @property(nonatomic) float scoreThreshold; /** If set, all classes not in this list will be filtered out from the results. */ -@property(nonatomic, copy) NSArray *categoryAllowList; +@property(nonatomic, copy) NSArray *categoryAllowlist; /** If set, all classes in this list will be filtered out from the results. */ -@property(nonatomic, copy) NSArray *categoryDenyList; +@property(nonatomic, copy) NSArray *categoryDenylist; @end diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m index 1d9191802..accb6c7dd 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m @@ -30,8 +30,8 @@ classifierOptions.scoreThreshold = self.scoreThreshold; classifierOptions.maxResults = self.maxResults; - classifierOptions.categoryDenyList = self.categoryDenyList; - classifierOptions.categoryAllowList = self.categoryAllowList; + classifierOptions.categoryDenylist = self.categoryDenylist; + classifierOptions.categoryAllowlist = self.categoryAllowlist; classifierOptions.displayNamesLocale = self.displayNamesLocale; return classifierOptions; diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm index 3d8397efa..81fe57d13 100644 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm @@ -30,11 +30,11 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto classifierOptionsProto->set_score_threshold(self.scoreThreshold); - for (NSString *category in self.categoryAllowList) { + for (NSString *category in self.categoryAllowlist) { classifierOptionsProto->add_category_allowlist(category.cppString); } - for (NSString *category in self.categoryDenyList) { + for (NSString *category in self.categoryDenylist) { classifierOptionsProto->add_category_denylist(category.cppString); } } From b4a7644428ac05af45dc737bd1002b6a8f6154cc Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 02:01:04 +0530 Subject: [PATCH 217/346] Updated comments --- .../ios/components/processors/sources/MPPClassifierOptions.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index e95de89e4..348e94e96 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -27,7 +27,10 @@ NS_SWIFT_NAME(ClassifierOptions) */ @property(nonatomic, copy) NSString *displayNamesLocale; -/** Limit to the number of classes that can be returned in results. */ +/** The maximum number of top-scored classification results to return. If < 0, + * all available results will be returned. If 0, an invalid argument error is + * returned. + */ @property(nonatomic) NSInteger maxResults; /** Score threshold to override the one provided in the model metadata (if any). From e559613b9de8d73e0d4956688561174b58e2dcb9 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 02:02:48 +0530 Subject: [PATCH 218/346] Updated comments in MPPClassifierOptions.h --- .../processors/sources/MPPClassifierOptions.h | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index 348e94e96..7bf5744f7 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -38,10 +38,16 @@ NS_SWIFT_NAME(ClassifierOptions) */ @property(nonatomic) float scoreThreshold; -/** If set, all classes not in this list will be filtered out from the results. */ +/** The allowlist of category names. If non-empty, detection results whose + * category name is not in this set will be filtered out. Duplicate or unknown + * category names are ignored. Mutually exclusive with categoryDenylist. + */ @property(nonatomic, copy) NSArray *categoryAllowlist; -/** If set, all classes in this list will be filtered out from the results. */ +/** The denylist of category names. If non-empty, detection results whose + * category name is in this set will be filtered out. Duplicate or unknown + * category names are ignored. Mutually exclusive with categoryAllowlist. + */ @property(nonatomic, copy) NSArray *categoryDenylist; @end From 69b6d9d970a9eae8d7c9e085201ba888ef4ef54b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 21 Dec 2022 17:39:54 -0800 Subject: [PATCH 219/346] Internal change PiperOrigin-RevId: 497043596 --- mediapipe/web/graph_runner/graph_runner.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index a9bb979af..ef866bc91 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -1028,7 +1028,9 @@ export class GraphRunner { // Set up our TS listener to receive any packets for this stream, and // additionally reformat our Uint8Array into a Float32Array for the user. this.setListener(outputStreamName, (data: Uint8Array) => { - const floatArray = new Float32Array(data.buffer); // Should be very fast + // Should be very fast + const floatArray = + new Float32Array(data.buffer, data.byteOffset, data.length / 4); callbackFcn(floatArray); }); From e47256ae55af3921d0878cf131c32625a2500082 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 12:10:23 +0530 Subject: [PATCH 220/346] Clearing proto before assigining new values in MPPClassifierOptions Helpers --- .../processors/utils/sources/MPPClassifierOptions+Helpers.mm | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm index 81fe57d13..efe9572e1 100644 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm @@ -22,6 +22,8 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto @implementation MPPClassifierOptions (Helpers) - (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto { + classifierOptionsProto->Clear(); + if (self.displayNamesLocale) { classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); } From 613ed588908ac3bd39b48bf05e21c2fa52eeb9ad Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 12:16:33 +0530 Subject: [PATCH 221/346] Inverted condition check in MPPTaskInfo --- .../tasks/ios/core/sources/MPPTaskInfo.mm | 67 ++++++++++--------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm index 7d2fd6f28..be3c8cbf7 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm @@ -24,9 +24,9 @@ namespace { using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig; using Node = ::mediapipe::CalculatorGraphConfig::Node; -using ::mediapipe::InputStreamInfo; using ::mediapipe::CalculatorOptions; using ::mediapipe::FlowLimiterCalculatorOptions; +using ::mediapipe::InputStreamInfo; } // namespace @implementation MPPTaskInfo @@ -82,45 +82,46 @@ using ::mediapipe::FlowLimiterCalculatorOptions; graph_config.add_output_stream(cpp_output_stream); } - if (self.enableFlowLimiting) { - Node *flow_limit_calculator_node = graph_config.add_node(); - - flow_limit_calculator_node->set_calculator("FlowLimiterCalculator"); - - InputStreamInfo *input_stream_info = flow_limit_calculator_node->add_input_stream_info(); - input_stream_info->set_tag_index("FINISHED"); - input_stream_info->set_back_edge(true); - - FlowLimiterCalculatorOptions *flow_limit_calculator_options = - flow_limit_calculator_node->mutable_options()->MutableExtension( - FlowLimiterCalculatorOptions::ext); - flow_limit_calculator_options->set_max_in_flight(1); - flow_limit_calculator_options->set_max_in_queue(1); - - for (NSString *inputStream in self.inputStreams) { - graph_config.add_input_stream(inputStream.cppString); - - NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream]; - flow_limit_calculator_node->add_input_stream(strippedInputStream.cppString); - - NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream]; - task_subgraph_node->add_input_stream(taskInputStream.cppString); - - NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream]; - flow_limit_calculator_node->add_output_stream(strippedTaskInputStream.cppString); - } - - NSString *firstOutputStream = self.outputStreams[0]; - auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString; - flow_limit_calculator_node->add_input_stream(finished_output_stream); - } else { + if (!self.enableFlowLimiting) { for (NSString *inputStream in self.inputStreams) { auto cpp_input_stream = inputStream.cppString; task_subgraph_node->add_input_stream(cpp_input_stream); graph_config.add_input_stream(cpp_input_stream); } + return graph_config; } + Node *flow_limit_calculator_node = graph_config.add_node(); + + flow_limit_calculator_node->set_calculator("FlowLimiterCalculator"); + + InputStreamInfo *input_stream_info = flow_limit_calculator_node->add_input_stream_info(); + input_stream_info->set_tag_index("FINISHED"); + input_stream_info->set_back_edge(true); + + FlowLimiterCalculatorOptions *flow_limit_calculator_options = + flow_limit_calculator_node->mutable_options()->MutableExtension( + FlowLimiterCalculatorOptions::ext); + flow_limit_calculator_options->set_max_in_flight(1); + flow_limit_calculator_options->set_max_in_queue(1); + + for (NSString *inputStream in self.inputStreams) { + graph_config.add_input_stream(inputStream.cppString); + + NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream]; + flow_limit_calculator_node->add_input_stream(strippedInputStream.cppString); + + NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream]; + task_subgraph_node->add_input_stream(taskInputStream.cppString); + + NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream]; + flow_limit_calculator_node->add_output_stream(strippedTaskInputStream.cppString); + } + + NSString *firstOutputStream = self.outputStreams[0]; + auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString; + flow_limit_calculator_node->add_input_stream(finished_output_stream); + return graph_config; } From 48eeae4d9d3582661f002ddc2424e3e6c8cdd512 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 22 Dec 2022 12:16:43 +0530 Subject: [PATCH 222/346] Formatted code --- mediapipe/tasks/ios/core/sources/MPPTaskInfo.h | 1 - mediapipe/tasks/ios/core/sources/MPPTaskRunner.h | 5 +++-- mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h index 4c01787a8..ae4c9eba1 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h @@ -17,7 +17,6 @@ #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" - NS_ASSUME_NONNULL_BEGIN /** diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h index 9dfef02e1..6561e136d 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -17,7 +17,6 @@ #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" - NS_ASSUME_NONNULL_BEGIN /** @@ -36,7 +35,9 @@ NS_ASSUME_NONNULL_BEGIN - (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig error:(NSError **)error NS_DESIGNATED_INITIALIZER; -- (absl::StatusOr)process:(const mediapipe::tasks::core::PacketMap&)packetMap error:(NSError **)error; +- (absl::StatusOr) + process:(const mediapipe::tasks::core::PacketMap &)packetMap + error:(NSError **)error; - (void)close; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm index 404f6c582..e08d0bc1b 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm @@ -45,7 +45,7 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; return self; } -- (absl::StatusOr)process:(const PacketMap&)packetMap { +- (absl::StatusOr)process:(const PacketMap &)packetMap { return _cppTaskRunner->Process(packetMap); } From 967384160524a7be56da549f17abd129493ada78 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 22 Dec 2022 09:47:35 -0800 Subject: [PATCH 223/346] Internal visibility update PiperOrigin-RevId: 497185157 --- mediapipe/framework/deps/BUILD | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index 27bc105c8..7ff004f1e 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -20,9 +20,14 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") licenses(["notice"]) -package(default_visibility = [ - "//mediapipe:__subpackages__", -]) +package_group( + name = "mediapipe_internal", + packages = [ + "//mediapipe/...", + ], +) + +package(default_visibility = ["mediapipe_internal"]) bzl_library( name = "expand_template_bzl", @@ -214,6 +219,9 @@ cc_library( name = "registration", srcs = ["registration.cc"], hdrs = ["registration.h"], + visibility = [ + "mediapipe_internal", + ], deps = [ ":registration_token", "//mediapipe/framework/port:logging", From 5b90afda701d1ddb91a435f064507b43636ea966 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 22 Dec 2022 10:19:59 -0800 Subject: [PATCH 224/346] Internal change PiperOrigin-RevId: 497191969 --- mediapipe/framework/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index a4c9a520d..83346dad1 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1060,7 +1060,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":calculator_framework", - "//mediapipe/framework:test_calculators_cc_proto", + ":test_calculators_cc_proto", "//mediapipe/framework/deps:mathutil", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:integral_types", From 36f054dfbe391b450aeb11bfc4b71e962644b72d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 22 Dec 2022 10:41:03 -0800 Subject: [PATCH 225/346] Internal model maker change PiperOrigin-RevId: 497196512 --- .../model_maker/python/text/text_classifier/text_classifier.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index c4d3fdbe2..1a338e345 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -33,7 +33,6 @@ from mediapipe.model_maker.python.text.text_classifier import preprocessor from mediapipe.model_maker.python.text.text_classifier import text_classifier_options from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import text_classifier as text_classifier_writer -from official.nlp import optimization def _validate(options: text_classifier_options.TextClassifierOptions): @@ -424,7 +423,7 @@ class _BertClassifier(TextClassifier): end_learning_rate=0.0, power=1.0) if warmup_steps: - lr_schedule = optimization.WarmUp( + lr_schedule = model_util.WarmUp( initial_learning_rate=initial_lr, decay_schedule_fn=lr_schedule, warmup_steps=warmup_steps) From 5a71b551e5ad4b85aa18bb23d994fe09b753f0f9 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 22 Dec 2022 15:29:18 -0800 Subject: [PATCH 226/346] Remove duplicate and non-public api for model_maker PiperOrigin-RevId: 497251246 --- mediapipe/model_maker/__init__.py | 3 +++ .../python/text/text_classifier/__init__.py | 9 ++++++++ .../python/vision/gesture_recognizer/BUILD | 2 ++ .../vision/gesture_recognizer/__init__.py | 9 ++++++++ .../gesture_recognizer_test.py | 22 ++++++++++--------- .../python/vision/image_classifier/BUILD | 2 ++ .../vision/image_classifier/__init__.py | 9 ++++++++ .../image_classifier/image_classifier_test.py | 10 +++++---- 8 files changed, 52 insertions(+), 14 deletions(-) diff --git a/mediapipe/model_maker/__init__.py b/mediapipe/model_maker/__init__.py index 9899a145b..b37088764 100644 --- a/mediapipe/model_maker/__init__.py +++ b/mediapipe/model_maker/__init__.py @@ -17,3 +17,6 @@ from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.vision import image_classifier from mediapipe.model_maker.python.vision import gesture_recognizer from mediapipe.model_maker.python.text import text_classifier + +# Remove duplicated and non-public API +del python diff --git a/mediapipe/model_maker/python/text/text_classifier/__init__.py b/mediapipe/model_maker/python/text/text_classifier/__init__.py index 618e51645..697461969 100644 --- a/mediapipe/model_maker/python/text/text_classifier/__init__.py +++ b/mediapipe/model_maker/python/text/text_classifier/__init__.py @@ -29,3 +29,12 @@ BertModelOptions = model_options.BertModelOptions SupportedModels = model_spec.SupportedModels TextClassifier = text_classifier.TextClassifier TextClassifierOptions = text_classifier_options.TextClassifierOptions + +# Remove duplicated and non-public API +del hyperparameters +del dataset +del model_options +del model_spec +del preprocessor # pylint: disable=undefined-variable +del text_classifier +del text_classifier_options diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index 9123e36b0..cbdff7cf3 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -146,6 +146,8 @@ py_test( tags = ["notsan"], deps = [ ":gesture_recognizer_import", + ":hyperparameters", + ":model_options", "//mediapipe/model_maker/python/core/utils:test_util", "//mediapipe/tasks/python/test:test_utils", ], diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py b/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py index dc6923fac..a302e8d79 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py @@ -25,3 +25,12 @@ HParams = hyperparameters.HParams Dataset = dataset.Dataset HandDataPreprocessingParams = dataset.HandDataPreprocessingParams GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions + +# Remove duplicated and non-public API +del constants # pylint: disable=undefined-variable +del dataset +del gesture_recognizer +del gesture_recognizer_options +del hyperparameters +del metadata_writer # pylint: disable=undefined-variable +del model_options diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 08fda4fea..4fdb74225 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -23,6 +23,8 @@ import tensorflow as tf from mediapipe.model_maker.python.core.utils import test_util from mediapipe.model_maker.python.vision import gesture_recognizer +from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters +from mediapipe.model_maker.python.vision.gesture_recognizer import model_options from mediapipe.tasks.python.test import test_utils _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata' @@ -48,11 +50,11 @@ class GestureRecognizerTest(tf.test.TestCase): self._train_data, self._validation_data = all_data.split(0.9) def test_gesture_recognizer_model(self): - model_options = gesture_recognizer.ModelOptions() + mo = gesture_recognizer.ModelOptions() hparams = gesture_recognizer.HParams( export_dir=tempfile.mkdtemp(), epochs=2) gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( - model_options=model_options, hparams=hparams) + model_options=mo, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._validation_data, @@ -64,11 +66,11 @@ class GestureRecognizerTest(tf.test.TestCase): tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense) def test_gesture_recognizer_model_layer_widths(self, mock_dense): layer_widths = [64, 32] - model_options = gesture_recognizer.ModelOptions(layer_widths=layer_widths) + mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths) hparams = gesture_recognizer.HParams( export_dir=tempfile.mkdtemp(), epochs=2) gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( - model_options=model_options, hparams=hparams) + model_options=mo, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._validation_data, @@ -87,11 +89,11 @@ class GestureRecognizerTest(tf.test.TestCase): self._test_accuracy(model) def test_export_gesture_recognizer_model(self): - model_options = gesture_recognizer.ModelOptions() + mo = gesture_recognizer.ModelOptions() hparams = gesture_recognizer.HParams( export_dir=tempfile.mkdtemp(), epochs=2) gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( - model_options=model_options, hparams=hparams) + model_options=mo, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._validation_data, @@ -128,12 +130,12 @@ class GestureRecognizerTest(tf.test.TestCase): self.assertGreater(accuracy, threshold) @unittest_mock.patch.object( - gesture_recognizer.hyperparameters, + hyperparameters, 'HParams', autospec=True, return_value=gesture_recognizer.HParams(epochs=1)) @unittest_mock.patch.object( - gesture_recognizer.model_options, + model_options, 'GestureRecognizerModelOptions', autospec=True, return_value=gesture_recognizer.ModelOptions()) @@ -148,11 +150,11 @@ class GestureRecognizerTest(tf.test.TestCase): mock_model_options.assert_called_once() def test_continual_training_by_loading_checkpoint(self): - model_options = gesture_recognizer.ModelOptions() + mo = gesture_recognizer.ModelOptions() hparams = gesture_recognizer.HParams( export_dir=tempfile.mkdtemp(), epochs=2) gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( - model_options=model_options, hparams=hparams) + model_options=mo, hparams=hparams) mock_stdout = io.StringIO() with mock.patch('sys.stdout', mock_stdout): model = gesture_recognizer.GestureRecognizer.create( diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index 29ae189e9..d7c47a359 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -121,7 +121,9 @@ py_library( srcs = ["image_classifier_test.py"], data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"], deps = [ + ":hyperparameters", ":image_classifier_import", + ":model_options", "//mediapipe/tasks/python/test:test_utils", ], ) diff --git a/mediapipe/model_maker/python/vision/image_classifier/__init__.py b/mediapipe/model_maker/python/vision/image_classifier/__init__.py index 3d0543cd2..0f964ef66 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/__init__.py +++ b/mediapipe/model_maker/python/vision/image_classifier/__init__.py @@ -27,3 +27,12 @@ ModelOptions = model_options.ImageClassifierModelOptions ModelSpec = model_spec.ModelSpec SupportedModels = model_spec.SupportedModels ImageClassifierOptions = image_classifier_options.ImageClassifierOptions + +# Remove duplicated and non-public API +del dataset +del hyperparameters +del image_classifier +del image_classifier_options +del model_options +del model_spec +del train_image_classifier_lib # pylint: disable=undefined-variable diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 252659edc..6ca21d334 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -24,6 +24,8 @@ import numpy as np import tensorflow as tf from mediapipe.model_maker.python.vision import image_classifier +from mediapipe.model_maker.python.vision.image_classifier import hyperparameters +from mediapipe.model_maker.python.vision.image_classifier import model_options from mediapipe.tasks.python.test import test_utils @@ -159,15 +161,15 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): self.assertGreaterEqual(accuracy, threshold) @unittest_mock.patch.object( - image_classifier.hyperparameters, + hyperparameters, 'HParams', autospec=True, - return_value=image_classifier.HParams(epochs=1)) + return_value=hyperparameters.HParams(epochs=1)) @unittest_mock.patch.object( - image_classifier.model_options, + model_options, 'ImageClassifierModelOptions', autospec=True, - return_value=image_classifier.ModelOptions()) + return_value=model_options.ImageClassifierModelOptions()) def test_create_hparams_and_model_options_if_none_in_image_classifier_options( self, mock_hparams, mock_model_options): options = image_classifier.ImageClassifierOptions( From 557cd050f3bf079266aaa7b88987a2cab5ab9ab3 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Thu, 22 Dec 2022 16:25:35 -0800 Subject: [PATCH 227/346] Deprecating RealTimeFlowLimiterCalculator in favor of FlowLimiterCalculator. PiperOrigin-RevId: 497260577 --- .../calculators/core/real_time_flow_limiter_calculator.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc index ef3cb9896..e3c92ba52 100644 --- a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc @@ -76,7 +76,11 @@ constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT"; // } // output_stream: "gated_frames" // } -class RealTimeFlowLimiterCalculator : public CalculatorBase { +// +// Please use FlowLimiterCalculator, which replaces this calculator and +// defines a few additional configuration options. +class ABSL_DEPRECATED("Use FlowLimiterCalculator instead.") + RealTimeFlowLimiterCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { int num_data_streams = cc->Inputs().NumEntries(""); From 5a5ff5393a7bfd9e76f7c3c867957eb18c48f80e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 22 Dec 2022 17:29:23 -0800 Subject: [PATCH 228/346] Internal change PiperOrigin-RevId: 497269082 --- mediapipe/framework/api2/builder.h | 2 +- mediapipe/framework/api2/packet.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 19273bf44..2a98c4166 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -398,7 +398,7 @@ template class Node; #if __cplusplus >= 201703L // Deduction guide to silence -Wctad-maybe-unsupported. -explicit Node()->Node; +explicit Node() -> Node; #endif // C++17 template <> diff --git a/mediapipe/framework/api2/packet.h b/mediapipe/framework/api2/packet.h index 7933575d3..b1ebb0410 100644 --- a/mediapipe/framework/api2/packet.h +++ b/mediapipe/framework/api2/packet.h @@ -181,7 +181,7 @@ template class Packet; #if __cplusplus >= 201703L // Deduction guide to silence -Wctad-maybe-unsupported. -explicit Packet()->Packet; +explicit Packet() -> Packet; #endif // C++17 template <> From 175aff9be8ca719257e15355ecc1b682e7e4e299 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 27 Dec 2022 11:24:50 -0800 Subject: [PATCH 229/346] Update list of issue assignments PiperOrigin-RevId: 498003950 --- .github/bot_config.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/bot_config.yml b/.github/bot_config.yml index 8ad724168..74a60e4b9 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -15,4 +15,5 @@ # A list of assignees assignees: - - sureshdagooglecom + - kuaashish + - ayushgdev From 7e36a5e2ae8c66ef9717d399fa4004f448dde13f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 28 Dec 2022 11:22:52 -0800 Subject: [PATCH 230/346] Set filecmp.cmp(shallow=False) in model_maker unit tests. PiperOrigin-RevId: 498218578 --- .../python/text/text_classifier/text_classifier_test.py | 6 ++++-- .../python/vision/image_classifier/image_classifier_test.py | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index 7a30d19fd..d2edb78bc 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -72,8 +72,10 @@ class TextClassifierTest(tf.test.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) self.assertTrue( - filecmp.cmp(output_metadata_file, - self._AVERAGE_WORD_EMBEDDING_JSON_FILE)) + filecmp.cmp( + output_metadata_file, + self._AVERAGE_WORD_EMBEDDING_JSON_FILE, + shallow=False)) def test_create_and_train_bert(self): train_data, validation_data = self._get_data() diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 6ca21d334..14c67d831 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -135,7 +135,9 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) - self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file)) + self.assertTrue( + filecmp.cmp( + output_metadata_file, expected_metadata_file, shallow=False)) def test_continual_training_by_loading_checkpoint(self): mock_stdout = io.StringIO() From 9580f045710327b7a22d738b911af70121e2a79a Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 28 Dec 2022 13:57:20 -0800 Subject: [PATCH 231/346] Apply most graph options synchronously PiperOrigin-RevId: 498244085 --- .../audio_classifier/audio_classifier.ts | 7 +- .../audio_classifier/audio_classifier_test.ts | 3 +- .../audio/audio_embedder/audio_embedder.ts | 7 +- .../audio_embedder/audio_embedder_test.ts | 3 +- .../tasks/web/components/processors/BUILD | 26 --- .../processors/base_options.test.ts | 127 --------------- .../web/components/processors/base_options.ts | 80 ---------- mediapipe/tasks/web/core/BUILD | 5 +- mediapipe/tasks/web/core/task_runner.ts | 75 ++++++++- mediapipe/tasks/web/core/task_runner_test.ts | 148 +++++++++++++++++- .../text/text_classifier/text_classifier.ts | 7 +- .../text_classifier/text_classifier_test.ts | 3 +- .../web/text/text_embedder/text_embedder.ts | 7 +- .../text/text_embedder/text_embedder_test.ts | 3 +- mediapipe/tasks/web/vision/core/BUILD | 1 + .../vision/core/vision_task_runner.test.ts | 32 ++-- .../web/vision/core/vision_task_runner.ts | 4 +- .../gesture_recognizer/gesture_recognizer.ts | 8 +- .../gesture_recognizer_test.ts | 3 +- .../vision/hand_landmarker/hand_landmarker.ts | 8 +- .../hand_landmarker/hand_landmarker_test.ts | 3 +- .../image_classifier/image_classifier.ts | 7 +- .../image_classifier/image_classifier_test.ts | 3 +- .../vision/image_embedder/image_embedder.ts | 7 +- .../image_embedder/image_embedder_test.ts | 3 +- .../vision/object_detector/object_detector.ts | 8 +- .../object_detector/object_detector_test.ts | 3 +- 27 files changed, 280 insertions(+), 311 deletions(-) delete mode 100644 mediapipe/tasks/web/components/processors/base_options.test.ts delete mode 100644 mediapipe/tasks/web/components/processors/base_options.ts diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 7bfca680a..51573f50a 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -119,11 +119,10 @@ export class AudioClassifier extends AudioTaskRunner { * * @param options The options for the audio classifier. */ - override async setOptions(options: AudioClassifierOptions): Promise { - await super.setOptions(options); + override setOptions(options: AudioClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -171,7 +170,7 @@ export class AudioClassifier extends AudioTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM); diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts index d5c0a9429..2089f184f 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts @@ -79,7 +79,8 @@ describe('AudioClassifier', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); audioClassifier = new AudioClassifierFake(); - await audioClassifier.setOptions({}); // Initialize graph + await audioClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 246cba883..6a4b8ce39 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -121,11 +121,10 @@ export class AudioEmbedder extends AudioTaskRunner { * * @param options The options for the audio embedder. */ - override async setOptions(options: AudioEmbedderOptions): Promise { - await super.setOptions(options); + override setOptions(options: AudioEmbedderOptions): Promise { this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -171,7 +170,7 @@ export class AudioEmbedder extends AudioTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts index 2f605ff98..dde61a6e9 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts @@ -70,7 +70,8 @@ describe('AudioEmbedder', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); audioEmbedder = new AudioEmbedderFake(); - await audioEmbedder.setOptions({}); // Initialize graph + await audioEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', () => { diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index 148a08238..cab24293d 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -103,29 +103,3 @@ jasmine_node_test( name = "embedder_options_test", deps = [":embedder_options_test_lib"], ) - -mediapipe_ts_library( - name = "base_options", - srcs = [ - "base_options.ts", - ], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", - "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", - "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", - "//mediapipe/tasks/web/core", - ], -) - -mediapipe_ts_library( - name = "base_options_test_lib", - testonly = True, - srcs = ["base_options.test.ts"], - deps = [":base_options"], -) - -jasmine_node_test( - name = "base_options_test", - deps = [":base_options_test_lib"], -) diff --git a/mediapipe/tasks/web/components/processors/base_options.test.ts b/mediapipe/tasks/web/components/processors/base_options.test.ts deleted file mode 100644 index 6d58be68f..000000000 --- a/mediapipe/tasks/web/components/processors/base_options.test.ts +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import 'jasmine'; - -// Placeholder for internal dependency on encodeByteArray -// Placeholder for internal dependency on trusted resource URL builder - -import {convertBaseOptionsToProto} from './base_options'; - -describe('convertBaseOptionsToProto()', () => { - const mockBytes = new Uint8Array([0, 1, 2, 3]); - const mockBytesResult = { - modelAsset: { - fileContent: Buffer.from(mockBytes).toString('base64'), - fileName: undefined, - fileDescriptorMeta: undefined, - filePointerMeta: undefined, - }, - useStreamMode: false, - acceleration: { - xnnpack: undefined, - gpu: undefined, - tflite: {}, - }, - }; - - let fetchSpy: jasmine.Spy; - - beforeEach(() => { - fetchSpy = jasmine.createSpy().and.callFake(async url => { - expect(url).toEqual('foo'); - return { - arrayBuffer: () => mockBytes.buffer, - } as unknown as Response; - }); - global.fetch = fetchSpy; - }); - - it('verifies that at least one model asset option is provided', async () => { - await expectAsync(convertBaseOptionsToProto({})) - .toBeRejectedWithError( - /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); - }); - - it('verifies that no more than one model asset option is provided', async () => { - await expectAsync(convertBaseOptionsToProto({ - modelAssetPath: `foo`, - modelAssetBuffer: new Uint8Array([]) - })) - .toBeRejectedWithError( - /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); - }); - - it('downloads model', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetPath: `foo`, - }); - - expect(fetchSpy).toHaveBeenCalled(); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); - - it('does not download model when bytes are provided', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - }); - - expect(fetchSpy).not.toHaveBeenCalled(); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); - - it('can enable CPU delegate', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'CPU', - }); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); - - it('can enable GPU delegate', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'GPU', - }); - expect(baseOptionsProto.toObject()).toEqual({ - ...mockBytesResult, - acceleration: { - xnnpack: undefined, - gpu: { - useAdvancedGpuApi: false, - api: 0, - allowPrecisionLoss: true, - cachedKernelPath: undefined, - serializedModelDir: undefined, - modelToken: undefined, - usage: 2, - }, - tflite: undefined, - }, - }); - }); - - it('can reset delegate', async () => { - let baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'GPU', - }); - // Clear backend - baseOptionsProto = - await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); -}); diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts deleted file mode 100644 index 97b62b784..000000000 --- a/mediapipe/tasks/web/components/processors/base_options.ts +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb'; -import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; -import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; -import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; -import {BaseOptions} from '../../../../tasks/web/core/task_runner_options'; - -// The OSS JS API does not support the builder pattern. -// tslint:disable:jspb-use-builder-pattern - -/** - * Converts a BaseOptions API object to its Protobuf representation. - * @throws If neither a model assset path or buffer is provided - */ -export async function convertBaseOptionsToProto( - updatedOptions: BaseOptions, - currentOptions?: BaseOptionsProto): Promise { - const result = - currentOptions ? currentOptions.clone() : new BaseOptionsProto(); - - await configureExternalFile(updatedOptions, result); - configureAcceleration(updatedOptions, result); - - return result; -} - -/** - * Configues the `externalFile` option and validates that a single model is - * provided. - */ -async function configureExternalFile( - options: BaseOptions, proto: BaseOptionsProto) { - const externalFile = proto.getModelAsset() || new ExternalFile(); - proto.setModelAsset(externalFile); - - if (options.modelAssetPath || options.modelAssetBuffer) { - if (options.modelAssetPath && options.modelAssetBuffer) { - throw new Error( - 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); - } - - let modelAssetBuffer = options.modelAssetBuffer; - if (!modelAssetBuffer) { - const response = await fetch(options.modelAssetPath!.toString()); - modelAssetBuffer = new Uint8Array(await response.arrayBuffer()); - } - externalFile.setFileContent(modelAssetBuffer); - } - - if (!externalFile.hasFileContent()) { - throw new Error( - 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); - } -} - -/** Configues the `acceleration` option. */ -function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) { - const acceleration = proto.getAcceleration() ?? new Acceleration(); - if (options.delegate === 'GPU') { - acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); - } else { - acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); - } - proto.setAcceleration(acceleration); -} diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index 1721661f5..c0d10d28b 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -18,8 +18,10 @@ mediapipe_ts_library( srcs = ["task_runner.ts"], deps = [ ":core", + "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", + "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", @@ -53,6 +55,7 @@ mediapipe_ts_library( "task_runner_test.ts", ], deps = [ + ":core", ":task_runner", ":task_runner_test_utils", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 2011fadef..ffb538b52 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -14,9 +14,11 @@ * limitations under the License. */ +import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb'; +import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; -import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options'; -import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; +import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb'; +import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; @@ -91,14 +93,52 @@ export abstract class TaskRunner { this.graphRunner.registerModelResourcesGraphService(); } - /** Configures the shared options of a MediaPipe Task. */ - async setOptions(options: TaskRunnerOptions): Promise { - if (options.baseOptions) { - this.baseOptions = await convertBaseOptionsToProto( - options.baseOptions, this.baseOptions); + /** Configures the task with custom options. */ + abstract setOptions(options: TaskRunnerOptions): Promise; + + /** + * Applies the current set of options, including any base options that have + * not been processed by the task implementation. The options are applied + * synchronously unless a `modelAssetPath` is provided. This ensures that + * for most use cases options are applied directly and immediately affect + * the next inference. + */ + protected applyOptions(options: TaskRunnerOptions): Promise { + const baseOptions: BaseOptions = options.baseOptions || {}; + + // Validate that exactly one model is configured + if (options.baseOptions?.modelAssetBuffer && + options.baseOptions?.modelAssetPath) { + throw new Error( + 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); + } else if (!(this.baseOptions.getModelAsset()?.hasFileContent() || + options.baseOptions?.modelAssetBuffer || + options.baseOptions?.modelAssetPath)) { + throw new Error( + 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); + } + + this.setAcceleration(baseOptions); + if (baseOptions.modelAssetPath) { + // We don't use `await` here since we want to apply most settings + // synchronously. + return fetch(baseOptions.modelAssetPath.toString()) + .then(response => response.arrayBuffer()) + .then(buffer => { + this.setExternalFile(new Uint8Array(buffer)); + this.refreshGraph(); + }); + } else { + // Apply the setting synchronously. + this.setExternalFile(baseOptions.modelAssetBuffer); + this.refreshGraph(); + return Promise.resolve(); } } + /** Appliest the current options to the MediaPipe graph. */ + protected abstract refreshGraph(): void; + /** * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run * over the video stream. Will replace the previously running MediaPipe graph, @@ -140,6 +180,27 @@ export abstract class TaskRunner { } this.processingErrors = []; } + + /** Configures the `externalFile` option */ + private setExternalFile(modelAssetBuffer?: Uint8Array): void { + const externalFile = this.baseOptions.getModelAsset() || new ExternalFile(); + if (modelAssetBuffer) { + externalFile.setFileContent(modelAssetBuffer); + } + this.baseOptions.setModelAsset(externalFile); + } + + /** Configures the `acceleration` option. */ + private setAcceleration(options: BaseOptions) { + const acceleration = + this.baseOptions.getAcceleration() ?? new Acceleration(); + if (options.delegate === 'GPU') { + acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); + } else { + acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); + } + this.baseOptions.setAcceleration(acceleration); + } } diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts index c9aad9d25..a55ac04d7 100644 --- a/mediapipe/tasks/web/core/task_runner_test.ts +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -15,18 +15,22 @@ */ import 'jasmine'; +// Placeholder for internal dependency on encodeByteArray import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {TaskRunner} from '../../../tasks/web/core/task_runner'; import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; import {ErrorListener} from '../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource URL builder import {GraphRunnerImageLib} from './task_runner'; +import {TaskRunnerOptions} from './task_runner_options.d'; class TaskRunnerFake extends TaskRunner { - protected baseOptions = new BaseOptionsProto(); private errorListener: ErrorListener|undefined; private errors: string[] = []; + baseOptions = new BaseOptionsProto(); + static createFake(): TaskRunnerFake { const wasmModule = createSpyWasmModule(); return new TaskRunnerFake(wasmModule); @@ -61,10 +65,16 @@ class TaskRunnerFake extends TaskRunner { super.finishProcessing(); } + override refreshGraph(): void {} + override setGraph(graphData: Uint8Array, isBinary: boolean): void { super.setGraph(graphData, isBinary); } + setOptions(options: TaskRunnerOptions): Promise { + return this.applyOptions(options); + } + private throwErrors(): void { expect(this.errorListener).toBeDefined(); for (const error of this.errors) { @@ -75,8 +85,38 @@ class TaskRunnerFake extends TaskRunner { } describe('TaskRunner', () => { + const mockBytes = new Uint8Array([0, 1, 2, 3]); + const mockBytesResult = { + modelAsset: { + fileContent: Buffer.from(mockBytes).toString('base64'), + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined, + }, + useStreamMode: false, + acceleration: { + xnnpack: undefined, + gpu: undefined, + tflite: {}, + }, + }; + + let fetchSpy: jasmine.Spy; + let taskRunner: TaskRunnerFake; + + beforeEach(() => { + fetchSpy = jasmine.createSpy().and.callFake(async url => { + expect(url).toEqual('foo'); + return { + arrayBuffer: () => mockBytes.buffer, + } as unknown as Response; + }); + global.fetch = fetchSpy; + + taskRunner = TaskRunnerFake.createFake(); + }); + it('handles errors during graph update', () => { - const taskRunner = TaskRunnerFake.createFake(); taskRunner.enqueueError('Test error'); expect(() => { @@ -85,7 +125,6 @@ describe('TaskRunner', () => { }); it('handles errors during graph execution', () => { - const taskRunner = TaskRunnerFake.createFake(); taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); taskRunner.enqueueError('Test error'); @@ -96,7 +135,6 @@ describe('TaskRunner', () => { }); it('can handle multiple errors', () => { - const taskRunner = TaskRunnerFake.createFake(); taskRunner.enqueueError('Test error 1'); taskRunner.enqueueError('Test error 2'); @@ -104,4 +142,106 @@ describe('TaskRunner', () => { taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); }).toThrowError(/Test error 1, Test error 2/); }); + + it('verifies that at least one model asset option is provided', () => { + expect(() => { + taskRunner.setOptions({}); + }) + .toThrowError( + /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); + }); + + it('verifies that no more than one model asset option is provided', () => { + expect(() => { + taskRunner.setOptions({ + baseOptions: { + modelAssetPath: `foo`, + modelAssetBuffer: new Uint8Array([]) + } + }); + }) + .toThrowError( + /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); + }); + + it('doesn\'t require model once it is configured', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + expect(() => { + taskRunner.setOptions({}); + }).not.toThrowError(); + }); + + it('downloads model', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetPath: `foo`}}); + + expect(fetchSpy).toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('does not download model when bytes are provided', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('changes model synchronously when bytes are provided', () => { + const resolvedPromise = taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + + // Check that the change has been applied even though we do not await the + // above Promise + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + return resolvedPromise; + }); + + it('can enable CPU delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'CPU', + } + }); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('can enable GPU delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'GPU', + } + }); + expect(taskRunner.baseOptions.toObject()).toEqual({ + ...mockBytesResult, + acceleration: { + xnnpack: undefined, + gpu: { + useAdvancedGpuApi: false, + api: 0, + allowPrecisionLoss: true, + cachedKernelPath: undefined, + serializedModelDir: undefined, + modelToken: undefined, + usage: 2, + }, + tflite: undefined, + }, + }); + }); + + it('can reset delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'GPU', + } + }); + // Clear backend + await taskRunner.setOptions({baseOptions: {delegate: undefined}}); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); }); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 62708700a..981438625 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -109,11 +109,10 @@ export class TextClassifier extends TaskRunner { * * @param options The options for the text classifier. */ - override async setOptions(options: TextClassifierOptions): Promise { - await super.setOptions(options); + override setOptions(options: TextClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } protected override get baseOptions(): BaseOptionsProto { @@ -141,7 +140,7 @@ export class TextClassifier extends TaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts index 841bf8c48..5578362cb 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts @@ -56,7 +56,8 @@ describe('TextClassifier', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); textClassifier = new TextClassifierFake(); - await textClassifier.setOptions({}); // Initialize graph + await textClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 611233e02..7aa0aa6b9 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -113,11 +113,10 @@ export class TextEmbedder extends TaskRunner { * * @param options The options for the text embedder. */ - override async setOptions(options: TextEmbedderOptions): Promise { - await super.setOptions(options); + override setOptions(options: TextEmbedderOptions): Promise { this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); + return this.applyOptions(options); } protected override get baseOptions(): BaseOptionsProto { @@ -157,7 +156,7 @@ export class TextEmbedder extends TaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM); diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts index 04a9b371a..2804e4deb 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts @@ -56,7 +56,8 @@ describe('TextEmbedder', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); textEmbedder = new TextEmbedderFake(); - await textEmbedder.setOptions({}); // Initialize graph + await textEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index e4ea3036f..03958a819 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -29,6 +29,7 @@ mediapipe_ts_library( testonly = True, srcs = ["vision_task_runner.test.ts"], deps = [ + ":vision_task_options", ":vision_task_runner", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/web/core:task_runner_test_utils", diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts index 6cc9ea328..d77cc4fed 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -20,6 +20,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils'; import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {VisionTaskOptions} from './vision_task_options'; import {VisionTaskRunner} from './vision_task_runner'; class VisionTaskRunnerFake extends VisionTaskRunner { @@ -31,6 +32,12 @@ class VisionTaskRunnerFake extends VisionTaskRunner { protected override process(): void {} + protected override refreshGraph(): void {} + + override setOptions(options: VisionTaskOptions): Promise { + return this.applyOptions(options); + } + override processImageData(image: ImageSource): void { super.processImageData(image); } @@ -41,32 +48,24 @@ class VisionTaskRunnerFake extends VisionTaskRunner { } describe('VisionTaskRunner', () => { - const streamMode = { - modelAsset: undefined, - useStreamMode: true, - acceleration: undefined, - }; - - const imageMode = { - modelAsset: undefined, - useStreamMode: false, - acceleration: undefined, - }; - let visionTaskRunner: VisionTaskRunnerFake; - beforeEach(() => { + beforeEach(async () => { visionTaskRunner = new VisionTaskRunnerFake(); + await visionTaskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('can enable image mode', async () => { await visionTaskRunner.setOptions({runningMode: 'image'}); - expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: false})); }); it('can enable video mode', async () => { await visionTaskRunner.setOptions({runningMode: 'video'}); - expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: true})); }); it('can clear running mode', async () => { @@ -74,7 +73,8 @@ describe('VisionTaskRunner', () => { // Clear running mode await visionTaskRunner.setOptions({runningMode: undefined}); - expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: false})); }); it('cannot process images with video mode', async () => { diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 3432b521b..952990326 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -22,13 +22,13 @@ import {VisionTaskOptions} from './vision_task_options'; /** Base class for all MediaPipe Vision Tasks. */ export abstract class VisionTaskRunner extends TaskRunner { /** Configures the shared options of a vision task. */ - override async setOptions(options: VisionTaskOptions): Promise { - await super.setOptions(options); + override applyOptions(options: VisionTaskOptions): Promise { if ('runningMode' in options) { const useStreamMode = !!options.runningMode && options.runningMode !== 'image'; this.baseOptions.setUseStreamMode(useStreamMode); } + return super.applyOptions(options); } /** Sends an image packet to the graph and awaits results. */ diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index b6b795076..cfeb179f5 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -169,9 +169,7 @@ export class GestureRecognizer extends * * @param options The options for the gesture recognizer. */ - override async setOptions(options: GestureRecognizerOptions): Promise { - await super.setOptions(options); - + override setOptions(options: GestureRecognizerOptions): Promise { if ('numHands' in options) { this.handDetectorGraphOptions.setNumHands( options.numHands ?? DEFAULT_NUM_HANDS); @@ -221,7 +219,7 @@ export class GestureRecognizer extends ?.clearClassifierOptions(); } - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -342,7 +340,7 @@ export class GestureRecognizer extends } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index c0f0d1554..ff6bba613 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -109,7 +109,8 @@ describe('GestureRecognizer', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); gestureRecognizer = new GestureRecognizerFake(); - await gestureRecognizer.setOptions({}); // Initialize graph + await gestureRecognizer.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 2a0e8286c..24cf9a402 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -150,9 +150,7 @@ export class HandLandmarker extends VisionTaskRunner { * * @param options The options for the hand landmarker. */ - override async setOptions(options: HandLandmarkerOptions): Promise { - await super.setOptions(options); - + override setOptions(options: HandLandmarkerOptions): Promise { // Configure hand detector options. if ('numHands' in options) { this.handDetectorGraphOptions.setNumHands( @@ -173,7 +171,7 @@ export class HandLandmarker extends VisionTaskRunner { options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); } - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -291,7 +289,7 @@ export class HandLandmarker extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts index fc26680e0..76e77b4bf 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -98,7 +98,8 @@ describe('HandLandmarker', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); handLandmarker = new HandLandmarkerFake(); - await handLandmarker.setOptions({}); // Initialize graph + await handLandmarker.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 36e7311fb..9298a860c 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -118,11 +118,10 @@ export class ImageClassifier extends VisionTaskRunner { * * @param options The options for the image classifier. */ - override async setOptions(options: ImageClassifierOptions): Promise { - await super.setOptions(options); + override setOptions(options: ImageClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -163,7 +162,7 @@ export class ImageClassifier extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts index 2041a0cef..da4a01d02 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts @@ -61,7 +61,8 @@ describe('ImageClassifier', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); imageClassifier = new ImageClassifierFake(); - await imageClassifier.setOptions({}); // Initialize graph + await imageClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 0c45ba5e7..cf0bd8c5d 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -120,11 +120,10 @@ export class ImageEmbedder extends VisionTaskRunner { * * @param options The options for the image embedder. */ - override async setOptions(options: ImageEmbedderOptions): Promise { - await super.setOptions(options); + override setOptions(options: ImageEmbedderOptions): Promise { this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -186,7 +185,7 @@ export class ImageEmbedder extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM); diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts index cafe0f3d8..b63bb374c 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts @@ -57,7 +57,8 @@ describe('ImageEmbedder', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); imageEmbedder = new ImageEmbedderFake(); - await imageEmbedder.setOptions({}); // Initialize graph + await imageEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index fbfaced12..e4c51de08 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -117,9 +117,7 @@ export class ObjectDetector extends VisionTaskRunner { * * @param options The options for the object detector. */ - override async setOptions(options: ObjectDetectorOptions): Promise { - await super.setOptions(options); - + override setOptions(options: ObjectDetectorOptions): Promise { // Note that we have to support both JSPB and ProtobufJS, hence we // have to expliclity clear the values instead of setting them to // `undefined`. @@ -153,7 +151,7 @@ export class ObjectDetector extends VisionTaskRunner { this.options.clearCategoryDenylistList(); } - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -226,7 +224,7 @@ export class ObjectDetector extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(DETECTIONS_STREAM); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts index fff1a1c48..43b7035d5 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -61,7 +61,8 @@ describe('ObjectDetector', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); objectDetector = new ObjectDetectorFake(); - await objectDetector.setOptions({}); // Initialize graph + await objectDetector.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { From 1924f1cdff94af953c2cd9b01a13d623ea13e7a7 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 28 Dec 2022 14:27:42 -0800 Subject: [PATCH 232/346] Tensor: Fix use_ahwb_ flag and tests on local device involved. PiperOrigin-RevId: 498249332 --- mediapipe/framework/formats/tensor_ahwb.cc | 3 +- .../framework/formats/tensor_ahwb_gpu_test.cc | 16 ++++++-- .../framework/formats/tensor_ahwb_test.cc | 39 ++++--------------- 3 files changed, 22 insertions(+), 36 deletions(-) diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 466811be7..74b2dca93 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -458,7 +458,8 @@ void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const { ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim); } } - use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_); + // Keep flag value if it was set previously. + use_ahwb_ = use_ahwb_ || ahwb_usage_track_.contains(ahwb_tracking_key_); } #else // MEDIAPIPE_TENSOR_USE_AHWB diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index a6ca00949..e2ad869f9 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -92,9 +92,14 @@ class TensorAhwbGpuTest : public mediapipe::GpuTestBase { }; TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { - Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + { + // Request Ahwb first to get Ahwb storage allocated internally. + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetWritingFinishedFD(-1, [](bool) { return true; }); + } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); @@ -114,9 +119,14 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { } TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { - Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})}; + { + // Request Ahwb first to get Ahwb storage allocated internally. + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); + } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); @@ -139,7 +149,6 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { // Request the CPU view to get the memory to be allocated. // Request Ahwb view then to transform the storage into Ahwb. - Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; { @@ -168,7 +177,6 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) { // Request the GPU view to get the ssbo allocated internally. // Request Ahwb view then to transform the storage into Ahwb. - Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; RunInGlContext([&tensor] { diff --git a/mediapipe/framework/formats/tensor_ahwb_test.cc b/mediapipe/framework/formats/tensor_ahwb_test.cc index 7ab5a4925..f0baa6303 100644 --- a/mediapipe/framework/formats/tensor_ahwb_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_test.cc @@ -1,34 +1,28 @@ #include "mediapipe/framework/formats/tensor.h" -#include "mediapipe/gpu/gpu_test_base.h" #include "testing/base/public/gmock.h" #include "testing/base/public/gunit.h" -#ifdef MEDIAPIPE_TENSOR_USE_AHWB -#if !MEDIAPIPE_DISABLE_GPU - namespace mediapipe { -class TensorAhwbTest : public mediapipe::GpuTestBase { - public: -}; - -TEST_F(TensorAhwbTest, TestCpuThenAHWB) { +TEST(TensorAhwbTest, TestCpuThenAHWB) { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); { auto ptr = tensor.GetCpuWriteView().buffer(); EXPECT_NE(ptr, nullptr); } { - auto ahwb = tensor.GetAHardwareBufferReadView().handle(); - EXPECT_NE(ahwb, nullptr); + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); } } -TEST_F(TensorAhwbTest, TestAHWBThenCpu) { +TEST(TensorAhwbTest, TestAHWBThenCpu) { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); { - auto ahwb = tensor.GetAHardwareBufferWriteView().handle(); - EXPECT_NE(ahwb, nullptr); + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetWritingFinishedFD(-1, [](bool) { return true; }); } { auto ptr = tensor.GetCpuReadView().buffer(); @@ -36,21 +30,4 @@ TEST_F(TensorAhwbTest, TestAHWBThenCpu) { } } -TEST_F(TensorAhwbTest, TestCpuThenGl) { - RunInGlContext([] { - Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); - { - auto ptr = tensor.GetCpuWriteView().buffer(); - EXPECT_NE(ptr, nullptr); - } - { - auto ssbo = tensor.GetOpenGlBufferReadView().name(); - EXPECT_GT(ssbo, 0); - } - }); -} - } // namespace mediapipe - -#endif // !MEDIAPIPE_DISABLE_GPU -#endif // MEDIAPIPE_TENSOR_USE_AHWB From 2d9a969d10bdcac98e0e86f617817e08cf656331 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 28 Dec 2022 16:07:09 -0800 Subject: [PATCH 233/346] Tensor1: memorize size_alignment when tracking the ahwb usage. When CPU/GPU buffer allocated and the tracker selects Ahwb storage to be used then the properly recorded alignment must be used. PiperOrigin-RevId: 498264759 --- mediapipe/framework/formats/BUILD | 2 +- mediapipe/framework/formats/tensor.h | 7 +- mediapipe/framework/formats/tensor_ahwb.cc | 7 +- .../framework/formats/tensor_ahwb_test.cc | 67 +++++++++++++++++++ 4 files changed, 78 insertions(+), 5 deletions(-) diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index f5a043f10..cce7e5bd0 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -455,7 +455,7 @@ cc_library( ], }), deps = [ - "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 8a6f02e9d..0f19bb5ee 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -24,7 +24,7 @@ #include #include -#include "absl/container/flat_hash_set.h" +#include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/formats/tensor_internal.h" #include "mediapipe/framework/port.h" @@ -434,8 +434,9 @@ class Tensor { mutable bool use_ahwb_ = false; mutable uint64_t ahwb_tracking_key_ = 0; // TODO: Tracks all unique tensors. Can grow to a large number. LRU - // can be more predicted. - static inline absl::flat_hash_set ahwb_usage_track_; + // (Least Recently Used) can be more predicted. + // The value contains the size alignment parameter. + static inline absl::flat_hash_map ahwb_usage_track_; // Expects the target SSBO to be already bound. bool AllocateAhwbMapToSsbo() const; bool InsertAhwbToSsboFence() const; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 74b2dca93..525f05f31 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -266,7 +266,12 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView( bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { // Mark current tracking key as Ahwb-use. - ahwb_usage_track_.insert(ahwb_tracking_key_); + if (auto it = ahwb_usage_track_.find(ahwb_tracking_key_); + it != ahwb_usage_track_.end()) { + size_alignment = it->second; + } else if (ahwb_tracking_key_ != 0) { + ahwb_usage_track_.insert({ahwb_tracking_key_, size_alignment}); + } use_ahwb_ = true; if (__builtin_available(android 26, *)) { diff --git a/mediapipe/framework/formats/tensor_ahwb_test.cc b/mediapipe/framework/formats/tensor_ahwb_test.cc index f0baa6303..3da6ca8d3 100644 --- a/mediapipe/framework/formats/tensor_ahwb_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_test.cc @@ -30,4 +30,71 @@ TEST(TensorAhwbTest, TestAHWBThenCpu) { } } +TEST(TensorAhwbTest, TestAhwbAlignment) { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{5}); + { + auto view = tensor.GetAHardwareBufferWriteView(16); + EXPECT_NE(view.handle(), nullptr); + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc; + AHardwareBuffer_describe(view.handle(), &desc); + // sizeof(float) * 5 = 20, the closest aligned to 16 size is 32. + EXPECT_EQ(desc.width, 32); + } + view.SetWritingFinishedFD(-1, [](bool) { return true; }); + } +} + +// Tensor::GetCpuView uses source location mechanism that gives source file name +// and line from where the method is called. The function is intended just to +// have two calls providing the same source file name and line. +auto GetCpuView(const Tensor &tensor) { return tensor.GetCpuWriteView(); } + +// The test checks the tracking mechanism: when a tensor's Cpu view is retrieved +// for the first time then the source location is attached to the tensor. If the +// Ahwb view is requested then from the tensor then the previously recorded Cpu +// view request source location is marked for using Ahwb storage. +// When a Cpu view with the same source location (but for the newly allocated +// tensor) is requested and the location is marked to use Ahwb storage then the +// Ahwb storage is allocated for the CpuView. +TEST(TensorAhwbTest, TestTrackingAhwb) { + // Create first tensor and request Cpu and then Ahwb view to mark the source + // location for Ahwb storage. + { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9}); + { + auto view = GetCpuView(tensor); + EXPECT_NE(view.buffer(), nullptr); + } + { + // Align size of the Ahwb by multiple of 16. + auto view = tensor.GetAHardwareBufferWriteView(16); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); + } + } + { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9}); + { + // The second tensor uses the same Cpu view source location so Ahwb + // storage is allocated internally. + auto view = GetCpuView(tensor); + EXPECT_NE(view.buffer(), nullptr); + } + { + // Check the Ahwb size to be aligned to multiple of 16. The alignment is + // stored by previous requesting of the Ahwb view. + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc; + AHardwareBuffer_describe(view.handle(), &desc); + // sizeof(float) * 9 = 36. The closest aligned size is 48. + EXPECT_EQ(desc.width, 48); + } + view.SetReadingFinishedFunc([](bool) { return true; }); + } + } +} + } // namespace mediapipe From aaa16eca1fedf9450689be422ea2dc01c7d74c93 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 29 Dec 2022 08:33:58 -0800 Subject: [PATCH 234/346] Sets the graph service packets before initializing (and validating the graph) in the objc graph wrapper. PiperOrigin-RevId: 498393761 --- mediapipe/objc/MPPGraph.mm | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index 1bd177e80..3123eb863 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -230,16 +230,17 @@ if ([wrapper.delegate } - (absl::Status)performStart { - absl::Status status = _graph->Initialize(_config); - if (!status.ok()) { - return status; - } + absl::Status status; for (const auto& service_packet : _servicePackets) { status = _graph->SetServicePacket(*service_packet.first, service_packet.second); if (!status.ok()) { return status; } } + status = _graph->Initialize(_config); + if (!status.ok()) { + return status; + } status = _graph->StartRun(_inputSidePackets, _streamHeaders); if (!status.ok()) { return status; From 60c6b155f626f40e2971cda10aa4c3565897874a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 29 Dec 2022 10:16:10 -0800 Subject: [PATCH 235/346] Save an integer id in graph profiler objects to distinguish between different profiler instances during benchmarking. PiperOrigin-RevId: 498409363 --- .../framework/profiler/graph_profiler.cc | 1 + mediapipe/framework/profiler/graph_profiler.h | 9 +++++++ .../framework/profiler/graph_profiler_test.cc | 26 +++++++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/mediapipe/framework/profiler/graph_profiler.cc b/mediapipe/framework/profiler/graph_profiler.cc index f14acfc78..6aead5250 100644 --- a/mediapipe/framework/profiler/graph_profiler.cc +++ b/mediapipe/framework/profiler/graph_profiler.cc @@ -194,6 +194,7 @@ void GraphProfiler::Initialize( "Calculator \"$0\" has already been added.", node_name); } profile_builder_ = std::make_unique(this); + graph_id_ = ++next_instance_id_; is_initialized_ = true; } diff --git a/mediapipe/framework/profiler/graph_profiler.h b/mediapipe/framework/profiler/graph_profiler.h index 23caed4ec..6358cb057 100644 --- a/mediapipe/framework/profiler/graph_profiler.h +++ b/mediapipe/framework/profiler/graph_profiler.h @@ -237,6 +237,9 @@ class GraphProfiler : public std::enable_shared_from_this { return validated_graph_; } + // Gets a numerical identifier for this GraphProfiler object. + uint64_t GetGraphId() { return graph_id_; } + private: // This can be used to add packet info for the input streams to the graph. // It treats the stream defined by |stream_name| as a stream produced by a @@ -357,6 +360,12 @@ class GraphProfiler : public std::enable_shared_from_this { class GraphProfileBuilder; std::unique_ptr profile_builder_; + // The globally incrementing identifier for all graphs in a process. + static inline std::atomic_int next_instance_id_ = 0; + + // A unique identifier for this object. Only unique within a process. + uint64_t graph_id_; + // For testing. friend GraphProfilerTestPeer; }; diff --git a/mediapipe/framework/profiler/graph_profiler_test.cc b/mediapipe/framework/profiler/graph_profiler_test.cc index 81ba90cda..75d1c7ebd 100644 --- a/mediapipe/framework/profiler/graph_profiler_test.cc +++ b/mediapipe/framework/profiler/graph_profiler_test.cc @@ -442,6 +442,32 @@ TEST_F(GraphProfilerTestPeer, InitializeMultipleTimes) { "Cannot initialize .* multiple times."); } +// Tests that graph identifiers are not reused, even after destruction. +TEST_F(GraphProfilerTestPeer, InitializeMultipleProfilers) { + auto raw_graph_config = R"( + profiler_config { + enable_profiler: true + } + input_stream: "input_stream" + node { + calculator: "DummyTestCalculator" + input_stream: "input_stream" + })"; + const int n_iterations = 100; + absl::flat_hash_set seen_ids; + for (int i = 0; i < n_iterations; ++i) { + std::shared_ptr profiler = + std::make_shared(); + auto graph_config = CreateGraphConfig(raw_graph_config); + mediapipe::ValidatedGraphConfig validated_graph; + QCHECK_OK(validated_graph.Initialize(graph_config)); + profiler->Initialize(validated_graph); + + int id = profiler->GetGraphId(); + ASSERT_THAT(seen_ids, testing::Not(testing::Contains(id))); + seen_ids.insert(id); + } +} // Tests that Pause(), Resume(), and Reset() works. TEST_F(GraphProfilerTestPeer, PauseResumeReset) { InitializeProfilerWithGraphConfig(R"( From 9252a025e5604cb61b11cbf23943dc7fb9e6f679 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 30 Dec 2022 04:56:57 -0800 Subject: [PATCH 236/346] Use custom gesture options in GestureRecognizer PiperOrigin-RevId: 498567432 --- .../tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index 01f444742..91a5ec213 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -151,11 +151,11 @@ ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) { auto custom_gestures_classifier_options_proto = std::make_unique( components::processors::ConvertClassifierOptionsToProto( - &(options->canned_gestures_classifier_options))); + &(options->custom_gestures_classifier_options))); hand_gesture_recognizer_graph_options ->mutable_custom_gesture_classifier_graph_options() ->mutable_classifier_options() - ->Swap(canned_gestures_classifier_options_proto.get()); + ->Swap(custom_gestures_classifier_options_proto.get()); return options_proto; } From 2f4bb5d545fbd6b6389248b7123635dcdfff02b7 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 3 Jan 2023 09:34:21 -0800 Subject: [PATCH 237/346] Use utility framebuffer in ViewDoneWritingSimulatorWorkaround This code needs a FBO to bind the texture. Fixes invalid results when running under simulator. PiperOrigin-RevId: 499241867 --- .../gpu/gpu_buffer_storage_cv_pixel_buffer.cc | 75 +++++++++++-------- 1 file changed, 42 insertions(+), 33 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc index 014cc1c69..7cac32b7f 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc @@ -74,42 +74,51 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer, const GlTextureView& view) { CHECK(pixel_buffer); - CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); - CHECK(err == kCVReturnSuccess) - << "CVPixelBufferLockBaseAddress failed: " << err; - OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); - size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); - uint8_t* pixel_ptr = - static_cast(CVPixelBufferGetBaseAddress(pixel_buffer)); - if (pixel_format == kCVPixelFormatType_32BGRA) { - // TODO: restore previous framebuffer? Move this to helper so we - // can use BindFramebuffer? - glViewport(0, 0, view.width(), view.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), - view.name(), 0); + auto ctx = GlContext::GetCurrent().get(); + if (!ctx) ctx = view.gl_context(); + ctx->Run([pixel_buffer, &view, ctx] { + CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); + CHECK(err == kCVReturnSuccess) + << "CVPixelBufferLockBaseAddress failed: " << err; + OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); + size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); + uint8_t* pixel_ptr = + static_cast(CVPixelBufferGetBaseAddress(pixel_buffer)); + if (pixel_format == kCVPixelFormatType_32BGRA) { + glBindFramebuffer(GL_FRAMEBUFFER, kUtilityFramebuffer.Get(*ctx)); + glViewport(0, 0, view.width(), view.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + view.target(), view.name(), 0); - size_t contiguous_bytes_per_row = view.width() * 4; - if (bytes_per_row == contiguous_bytes_per_row) { - glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE, - pixel_ptr); - } else { - std::vector contiguous_buffer(contiguous_bytes_per_row * - view.height()); - uint8_t* temp_ptr = contiguous_buffer.data(); - glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE, - temp_ptr); - for (int i = 0; i < view.height(); ++i) { - memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row); - temp_ptr += contiguous_bytes_per_row; - pixel_ptr += bytes_per_row; + size_t contiguous_bytes_per_row = view.width() * 4; + if (bytes_per_row == contiguous_bytes_per_row) { + glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, + GL_UNSIGNED_BYTE, pixel_ptr); + } else { + // TODO: use GL_PACK settings for row length. We can expect + // GLES 3.0 on iOS now. + std::vector contiguous_buffer(contiguous_bytes_per_row * + view.height()); + uint8_t* temp_ptr = contiguous_buffer.data(); + glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, + GL_UNSIGNED_BYTE, temp_ptr); + for (int i = 0; i < view.height(); ++i) { + memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row); + temp_ptr += contiguous_bytes_per_row; + pixel_ptr += bytes_per_row; + } } + // TODO: restore previous framebuffer? + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + view.target(), 0, 0); + glBindFramebuffer(GL_FRAMEBUFFER, 0); + } else { + LOG(ERROR) << "unsupported pixel format: " << pixel_format; } - } else { - LOG(ERROR) << "unsupported pixel format: " << pixel_format; - } - err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); - CHECK(err == kCVReturnSuccess) - << "CVPixelBufferUnlockBaseAddress failed: " << err; + err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); + CHECK(err == kCVReturnSuccess) + << "CVPixelBufferUnlockBaseAddress failed: " << err; + }); } #endif // TARGET_IPHONE_SIMULATOR From f53c0eaceeae9b7cb622764d78054f8e44222ba3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 3 Jan 2023 09:38:02 -0800 Subject: [PATCH 238/346] Extend tag conversion behavior to also convert `:` (in addition to the current `/`, `-`, and `.`) to `_`. PiperOrigin-RevId: 499243005 --- .../tensorflow_session_from_saved_model_calculator.cc | 7 +++---- .../tensorflow_session_from_saved_model_calculator.proto | 4 ++-- .../tensorflow_session_from_saved_model_generator.cc | 7 +++---- .../tensorflow_session_from_saved_model_generator.proto | 4 ++-- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc index 922eb9d50..18bddbbe3 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc @@ -55,7 +55,7 @@ absl::Status GetLatestDirectory(std::string* path) { } // If options.convert_signature_to_tags() is set, will convert letters to -// uppercase and replace /'s and -'s with _'s. This enables the standard +// uppercase and replace /, -, . and :'s with _'s. This enables the standard // SavedModel classification, regression, and prediction signatures to be used // as uppercase INPUTS and OUTPUTS tags for streams and supports other common // patterns. @@ -67,9 +67,8 @@ const std::string MaybeConvertSignatureToTag( output.resize(name.length()); std::transform(name.begin(), name.end(), output.begin(), [](unsigned char c) { return std::toupper(c); }); - output = absl::StrReplaceAll(output, {{"/", "_"}}); - output = absl::StrReplaceAll(output, {{"-", "_"}}); - output = absl::StrReplaceAll(output, {{".", "_"}}); + output = absl::StrReplaceAll( + output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto index 927d3b51f..515b46fa9 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto @@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelCalculatorOptions { // The name of the generic signature to load into the mapping from tags to // tensor names. optional string signature_name = 2 [default = "serving_default"]; - // Whether to convert the signature keys to uppercase as well as switch /'s - // and -'s to _'s, which enables common signatures to be used as Tags. + // Whether to convert the signature keys to uppercase as well as switch + // /, -, .and :'s to _'s, which enables common signatures to be used as Tags. optional bool convert_signature_to_tags = 3 [default = true]; // If true, saved_model_path can have multiple exported models in // subdirectories saved_model_path/%08d and the alphabetically last (i.e., diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc index d5236f1cc..ee69ec56a 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -61,7 +61,7 @@ absl::Status GetLatestDirectory(std::string* path) { } // If options.convert_signature_to_tags() is set, will convert letters to -// uppercase and replace /'s and -'s with _'s. This enables the standard +// uppercase and replace /, -, and .'s with _'s. This enables the standard // SavedModel classification, regression, and prediction signatures to be used // as uppercase INPUTS and OUTPUTS tags for streams and supports other common // patterns. @@ -73,9 +73,8 @@ const std::string MaybeConvertSignatureToTag( output.resize(name.length()); std::transform(name.begin(), name.end(), output.begin(), [](unsigned char c) { return std::toupper(c); }); - output = absl::StrReplaceAll(output, {{"/", "_"}}); - output = absl::StrReplaceAll(output, {{"-", "_"}}); - output = absl::StrReplaceAll(output, {{".", "_"}}); + output = absl::StrReplaceAll( + output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto index d24a1cd73..d45fcb662 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto @@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelGeneratorOptions { // The name of the generic signature to load into the mapping from tags to // tensor names. optional string signature_name = 2 [default = "serving_default"]; - // Whether to convert the signature keys to uppercase as well as switch /'s - // and -'s to _'s, which enables common signatures to be used as Tags. + // Whether to convert the signature keys to uppercase, as well as switch /'s + // -'s, .'s, and :'s to _'s, enabling common signatures to be used as Tags. optional bool convert_signature_to_tags = 3 [default = true]; // If true, saved_model_path can have multiple exported models in // subdirectories saved_model_path/%08d and the alphabetically last (i.e., From 987f4dc1ed89801e54c408abd670f63ce0c77007 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 3 Jan 2023 10:52:41 -0800 Subject: [PATCH 239/346] Make addJsamineCustomFloatEqualityTest configurable PiperOrigin-RevId: 499263931 --- mediapipe/tasks/web/core/task_runner_test_utils.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts index 2a1161a55..838b3f585 100644 --- a/mediapipe/tasks/web/core/task_runner_test_utils.ts +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -44,10 +44,10 @@ export function createSpyWasmModule(): SpyWasmModule { * Sets up our equality testing to use a custom float equality checking function * to avoid incorrect test results due to minor floating point inaccuracies. */ -export function addJasmineCustomFloatEqualityTester() { +export function addJasmineCustomFloatEqualityTester(tolerance = 5e-8) { jasmine.addCustomEqualityTester((a, b) => { // Custom float equality if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) { - return Math.abs(a - b) < 5e-8; + return Math.abs(a - b) < tolerance; } return; }); From 68f247a5c7a2f081e6f0ff8b25b9187de5646e2b Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 3 Jan 2023 12:03:57 -0800 Subject: [PATCH 240/346] Internal change PiperOrigin-RevId: 499282085 --- .../web/vision/hand_landmarker/hand_landmarker_result.d.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts index 89f867d69..8a6d9bfa6 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts @@ -17,6 +17,8 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +export {Landmark, NormalizedLandmark, Category}; + /** * Represents the hand landmarks deection results generated by `HandLandmarker`. */ From 75b87e0e321090bf73653d83ebfa69cf6f73621f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 3 Jan 2023 12:09:59 -0800 Subject: [PATCH 241/346] Internal change PiperOrigin-RevId: 499283559 --- .../gesture_recognizer/gesture_recognizer.ts | 35 ++++++++++++++----- .../gesture_recognizer_result.d.ts | 8 ++++- .../gesture_recognizer_test.ts | 23 +++++++++++- 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index cfeb179f5..c77f2c67a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -263,12 +263,22 @@ export class GestureRecognizer extends NORM_RECT_STREAM, timestamp); this.finishProcessing(); - return { - gestures: this.gestures, - landmarks: this.landmarks, - worldLandmarks: this.worldLandmarks, - handednesses: this.handednesses - }; + if (this.gestures.length === 0) { + // If no gestures are detected in the image, just return an empty list + return { + gestures: [], + landmarks: [], + worldLandmarks: [], + handednesses: [], + }; + } else { + return { + gestures: this.gestures, + landmarks: this.landmarks, + worldLandmarks: this.worldLandmarks, + handednesses: this.handednesses + }; + } } /** Sets the default values for the graph. */ @@ -283,15 +293,19 @@ export class GestureRecognizer extends } /** Converts the proto data to a Category[][] structure. */ - private toJsCategories(data: Uint8Array[]): Category[][] { + private toJsCategories(data: Uint8Array[], populateIndex = true): + Category[][] { const result: Category[][] = []; for (const binaryProto of data) { const inputList = ClassificationList.deserializeBinary(binaryProto); const outputList: Category[] = []; for (const classification of inputList.getClassificationList()) { + const index = populateIndex && classification.hasIndex() ? + classification.getIndex()! : + DEFAULT_CATEGORY_INDEX; outputList.push({ score: classification.getScore() ?? 0, - index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX, + index, categoryName: classification.getLabel() ?? '', displayName: classification.getDisplayName() ?? '', }); @@ -375,7 +389,10 @@ export class GestureRecognizer extends }); this.graphRunner.attachProtoVectorListener( HAND_GESTURES_STREAM, binaryProto => { - this.gestures.push(...this.toJsCategories(binaryProto)); + // Gesture index is not used, because the final gesture result comes + // from multiple classifiers. + this.gestures.push( + ...this.toJsCategories(binaryProto, /* populateIndex= */ false)); }); this.graphRunner.attachProtoVectorListener( HANDEDNESS_STREAM, binaryProto => { diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts index e570270b2..323290008 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts @@ -17,6 +17,8 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +export {Category, Landmark, NormalizedLandmark}; + /** * Represents the gesture recognition results generated by `GestureRecognizer`. */ @@ -30,6 +32,10 @@ export declare interface GestureRecognizerResult { /** Handedness of detected hands. */ handednesses: Category[][]; - /** Recognized hand gestures of detected hands */ + /** + * Recognized hand gestures of detected hands. Note that the index of the + * gesture is always -1, because the raw indices from multiple gesture + * classifiers cannot consolidate to a meaningful index. + */ gestures: Category[][]; } diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index ff6bba613..ee51fd32a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -272,7 +272,7 @@ describe('GestureRecognizer', () => { expect(gestures).toEqual({ 'gestures': [[{ 'score': 0.2, - 'index': 2, + 'index': -1, 'categoryName': 'gesture_label', 'displayName': 'gesture_display_name' }]], @@ -305,4 +305,25 @@ describe('GestureRecognizer', () => { // gestures. expect(gestures2).toEqual(gestures1); }); + + it('returns empty results when no gestures are detected', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(gestureRecognizer); + gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + gestureRecognizer.listeners.get('handedness')!(createHandednesses()); + gestureRecognizer.listeners.get('hand_gestures')!([]); + }); + + // Invoke the gesture recognizer + const gestures = gestureRecognizer.recognize({} as HTMLImageElement); + expect(gestures).toEqual({ + 'gestures': [], + 'landmarks': [], + 'worldLandmarks': [], + 'handednesses': [] + }); + }); }); From e7dc989f715382c10ac6d714f4f4be5d330f903d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 3 Jan 2023 14:12:34 -0800 Subject: [PATCH 242/346] Internal Change PiperOrigin-RevId: 499313491 --- mediapipe/examples/desktop/autoflip/BUILD | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/examples/desktop/autoflip/BUILD b/mediapipe/examples/desktop/autoflip/BUILD index 562f11c49..0e28746dc 100644 --- a/mediapipe/examples/desktop/autoflip/BUILD +++ b/mediapipe/examples/desktop/autoflip/BUILD @@ -30,6 +30,10 @@ proto_library( java_lite_proto_library( name = "autoflip_messages_java_proto_lite", + visibility = [ + "//java/com/google/android/apps/photos:__subpackages__", + "//javatests/com/google/android/apps/photos:__subpackages__", + ], deps = [ ":autoflip_messages_proto", ], From add5600d0d4e9f0213ebf58088301dc7e743194a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 3 Jan 2023 17:18:59 -0800 Subject: [PATCH 243/346] Internal change PiperOrigin-RevId: 499351795 --- .../python/text/text_classifier/text_classifier_test.py | 1 + .../python/vision/image_classifier/image_classifier_test.py | 1 + 2 files changed, 2 insertions(+) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index d2edb78bc..eb4443b44 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -71,6 +71,7 @@ class TextClassifierTest(tf.test.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) + filecmp.clear_cache() self.assertTrue( filecmp.cmp( output_metadata_file, diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 14c67d831..afda8643b 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -135,6 +135,7 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) + filecmp.clear_cache() self.assertTrue( filecmp.cmp( output_metadata_file, expected_metadata_file, shallow=False)) From a4ea606eac3adf3ca5e149e9e6ff6573168971a6 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 4 Jan 2023 08:21:55 -0800 Subject: [PATCH 244/346] Internal change. PiperOrigin-RevId: 499490514 --- .../framework/formats/tensor_ahwb_gpu_test.cc | 28 +++++++++---------- .../framework/formats/tensor_ahwb_test.cc | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index e2ad869f9..45d341e20 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -83,8 +83,8 @@ void FillGpuBuffer(GLuint name, std::size_t size, TFLITE_GPU_CALL_GL(glBindBufferBase, GL_SHADER_STORAGE_BUFFER, 0, name)); MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glUseProgram, to_buffer_program)); MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDispatchCompute, size / 2, 1, 1)); - MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glBindBuffer, GL_SHADER_STORAGE_BUFFER, 0)); - MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteProgram, to_buffer_program)); + MP_EXPECT_OK(TFLITE_GPU_CALL_GL(glBindBuffer, GL_SHADER_STORAGE_BUFFER, 0)); + MP_EXPECT_OK(TFLITE_GPU_CALL_GL(glDeleteProgram, to_buffer_program)); } class TensorAhwbGpuTest : public mediapipe::GpuTestBase { @@ -97,18 +97,18 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { { // Request Ahwb first to get Ahwb storage allocated internally. auto view = tensor.GetAHardwareBufferWriteView(); - EXPECT_NE(view.handle(), nullptr); + ASSERT_NE(view.handle(), nullptr); view.SetWritingFinishedFD(-1, [](bool) { return true; }); } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); - EXPECT_GT(ssbo_name, 0); + ASSERT_GT(ssbo_name, 0); FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), tensor.element_type()); }); auto ptr = tensor.GetCpuReadView().buffer(); - EXPECT_NE(ptr, nullptr); + ASSERT_NE(ptr, nullptr); std::vector reference; reference.resize(num_elements); for (int i = 0; i < num_elements; i++) { @@ -124,18 +124,18 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { { // Request Ahwb first to get Ahwb storage allocated internally. auto view = tensor.GetAHardwareBufferWriteView(); - EXPECT_NE(view.handle(), nullptr); + ASSERT_NE(view.handle(), nullptr); view.SetReadingFinishedFunc([](bool) { return true; }); } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); - EXPECT_GT(ssbo_name, 0); + ASSERT_GT(ssbo_name, 0); FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), tensor.element_type()); }); auto ptr = tensor.GetCpuReadView().buffer(); - EXPECT_NE(ptr, nullptr); + ASSERT_NE(ptr, nullptr); std::vector reference; reference.resize(num_elements); for (int i = 0; i < num_elements; i++) { @@ -153,18 +153,18 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; { auto ptr = tensor.GetCpuWriteView().buffer(); - EXPECT_NE(ptr, nullptr); + ASSERT_NE(ptr, nullptr); for (int i = 0; i < num_elements; i++) { ptr[i] = static_cast(i) / 10.0f; } } { auto view = tensor.GetAHardwareBufferReadView(); - EXPECT_NE(view.handle(), nullptr); + ASSERT_NE(view.handle(), nullptr); view.SetReadingFinishedFunc([](bool) { return true; }); } auto ptr = tensor.GetCpuReadView().buffer(); - EXPECT_NE(ptr, nullptr); + ASSERT_NE(ptr, nullptr); std::vector reference; reference.resize(num_elements); for (int i = 0; i < num_elements; i++) { @@ -182,17 +182,17 @@ TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) { RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); - EXPECT_GT(ssbo_name, 0); + ASSERT_GT(ssbo_name, 0); FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), tensor.element_type()); }); { auto view = tensor.GetAHardwareBufferReadView(); - EXPECT_NE(view.handle(), nullptr); + ASSERT_NE(view.handle(), nullptr); view.SetReadingFinishedFunc([](bool) { return true; }); } auto ptr = tensor.GetCpuReadView().buffer(); - EXPECT_NE(ptr, nullptr); + ASSERT_NE(ptr, nullptr); std::vector reference; reference.resize(num_elements); for (int i = 0; i < num_elements; i++) { diff --git a/mediapipe/framework/formats/tensor_ahwb_test.cc b/mediapipe/framework/formats/tensor_ahwb_test.cc index 3da6ca8d3..69e49dd58 100644 --- a/mediapipe/framework/formats/tensor_ahwb_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_test.cc @@ -34,7 +34,7 @@ TEST(TensorAhwbTest, TestAhwbAlignment) { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{5}); { auto view = tensor.GetAHardwareBufferWriteView(16); - EXPECT_NE(view.handle(), nullptr); + ASSERT_NE(view.handle(), nullptr); if (__builtin_available(android 26, *)) { AHardwareBuffer_Desc desc; AHardwareBuffer_describe(view.handle(), &desc); From 9a70af146432dcbbbc961f9c1a5af4a039d0909a Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 4 Jan 2023 08:52:03 -0800 Subject: [PATCH 245/346] Internal change. PiperOrigin-RevId: 499496793 --- mediapipe/framework/formats/tensor_ahwb_gpu_test.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index 45d341e20..ff78d1f88 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -68,8 +68,9 @@ void FillGpuBuffer(GLuint name, std::size_t size, MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_INFO_LOG_LENGTH, &max_length)); std::vector error_log(max_length); - glGetShaderInfoLog(shader, max_length, &max_length, error_log.data()); - glDeleteShader(shader); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderInfoLog, shader, max_length, + &max_length, error_log.data())); + MP_EXPECT_OK(TFLITE_GPU_CALL_GL(glDeleteShader, shader)); FAIL() << error_log.data(); return; } From e3131d7d7856771def3c1c141720ca311ed0f3d9 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 4 Jan 2023 10:31:04 -0800 Subject: [PATCH 246/346] Internal change PiperOrigin-RevId: 499521620 --- mediapipe/model_maker/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/model_maker/setup.py b/mediapipe/model_maker/setup.py index ea193db94..7114e2080 100644 --- a/mediapipe/model_maker/setup.py +++ b/mediapipe/model_maker/setup.py @@ -132,9 +132,9 @@ setuptools.setup( 'Operating System :: MacOS :: MacOS X', 'Operating System :: Microsoft :: Windows', 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3 :: Only', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', From 24cc0672c47b0b2fac28bbc8434e93a9fccb47ad Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 4 Jan 2023 10:57:33 -0800 Subject: [PATCH 247/346] Internal change PiperOrigin-RevId: 499529022 --- mediapipe/examples/desktop/autoflip/BUILD | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/examples/desktop/autoflip/BUILD b/mediapipe/examples/desktop/autoflip/BUILD index 0e28746dc..340205caa 100644 --- a/mediapipe/examples/desktop/autoflip/BUILD +++ b/mediapipe/examples/desktop/autoflip/BUILD @@ -18,6 +18,8 @@ licenses(["notice"]) package(default_visibility = [ "//mediapipe/examples:__subpackages__", + "//photos/editing/mobile/mediapipe/calculators:__subpackages__", + "//photos/editing/mobile/mediapipe/proto:__subpackages__", ]) proto_library( @@ -45,6 +47,8 @@ mediapipe_cc_proto_library( cc_deps = ["//mediapipe/framework:calculator_cc_proto"], visibility = [ "//mediapipe/examples:__subpackages__", + "//photos/editing/mobile/mediapipe/calculators:__pkg__", + "//photos/editing/mobile/mediapipe/calculators:__subpackages__", ], deps = [":autoflip_messages_proto"], ) From 43bf02443c1b8b7f237c9f7ef408da5cb56619b8 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 4 Jan 2023 17:31:48 -0800 Subject: [PATCH 248/346] Option to remove overlapping values computed for different timestamps. PiperOrigin-RevId: 499635143 --- .../tensor_to_vector_int_calculator.cc | 20 +++++++ ...sor_to_vector_int_calculator_options.proto | 4 ++ .../tensor_to_vector_int_calculator_test.cc | 53 ++++++++++++++++++- 3 files changed, 76 insertions(+), 1 deletion(-) diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc index 2f4ff28cf..f92ddf08d 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc @@ -37,8 +37,10 @@ class TensorToVectorIntCalculator : public CalculatorBase { private: void TokenizeVector(std::vector* vector) const; + void RemoveOverlapVector(std::vector* vector); TensorToVectorIntCalculatorOptions options_; + int32_t overlapping_values_; }; REGISTER_CALCULATOR(TensorToVectorIntCalculator); @@ -66,6 +68,7 @@ absl::Status TensorToVectorIntCalculator::GetContract(CalculatorContract* cc) { absl::Status TensorToVectorIntCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); + overlapping_values_ = 0; // Inform mediapipe that this calculator produces an output at time t for // each input received at time t (i.e. this calculator does not buffer @@ -106,6 +109,7 @@ absl::Status TensorToVectorIntCalculator::Process(CalculatorContext* cc) { } } TokenizeVector(&instance_output); + RemoveOverlapVector(&instance_output); } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); } else { @@ -128,12 +132,28 @@ absl::Status TensorToVectorIntCalculator::Process(CalculatorContext* cc) { } } TokenizeVector(output.get()); + RemoveOverlapVector(output.get()); cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); } return absl::OkStatus(); } +void TensorToVectorIntCalculator::RemoveOverlapVector( + std::vector* vector) { + if (options_.overlap() <= 0) { + return; + } + if (overlapping_values_ > 0) { + if (vector->size() < overlapping_values_) { + vector->clear(); + } else { + vector->erase(vector->begin(), vector->begin() + overlapping_values_); + } + } + overlapping_values_ = options_.overlap(); +} + void TensorToVectorIntCalculator::TokenizeVector( std::vector* vector) const { if (!options_.tensor_is_token()) { diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto index 9da3298b9..76b9be952 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto @@ -36,4 +36,8 @@ message TensorToVectorIntCalculatorOptions { optional bool tensor_is_token = 3 [default = false]; // Threshold for the token generation. optional float token_threshold = 4 [default = 0.5]; + + // Values which overlap between timely following vectors. They are removed + // from the output to reduce redundancy. + optional int32 overlap = 5 [default = 0]; } diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc index 60c0d47ec..406c2c1a7 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc @@ -28,7 +28,8 @@ namespace tf = ::tensorflow; class TensorToVectorIntCalculatorTest : public ::testing::Test { protected: void SetUpRunner(const bool tensor_is_2d, const bool flatten_nd, - const bool tensor_is_token = false) { + const bool tensor_is_token = false, + const int32_t overlap = 0) { CalculatorGraphConfig::Node config; config.set_calculator("TensorToVectorIntCalculator"); config.add_input_stream("input_tensor"); @@ -38,6 +39,7 @@ class TensorToVectorIntCalculatorTest : public ::testing::Test { options->set_tensor_is_2d(tensor_is_2d); options->set_flatten_nd(flatten_nd); options->set_tensor_is_token(tensor_is_token); + options->set_overlap(overlap); runner_ = absl::make_unique(config); } @@ -188,5 +190,54 @@ TEST_F(TensorToVectorIntCalculatorTest, FlattenShouldTakeAllDimensions) { } } +TEST_F(TensorToVectorIntCalculatorTest, Overlap) { + SetUpRunner(false, false, false, 2); + for (int time = 0; time < 3; ++time) { + const tf::TensorShape tensor_shape(std::vector{5}); + auto tensor = absl::make_unique(tf::DT_INT64, tensor_shape); + auto tensor_vec = tensor->vec(); + for (int i = 0; i < 5; ++i) { + // 2^i can be represented exactly in floating point numbers if 'i' is + // small. + tensor_vec(i) = static_cast(time + (1 << i)); + } + + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + } + + ASSERT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(3, output_packets.size()); + + { + // First vector in full. + int time = 0; + EXPECT_EQ(time, output_packets[time].Timestamp().Value()); + const std::vector& output_vector = + output_packets[time].Get>(); + + EXPECT_EQ(5, output_vector.size()); + for (int i = 0; i < 5; ++i) { + const int64 expected = static_cast(time + (1 << i)); + EXPECT_EQ(expected, output_vector[i]); + } + } + + // All following vectors the overlap removed + for (int time = 1; time < 3; ++time) { + EXPECT_EQ(time, output_packets[time].Timestamp().Value()); + const std::vector& output_vector = + output_packets[time].Get>(); + + EXPECT_EQ(3, output_vector.size()); + for (int i = 0; i < 3; ++i) { + const int64 expected = static_cast(time + (1 << (i + 2))); + EXPECT_EQ(expected, output_vector[i]); + } + } +} + } // namespace } // namespace mediapipe From 463cbb60eea6af436bbec6d13fceae0f65cdbe64 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 5 Jan 2023 07:55:57 -0800 Subject: [PATCH 249/346] Fix RGBA vs RGB selection when creating GLTexture. PiperOrigin-RevId: 499877590 --- .../calculators/tensor/image_to_tensor_converter_gl_buffer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc index a551e7f8d..eb1726aac 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc @@ -285,7 +285,7 @@ class GlProcessor : public ImageToTensorConverter { auto source_texture = gl_helper_.CreateSourceTexture(input); tflite::gpu::gl::GlTexture input_texture( GL_TEXTURE_2D, source_texture.name(), - input_num_channels == 4 ? GL_RGB : GL_RGBA, + input_num_channels == 4 ? GL_RGBA : GL_RGB, source_texture.width() * source_texture.height() * input_num_channels * sizeof(uint8_t), /*layer=*/0, From 35293d88bcb35b87162fbbb40b76226677f98d3f Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Thu, 5 Jan 2023 08:54:25 -0800 Subject: [PATCH 250/346] Tensor: move into tensor sub-directory. PiperOrigin-RevId: 499896489 --- mediapipe/framework/formats/BUILD | 2 +- mediapipe/framework/formats/tensor.h | 2 +- mediapipe/framework/formats/tensor/BUILD | 24 +++++++++++++++++++ .../{tensor_internal.h => tensor/internal.h} | 0 .../framework/formats/tensor_ahwb_gpu_test.cc | 2 +- 5 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 mediapipe/framework/formats/tensor/BUILD rename mediapipe/framework/formats/{tensor_internal.h => tensor/internal.h} (100%) diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index cce7e5bd0..371f23ed1 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -430,7 +430,7 @@ cc_library( ], hdrs = [ "tensor.h", - "tensor_internal.h", + "//mediapipe/framework/formats/tensor:internal.h", ], copts = select({ "//mediapipe:apple": [ diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 0f19bb5ee..4a952ae09 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -26,7 +26,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" -#include "mediapipe/framework/formats/tensor_internal.h" +#include "mediapipe/framework/formats/tensor/internal.h" #include "mediapipe/framework/port.h" #if MEDIAPIPE_METAL_ENABLED diff --git a/mediapipe/framework/formats/tensor/BUILD b/mediapipe/framework/formats/tensor/BUILD new file mode 100644 index 000000000..c634b0dda --- /dev/null +++ b/mediapipe/framework/formats/tensor/BUILD @@ -0,0 +1,24 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//visibility:public"], + features = ["-layering_check"], +) + +licenses(["notice"]) + +exports_files([ + "internal.h", +]) diff --git a/mediapipe/framework/formats/tensor_internal.h b/mediapipe/framework/formats/tensor/internal.h similarity index 100% rename from mediapipe/framework/formats/tensor_internal.h rename to mediapipe/framework/formats/tensor/internal.h diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index ff78d1f88..b06bd3ef2 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -7,7 +7,7 @@ #include #include "mediapipe/framework/formats/tensor.h" -#include "mediapipe/framework/formats/tensor_data_types.h" +#include "mediapipe/framework/formats/tensor/views/data_types.h" #include "mediapipe/gpu/gpu_test_base.h" #include "mediapipe/gpu/shader_util.h" #include "tensorflow/lite/delegates/gpu/gl/gl_call.h" From 81a46bb31a5da15a0ddd7123b92499c6ca14dc86 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 5 Jan 2023 09:12:06 -0800 Subject: [PATCH 251/346] Internal change PiperOrigin-RevId: 499902323 --- mediapipe/web/graph_runner/graph_runner.ts | 73 ++++++++++++++-------- 1 file changed, 46 insertions(+), 27 deletions(-) diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index ef866bc91..644d74918 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -73,10 +73,11 @@ export declare interface WasmModule { // Wasm Module output listener entrypoints. Also built as part of // gl_graph_runner_internal_multi_input. - simpleListeners?: {[outputStreamName: string]: (data: unknown) => void}; + simpleListeners?: + {[outputStreamName: string]: (data: unknown, timestamp: number) => void}; vectorListeners?: { [outputStreamName: string]: ( - data: unknown, index: number, length: number) => void + data: unknown, index: number, length: number, timestamp: number) => void }; _attachBoolListener: (streamNamePtr: number) => void; _attachBoolVectorListener: (streamNamePtr: number) => void; @@ -418,10 +419,12 @@ export class GraphRunner { * Ensures existence of the simple listeners table and registers the callback. * Intended for internal usage. */ - setListener(outputStreamName: string, callbackFcn: (data: T) => void) { + setListener( + outputStreamName: string, + callbackFcn: (data: T, timestamp: number) => void) { this.wasmModule.simpleListeners = this.wasmModule.simpleListeners || {}; this.wasmModule.simpleListeners[outputStreamName] = - callbackFcn as (data: unknown) => void; + callbackFcn as (data: unknown, timestamp: number) => void; } /** @@ -429,11 +432,12 @@ export class GraphRunner { * Intended for internal usage. */ setVectorListener( - outputStreamName: string, callbackFcn: (data: T[]) => void) { + outputStreamName: string, + callbackFcn: (data: T[], timestamp: number) => void) { let buffer: T[] = []; this.wasmModule.vectorListeners = this.wasmModule.vectorListeners || {}; this.wasmModule.vectorListeners[outputStreamName] = - (data: unknown, index: number, length: number) => { + (data: unknown, index: number, length: number, timestamp: number) => { // The Wasm listener gets invoked once for each element. Once we // receive all elements, we invoke the registered callback with the // full array. @@ -442,7 +446,7 @@ export class GraphRunner { // Invoke the user callback directly, as the Wasm layer may clean up // the underlying data elements once we leave the scope of the // listener. - callbackFcn(buffer); + callbackFcn(buffer, timestamp); buffer = []; } }; @@ -740,7 +744,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachBoolListener( - outputStreamName: string, callbackFcn: (data: boolean) => void): void { + outputStreamName: string, + callbackFcn: (data: boolean, timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -760,7 +765,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachBoolVectorListener( - outputStreamName: string, callbackFcn: (data: boolean[]) => void): void { + outputStreamName: string, + callbackFcn: (data: boolean[], timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -780,7 +786,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachIntListener( - outputStreamName: string, callbackFcn: (data: number) => void): void { + outputStreamName: string, + callbackFcn: (data: number, timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -800,7 +807,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachIntVectorListener( - outputStreamName: string, callbackFcn: (data: number[]) => void): void { + outputStreamName: string, + callbackFcn: (data: number[], timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -820,7 +828,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachDoubleListener( - outputStreamName: string, callbackFcn: (data: number) => void): void { + outputStreamName: string, + callbackFcn: (data: number, timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -840,7 +849,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachDoubleVectorListener( - outputStreamName: string, callbackFcn: (data: number[]) => void): void { + outputStreamName: string, + callbackFcn: (data: number[], timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -860,7 +870,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachFloatListener( - outputStreamName: string, callbackFcn: (data: number) => void): void { + outputStreamName: string, + callbackFcn: (data: number, timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -880,7 +891,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachFloatVectorListener( - outputStreamName: string, callbackFcn: (data: number[]) => void): void { + outputStreamName: string, + callbackFcn: (data: number[], timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -900,7 +912,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachStringListener( - outputStreamName: string, callbackFcn: (data: string) => void): void { + outputStreamName: string, + callbackFcn: (data: string, timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -920,7 +933,8 @@ export class GraphRunner { * should not perform overly complicated (or any async) behavior. */ attachStringVectorListener( - outputStreamName: string, callbackFcn: (data: string[]) => void): void { + outputStreamName: string, + callbackFcn: (data: string[], timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -950,7 +964,8 @@ export class GraphRunner { * with it). */ attachProtoListener( - outputStreamName: string, callbackFcn: (data: Uint8Array) => void, + outputStreamName: string, + callbackFcn: (data: Uint8Array, timestamp: number) => void, makeDeepCopy?: boolean): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -984,7 +999,8 @@ export class GraphRunner { * with it). */ attachProtoVectorListener( - outputStreamName: string, callbackFcn: (data: Uint8Array[]) => void, + outputStreamName: string, + callbackFcn: (data: Uint8Array[], timestamp: number) => void, makeDeepCopy?: boolean): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -1017,8 +1033,10 @@ export class GraphRunner { * up automatically by JS garbage collection whenever the user is finished * with it). */ - attachAudioListener(outputStreamName: string, - callbackFcn: (data: Float32Array) => void, makeDeepCopy?: boolean): void { + attachAudioListener( + outputStreamName: string, + callbackFcn: (data: Float32Array, timestamp: number) => void, + makeDeepCopy?: boolean): void { if (!this.wasmModule._attachAudioListener) { console.warn( 'Attempting to use attachAudioListener without support for ' + @@ -1027,12 +1045,13 @@ export class GraphRunner { // Set up our TS listener to receive any packets for this stream, and // additionally reformat our Uint8Array into a Float32Array for the user. - this.setListener(outputStreamName, (data: Uint8Array) => { - // Should be very fast - const floatArray = - new Float32Array(data.buffer, data.byteOffset, data.length / 4); - callbackFcn(floatArray); - }); + this.setListener( + outputStreamName, (data: Uint8Array, timestamp: number) => { + // Should be very fast + const floatArray = + new Float32Array(data.buffer, data.byteOffset, data.length / 4); + callbackFcn(floatArray, timestamp); + }); // Tell our graph to listen for string packets on this stream. this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { From 667fd81ddc12be213c0091c73f4c71fe0e4e35b2 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 5 Jan 2023 11:40:59 -0800 Subject: [PATCH 252/346] Internal change PiperOrigin-RevId: 499956657 --- .../audio_classifier/audio_classifier_test.ts | 33 ++++++++++--------- .../audio_embedder/audio_embedder_test.ts | 11 ++++--- .../text_classifier/text_classifier_test.ts | 18 +++++----- .../text/text_embedder/text_embedder_test.ts | 7 ++-- .../gesture_recognizer_test.ts | 32 ++++++++++-------- .../hand_landmarker/hand_landmarker_test.ts | 14 ++++---- .../image_classifier/image_classifier_test.ts | 6 ++-- .../image_embedder/image_embedder_test.ts | 5 +-- .../object_detector/object_detector_test.ts | 5 +-- 9 files changed, 75 insertions(+), 56 deletions(-) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts index 2089f184f..b7bb158de 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts @@ -34,7 +34,8 @@ class AudioClassifierFake extends AudioClassifier implements attachListenerSpies: jasmine.Spy[] = []; graph: CalculatorGraphConfig|undefined; - private protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + private protoVectorListener: + ((binaryProtos: Uint8Array[], timestamp: number) => void)|undefined; private resultProtoVector: ClassificationResult[] = []; constructor() { @@ -59,8 +60,10 @@ class AudioClassifierFake extends AudioClassifier implements }); spyOn(this.graphRunner, 'finishProcessing').and.callFake(() => { if (!this.protoVectorListener) return; - this.protoVectorListener(this.resultProtoVector.map( - classificationResult => classificationResult.serializeBinary())); + this.protoVectorListener( + this.resultProtoVector.map( + classificationResult => classificationResult.serializeBinary()), + 1337); }); spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); @@ -138,12 +141,12 @@ describe('AudioClassifier', () => { classifcations.setHeadIndex(1); classifcations.setHeadName('headName'); let classificationList = new ClassificationList(); - let clasification = new Classification(); - clasification.setIndex(1); - clasification.setScore(0.2); - clasification.setDisplayName('displayName'); - clasification.setLabel('categoryName'); - classificationList.addClassification(clasification); + let classification = new Classification(); + classification.setIndex(1); + classification.setScore(0.2); + classification.setDisplayName('displayName'); + classification.setLabel('categoryName'); + classificationList.addClassification(classification); classifcations.setClassificationList(classificationList); classificationResult.addClassifications(classifcations); resultProtoVector.push(classificationResult); @@ -152,10 +155,10 @@ describe('AudioClassifier', () => { classificationResult.setTimestampMs(1); classifcations = new Classifications(); classificationList = new ClassificationList(); - clasification = new Classification(); - clasification.setIndex(2); - clasification.setScore(0.3); - classificationList.addClassification(clasification); + classification = new Classification(); + classification.setIndex(2); + classification.setScore(0.3); + classificationList.addClassification(classification); classifcations.setClassificationList(classificationList); classificationResult.addClassifications(classifcations); resultProtoVector.push(classificationResult); @@ -191,8 +194,8 @@ describe('AudioClassifier', () => { const classificationResult = new ClassificationResult(); const classifcations = new Classifications(); const classificationList = new ClassificationList(); - const clasification = new Classification(); - classificationList.addClassification(clasification); + const classification = new Classification(); + classificationList.addClassification(classification); classifcations.setClassificationList(classificationList); classificationResult.addClassifications(classifcations); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts index dde61a6e9..a8a2b232b 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts @@ -34,8 +34,10 @@ class AudioEmbedderFake extends AudioEmbedder implements MediapipeTasksFake { attachListenerSpies: jasmine.Spy[] = []; fakeWasmModule: SpyWasmModule; - protoListener: ((binaryProto: Uint8Array) => void)|undefined; - protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + protoListener: + ((binaryProto: Uint8Array, timestamp: number) => void)|undefined; + protoVectorListener: + ((binaryProtos: Uint8Array[], timestamp: number) => void)|undefined; constructor() { super(createSpyWasmModule(), /* glCanvas= */ null); @@ -163,7 +165,7 @@ describe('AudioEmbedder', () => { audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(audioEmbedder); // Pass the test data to our listener - audioEmbedder.protoListener!(resultProto.serializeBinary()); + audioEmbedder.protoListener!(resultProto.serializeBinary(), 1337); }); // Invoke the audio embedder @@ -175,7 +177,8 @@ describe('AudioEmbedder', () => { audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(audioEmbedder); // Pass the test data to our listener - audioEmbedder.protoVectorListener!([resultProto.serializeBinary()]); + audioEmbedder.protoVectorListener! + ([resultProto.serializeBinary()], 1337); }); // Invoke the audio embedder diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts index 5578362cb..d9eb14865 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts @@ -32,7 +32,8 @@ class TextClassifierFake extends TextClassifier implements MediapipeTasksFake { attachListenerSpies: jasmine.Spy[] = []; graph: CalculatorGraphConfig|undefined; fakeWasmModule: SpyWasmModule; - protoListener: ((binaryProto: Uint8Array) => void)|undefined; + protoListener: + ((binaryProto: Uint8Array, timestamp: number) => void)|undefined; constructor() { super(createSpyWasmModule(), /* glCanvas= */ null); @@ -118,19 +119,20 @@ describe('TextClassifier', () => { classifcations.setHeadIndex(1); classifcations.setHeadName('headName'); const classificationList = new ClassificationList(); - const clasification = new Classification(); - clasification.setIndex(1); - clasification.setScore(0.2); - clasification.setDisplayName('displayName'); - clasification.setLabel('categoryName'); - classificationList.addClassification(clasification); + const classification = new Classification(); + classification.setIndex(1); + classification.setScore(0.2); + classification.setDisplayName('displayName'); + classification.setLabel('categoryName'); + classificationList.addClassification(classification); classifcations.setClassificationList(classificationList); classificationResult.addClassifications(classifcations); // Pass the test data to our listener textClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(textClassifier); - textClassifier.protoListener!(classificationResult.serializeBinary()); + textClassifier.protoListener! + (classificationResult.serializeBinary(), 1337); }); // Invoke the text classifier diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts index 2804e4deb..e26b85bf4 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts @@ -31,7 +31,8 @@ class TextEmbedderFake extends TextEmbedder implements MediapipeTasksFake { graph: CalculatorGraphConfig|undefined; attachListenerSpies: jasmine.Spy[] = []; fakeWasmModule: SpyWasmModule; - protoListener: ((binaryProtos: Uint8Array) => void)|undefined; + protoListener: + ((binaryProtos: Uint8Array, timestamp: number) => void)|undefined; constructor() { super(createSpyWasmModule(), /* glCanvas= */ null); @@ -120,7 +121,7 @@ describe('TextEmbedder', () => { // Pass the test data to our listener textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(textEmbedder); - textEmbedder.protoListener!(resultProto.serializeBinary()); + textEmbedder.protoListener!(resultProto.serializeBinary(), 1337); }); // Invoke the text embedder @@ -149,7 +150,7 @@ describe('TextEmbedder', () => { // Pass the test data to our listener textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(textEmbedder); - textEmbedder.protoListener!(resultProto.serializeBinary()); + textEmbedder.protoListener!(resultProto.serializeBinary(), 1337); }); // Invoke the text embedder diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index ee51fd32a..3611c3a7d 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -26,7 +26,7 @@ import {GestureRecognizer, GestureRecognizerOptions} from './gesture_recognizer' // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern -type ProtoListener = ((binaryProtos: Uint8Array[]) => void); +type ProtoListener = ((binaryProtos: Uint8Array[], timestamp: number) => void); function createHandednesses(): Uint8Array[] { const handsProto = new ClassificationList(); @@ -254,11 +254,13 @@ describe('GestureRecognizer', () => { // Pass the test data to our listener gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(gestureRecognizer); - gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('hand_landmarks')! + (createLandmarks(), 1337); gestureRecognizer.listeners.get('world_hand_landmarks')! - (createWorldLandmarks()); - gestureRecognizer.listeners.get('handedness')!(createHandednesses()); - gestureRecognizer.listeners.get('hand_gestures')!(createGestures()); + (createWorldLandmarks(), 1337); + gestureRecognizer.listeners.get('handedness')! + (createHandednesses(), 1337); + gestureRecognizer.listeners.get('hand_gestures')!(createGestures(), 1337); }); // Invoke the gesture recognizer @@ -290,11 +292,13 @@ describe('GestureRecognizer', () => { it('clears results between invoations', async () => { // Pass the test data to our listener gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { - gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('hand_landmarks')! + (createLandmarks(), 1337); gestureRecognizer.listeners.get('world_hand_landmarks')! - (createWorldLandmarks()); - gestureRecognizer.listeners.get('handedness')!(createHandednesses()); - gestureRecognizer.listeners.get('hand_gestures')!(createGestures()); + (createWorldLandmarks(), 1337); + gestureRecognizer.listeners.get('handedness')! + (createHandednesses(), 1337); + gestureRecognizer.listeners.get('hand_gestures')!(createGestures(), 1337); }); // Invoke the gesture recognizer twice @@ -310,11 +314,13 @@ describe('GestureRecognizer', () => { // Pass the test data to our listener gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(gestureRecognizer); - gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('hand_landmarks')! + (createLandmarks(), 1337); gestureRecognizer.listeners.get('world_hand_landmarks')! - (createWorldLandmarks()); - gestureRecognizer.listeners.get('handedness')!(createHandednesses()); - gestureRecognizer.listeners.get('hand_gestures')!([]); + (createWorldLandmarks(), 1337); + gestureRecognizer.listeners.get('handedness')! + (createHandednesses(), 1337); + gestureRecognizer.listeners.get('hand_gestures')!([], 1337); }); // Invoke the gesture recognizer diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts index 76e77b4bf..1a813c6f7 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -27,7 +27,7 @@ import {HandLandmarkerOptions} from './hand_landmarker_options'; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern -type ProtoListener = ((binaryProtos: Uint8Array[]) => void); +type ProtoListener = ((binaryProtos: Uint8Array[], timestamp: number) => void); function createHandednesses(): Uint8Array[] { const handsProto = new ClassificationList(); @@ -206,10 +206,10 @@ describe('HandLandmarker', () => { // Pass the test data to our listener handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(handLandmarker); - handLandmarker.listeners.get('hand_landmarks')!(createLandmarks()); + handLandmarker.listeners.get('hand_landmarks')!(createLandmarks(), 1337); handLandmarker.listeners.get('world_hand_landmarks')! - (createWorldLandmarks()); - handLandmarker.listeners.get('handedness')!(createHandednesses()); + (createWorldLandmarks(), 1337); + handLandmarker.listeners.get('handedness')!(createHandednesses(), 1337); }); // Invoke the hand landmarker @@ -235,10 +235,10 @@ describe('HandLandmarker', () => { it('clears results between invoations', async () => { // Pass the test data to our listener handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { - handLandmarker.listeners.get('hand_landmarks')!(createLandmarks()); + handLandmarker.listeners.get('hand_landmarks')!(createLandmarks(), 1337); handLandmarker.listeners.get('world_hand_landmarks')! - (createWorldLandmarks()); - handLandmarker.listeners.get('handedness')!(createHandednesses()); + (createWorldLandmarks(), 1337); + handLandmarker.listeners.get('handedness')!(createHandednesses(), 1337); }); // Invoke the hand landmarker twice diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts index da4a01d02..60595310e 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts @@ -35,7 +35,8 @@ class ImageClassifierFake extends ImageClassifier implements graph: CalculatorGraphConfig|undefined; fakeWasmModule: SpyWasmModule; - protoListener: ((binaryProto: Uint8Array) => void)|undefined; + protoListener: + ((binaryProto: Uint8Array, timestamp: number) => void)|undefined; constructor() { super(createSpyWasmModule(), /* glCanvas= */ null); @@ -128,7 +129,8 @@ describe('ImageClassifier', () => { // Pass the test data to our listener imageClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(imageClassifier); - imageClassifier.protoListener!(classificationResult.serializeBinary()); + imageClassifier.protoListener! + (classificationResult.serializeBinary(), 1337); }); // Invoke the image classifier diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts index b63bb374c..01ec751e3 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts @@ -31,7 +31,8 @@ class ImageEmbedderFake extends ImageEmbedder implements MediapipeTasksFake { graph: CalculatorGraphConfig|undefined; attachListenerSpies: jasmine.Spy[] = []; fakeWasmModule: SpyWasmModule; - protoListener: ((binaryProtos: Uint8Array) => void)|undefined; + protoListener: + ((binaryProtos: Uint8Array, timestamp: number) => void)|undefined; constructor() { super(createSpyWasmModule(), /* glCanvas= */ null); @@ -125,7 +126,7 @@ describe('ImageEmbedder', () => { // Pass the test data to our listener imageEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(imageEmbedder); - imageEmbedder.protoListener!(resultProto.serializeBinary()); + imageEmbedder.protoListener!(resultProto.serializeBinary(), 1337); }); }); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts index 43b7035d5..5bfb74ab6 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -35,7 +35,8 @@ class ObjectDetectorFake extends ObjectDetector implements MediapipeTasksFake { graph: CalculatorGraphConfig|undefined; fakeWasmModule: SpyWasmModule; - protoListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + protoListener: + ((binaryProtos: Uint8Array[], timestamp: number) => void)|undefined; constructor() { super(createSpyWasmModule(), /* glCanvas= */ null); @@ -200,7 +201,7 @@ describe('ObjectDetector', () => { // Pass the test data to our listener objectDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(objectDetector); - objectDetector.protoListener!(detectionProtos); + objectDetector.protoListener!(detectionProtos, 1337); }); // Invoke the object detector From 33df6c042fc3d78f525a6f7c86b10d67d091ddf1 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:07:11 +0530 Subject: [PATCH 253/346] Added iOS result containers for classification tasks --- .../tasks/ios/components/containers/BUILD | 32 +++++ .../containers/sources/MPPCategory.h | 68 ++++++++++ .../containers/sources/MPPCategory.m | 33 +++++ .../sources/MPPClassificationResult.h | 116 ++++++++++++++++++ .../sources/MPPClassificationResult.m | 51 ++++++++ 5 files changed, 300 insertions(+) create mode 100644 mediapipe/tasks/ios/components/containers/BUILD create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPCategory.h create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPCategory.m create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m diff --git a/mediapipe/tasks/ios/components/containers/BUILD b/mediapipe/tasks/ios/components/containers/BUILD new file mode 100644 index 000000000..9d82fc55a --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/BUILD @@ -0,0 +1,32 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCategory", + srcs = ["sources/MPPCategory.m"], + hdrs = ["sources/MPPCategory.h"], +) + +objc_library( + name = "MPPClassificationResult", + srcs = ["sources/MPPClassificationResult.m"], + hdrs = ["sources/MPPClassificationResult.h"], + deps = [ + ":MPPCategory", + ], +) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h new file mode 100644 index 000000000..648725d95 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h @@ -0,0 +1,68 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * Category is a util class that contains a label, its display name, a float value as score, and the + * index of the label in the corresponding label file. Typically it's used as the result of + * classification tasks. + **/ +NS_SWIFT_NAME(ClassificationCategory) +@interface MPPCategory : NSObject + +/** + * The index of the label in the corresponding label file. Set to -1 if the index is + * not set. + **/ +@property(nonatomic, readonly) NSInteger index; + +/** Confidence score for this class . **/ +@property(nonatomic, readonly) float score; + +/** The label of this category object. **/ +@property(nonatomic, readonly, nullable) NSString *categoryName; + +/** + * The display name of the label, which may be translated for different locales. For example, a + * label, "apple", may be translated into Spanish for display purpose, so that the display name is + * "manzana". + **/ +@property(nonatomic, readonly, nullable) NSString *displayName; + +/** + * Initializes a new `MPPCategory` with the given index, score, category name and display name. + * + * @param index The index of the label in the corresponding label file. + * @param score The probability score of this label category. + * @param categoryName The label of this category object. + * @param displayName The display name of the label. + * + * @return An instance of `MPPCategory` initialized with the given index, score, category name and + * display name. + **/ +- (instancetype)initWithIndex:(NSInteger)index + score:(float)score + categoryName:(nullable NSString *)categoryName + displayName:(nullable NSString *)displayName NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m new file mode 100644 index 000000000..824fae65e --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m @@ -0,0 +1,33 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" + +@implementation MPPCategory + +- (instancetype)initWithIndex:(NSInteger)index + score:(float)score + categoryName:(nullable NSString *)categoryName + displayName:(nullable NSString *)displayName { + self = [super init]; + if (self) { + _index = index; + _score = score; + _categoryName = categoryName; + _displayName = displayName; + } + return self; +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h new file mode 100644 index 000000000..9c8b9bd2e --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h @@ -0,0 +1,116 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Represents the list of classification for a given classifier head. Typically used as a result + * for classification tasks. + **/ +NS_SWIFT_NAME(Classifications) +@interface MPPClassifications : NSObject + +/** + * The index of the classifier head these entries refer to. This is useful for multi-head models. + **/ +@property(nonatomic, readonly) NSInteger headIndex; + +/** The optional name of the classifier head, which is the corresponding tensor metadata name. **/ +@property(nonatomic, readonly, nullable) NSString *headName; + +/** An array of `MPPCategory` objects containing the predicted categories. **/ +@property(nonatomic, readonly) NSArray *categories; + +/** + * Initializes a new `MPPClassifications` object with the given head index and array of categories. + * Head name is initialized to `nil`. + * + * @param headIndex The index of the classifier head. + * @param categories An array of `MPPCategory` objects containing the predicted categories. + * + * @return An instance of `MPPClassifications` initialized with the given head index and + * array of categories. + **/ +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + categories:(NSArray *)categories; + +/** + * Initializes a new `MPPClassifications` with the given head index, head name and array of + * categories. + * + * @param headIndex The index of the classifier head. + * @param headName The name of the classifier head, which is the corresponding tensor metadata + * name. + * @param categories An array of `MPPCategory` objects containing the predicted categories. + * + * @return An object of `MPPClassifications` initialized with the given head index, head name and + * array of categories. + **/ +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + headName:(nullable NSString *)headName + categories:(NSArray *)categories NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +/** + * Represents the classification results of a model. Typically used as a result for classification + * tasks. + **/ +NS_SWIFT_NAME(ClassificationResult) +@interface MPPClassificationResult : NSObject + +/** + * An Array of `MPPClassifications` objects containing the predicted categories for each head of + * the model. + **/ +@property(nonatomic, readonly) NSArray *classifications; + +/** + * The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to + * these results. If it is set to the value -1, it signifies the absence of a timestamp. This is + * only used for classification on time series (e.g. audio classification). In these use cases, the + * amount of data to process might exceed the maximum size that the model can process: to solve + * this, the input data is split into multiple chunks starting at different timestamps. + **/ +@property(nonatomic, readonly) NSInteger timestampMs; + +/** + * Initializes a new `MPPClassificationResult` with the given array of classifications and time + * stamp (in milliseconds). + * + * @param classifications An Array of `MPPClassifications` objects containing the predicted + * categories for each head of the model. + * @param timestampMs The timestamp (in milliseconds) of the start of the chunk of data + * corresponding to these results. + * + * @return An instance of `MPPClassificationResult` initialized with the given array of + * classifications and timestampMs. + **/ +- (instancetype)initWithClassifications:(NSArray *)classifications + timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m new file mode 100644 index 000000000..6d42d22ca --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m @@ -0,0 +1,51 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" + +@implementation MPPClassifications + +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + headName:(nullable NSString *)headName + categories:(NSArray *)categories { + self = [super init]; + if (self) { + _headIndex = headIndex; + _headName = headName; + _categories = categories; + } + return self; +} + +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + categories:(NSArray *)categories { + return [self initWithHeadIndex:headIndex headName:nil categories:categories]; +} + +@end + +@implementation MPPClassificationResult + +- (instancetype)initWithClassifications:(NSArray *)classifications + timestampMs:(NSInteger)timestampMs { + self = [super init]; + if (self) { + _classifications = classifications; + _timestampMs = timestampMs; + } + + return self; +} + +@end From 89aad67a877424d1715b0a0dbfb386cfd06e8c2a Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:07:50 +0530 Subject: [PATCH 254/346] Added iOS helpers for classification result containers --- .../ios/components/containers/utils/BUILD | 40 ++++++++++++ .../utils/sources/MPPCategory+Helpers.h | 26 ++++++++ .../utils/sources/MPPCategory+Helpers.mm | 43 +++++++++++++ .../sources/MPPClassificationResult+Helpers.h | 35 ++++++++++ .../MPPClassificationResult+Helpers.mm | 64 +++++++++++++++++++ 5 files changed, 208 insertions(+) create mode 100644 mediapipe/tasks/ios/components/containers/utils/BUILD create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm diff --git a/mediapipe/tasks/ios/components/containers/utils/BUILD b/mediapipe/tasks/ios/components/containers/utils/BUILD new file mode 100644 index 000000000..e4c76ac4b --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/BUILD @@ -0,0 +1,40 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCategoryHelpers", + srcs = ["sources/MPPCategory+Helpers.mm"], + hdrs = ["sources/MPPCategory+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/components/containers:MPPCategory", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + ], +) + +objc_library( + name = "MPPClassificationResultHelpers", + srcs = ["sources/MPPClassificationResult+Helpers.mm"], + hdrs = ["sources/MPPClassificationResult+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/components/containers:MPPClassificationResult", + ":MPPCategoryHelpers", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + ], +) diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h new file mode 100644 index 000000000..7580cfeeb --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h @@ -0,0 +1,26 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/formats/classification.pb.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPCategory (Helpers) + ++ (MPPCategory *)categoryWithProto:(const mediapipe::Classification &)classificationProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm new file mode 100644 index 000000000..1c6c951d0 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm @@ -0,0 +1,43 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h" + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +namespace { +using ClassificationProto = ::mediapipe::Classification; +} + +@implementation MPPCategory (Helpers) + ++ (MPPCategory *)categoryWithProto:(const ClassificationProto &)clasificationProto { + NSString *categoryName; + NSString *displayName; + + if (clasificationProto.has_label()) { + categoryName = [NSString stringWithCppString:clasificationProto.label()]; + } + + if (clasificationProto.has_display_name()) { + displayName = [NSString stringWithCppString:clasificationProto.display_name()]; + } + + return [[MPPCategory alloc] initWithIndex:clasificationProto.index() + score:clasificationProto.score() + categoryName:categoryName + displayName:displayName]; +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h new file mode 100644 index 000000000..fde436feb --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h @@ -0,0 +1,35 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPClassifications (Helpers) + ++ (MPPClassifications *)classificationsWithProto: + (const mediapipe::tasks::components::containers::proto::Classifications &)classificationsProto; + +@end + +@interface MPPClassificationResult (Helpers) + ++ (MPPClassificationResult *)classificationResultWithProto: + (const mediapipe::tasks::components::containers::proto::ClassificationResult &) + classificationResultProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm new file mode 100644 index 000000000..78bc0b6a3 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm @@ -0,0 +1,64 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h" + +namespace { +using ClassificationsProto = ::mediapipe::tasks::components::containers::proto::Classifications; +using ClassificationResultProto = + ::mediapipe::tasks::components::containers::proto::ClassificationResult; +} // namespace + +@implementation MPPClassifications (Helpers) + ++ (MPPClassifications *)classificationsWithProto: + (const ClassificationsProto &)classificationsProto { + NSMutableArray *categories = [NSMutableArray arrayWithCapacity:(NSUInteger)classificationsProto.classification_list().classification_size()]; + for (const auto &classification : classificationsProto.classification_list().classification()) { + [categories addObject:[MPPCategory categoryWithProto:classification]]; + } + + NSString *headName; + if (classificationsProto.has_head_name()) { + headName = [NSString stringWithCppString:classificationsProto.head_name()]; + } + + return [[MPPClassifications alloc] initWithHeadIndex:(NSInteger)classificationsProto.head_index() + headName:headName + categories:categories]; +} + +@end + +@implementation MPPClassificationResult (Helpers) + ++ (MPPClassificationResult *)classificationResultWithProto: + (const ClassificationResultProto &)classificationResultProto { + NSMutableArray *classifications = [NSMutableArray arrayWithCapacity:(NSUInteger)classificationResultProto.classifications_size()]; + for (const auto &classificationsProto : classificationResultProto.classifications()) { + [classifications addObject:[MPPClassifications classificationsWithProto:classificationsProto]]; + } + + NSInteger timestampMs; + if (classificationResultProto.has_timestamp_ms()) { + timestampMs = (NSInteger)classificationResultProto.timestamp_ms(); + } + + return [[MPPClassificationResult alloc] initWithClassifications:classifications timestampMs:timestampMs];; +} + +@end From 8f74a175d831bf09510afafc0e7ecbfb4f281a65 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:08:06 +0530 Subject: [PATCH 255/346] Removed MPPClassifierOptions and helpers --- .../tasks/ios/components/processors/BUILD | 23 ------- .../processors/sources/MPPClassifierOptions.h | 60 ------------------- .../processors/sources/MPPClassifierOptions.m | 40 ------------- .../ios/components/processors/utils/BUILD | 28 --------- .../sources/MPPClassifierOptions+Helpers.h | 26 -------- .../sources/MPPClassifierOptions+Helpers.mm | 43 ------------- 6 files changed, 220 deletions(-) delete mode 100644 mediapipe/tasks/ios/components/processors/BUILD delete mode 100644 mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h delete mode 100644 mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m delete mode 100644 mediapipe/tasks/ios/components/processors/utils/BUILD delete mode 100644 mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h delete mode 100644 mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm diff --git a/mediapipe/tasks/ios/components/processors/BUILD b/mediapipe/tasks/ios/components/processors/BUILD deleted file mode 100644 index 165145076..000000000 --- a/mediapipe/tasks/ios/components/processors/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -objc_library( - name = "MPPClassifierOptions", - srcs = ["sources/MPPClassifierOptions.m"], - hdrs = ["sources/MPPClassifierOptions.h"], -) diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h deleted file mode 100644 index 13dca4030..000000000 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -NS_ASSUME_NONNULL_BEGIN - -/** - * Holds settings for any single iOS MediaPipe classification task. - */ -NS_SWIFT_NAME(ClassifierOptions) -@interface MPPClassifierOptions : NSObject - -/** - * The locale to use for display names specified through the TFLite Model - * Metadata, if any. Defaults to English. - */ -@property(nonatomic, copy) NSString *displayNamesLocale; - -/** - * The maximum number of top-scored classification results to return. If < 0, - * all available results will be returned. If 0, an invalid argument error is - * returned. - */ -@property(nonatomic) NSInteger maxResults; - -/** - * Score threshold to override the one provided in the model metadata (if any). - * Results below this value are rejected. - */ -@property(nonatomic) float scoreThreshold; - -/** - * The allowlist of category names. If non-empty, detection results whose - * category name is not in this set will be filtered out. Duplicate or unknown - * category names are ignored. Mutually exclusive with categoryDenylist. - */ -@property(nonatomic, copy) NSArray *categoryAllowlist; - -/** - * The denylist of category names. If non-empty, detection results whose - * category name is in this set will be filtered out. Duplicate or unknown - * category names are ignored. Mutually exclusive with categoryAllowlist. - */ -@property(nonatomic, copy) NSArray *categoryDenylist; - -@end - -NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m deleted file mode 100644 index 01f498184..000000000 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" - -@implementation MPPClassifierOptions - -- (instancetype)init { - self = [super init]; - if (self) { - _maxResults = -1; - _scoreThreshold = 0; - } - return self; -} - -- (id)copyWithZone:(NSZone *)zone { - MPPClassifierOptions *classifierOptions = [[MPPClassifierOptions alloc] init]; - - classifierOptions.scoreThreshold = self.scoreThreshold; - classifierOptions.maxResults = self.maxResults; - classifierOptions.categoryDenylist = self.categoryDenylist; - classifierOptions.categoryAllowlist = self.categoryAllowlist; - classifierOptions.displayNamesLocale = self.displayNamesLocale; - - return classifierOptions; -} - -@end diff --git a/mediapipe/tasks/ios/components/processors/utils/BUILD b/mediapipe/tasks/ios/components/processors/utils/BUILD deleted file mode 100644 index 5344c5fdf..000000000 --- a/mediapipe/tasks/ios/components/processors/utils/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -objc_library( - name = "MPPClassifierOptionsHelpers", - srcs = ["sources/MPPClassifierOptions+Helpers.mm"], - hdrs = ["sources/MPPClassifierOptions+Helpers.h"], - deps = [ - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/ios/common/utils:NSStringHelpers", - "//mediapipe/tasks/ios/components/processors:MPPClassifierOptions", - ], -) diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h deleted file mode 100644 index e156020df..000000000 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" - -#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" - -NS_ASSUME_NONNULL_BEGIN - -@interface MPPClassifierOptions (Helpers) -- (void)copyToProto: - (mediapipe::tasks::components::processors::proto::ClassifierOptions *)classifierOptionsProto; -@end - -NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm deleted file mode 100644 index 24b54fd6a..000000000 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" -#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h" - -namespace { -using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto::ClassifierOptions; -} - -@implementation MPPClassifierOptions (Helpers) - -- (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto { - classifierOptionsProto->Clear(); - - if (self.displayNamesLocale) { - classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); - } - - classifierOptionsProto->set_max_results((int)self.maxResults); - classifierOptionsProto->set_score_threshold(self.scoreThreshold); - - for (NSString *category in self.categoryAllowlist) { - classifierOptionsProto->add_category_allowlist(category.cppString); - } - - for (NSString *category in self.categoryDenylist) { - classifierOptionsProto->add_category_denylist(category.cppString); - } -} - -@end From 4e38c7e623eaa7dc1219ab5291858e05703350c2 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:15:32 +0530 Subject: [PATCH 256/346] Updated documentation for MPPCommon.h --- mediapipe/tasks/ios/common/sources/MPPCommon.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/common/sources/MPPCommon.h b/mediapipe/tasks/ios/common/sources/MPPCommon.h index 7ce791d12..09a61e20d 100644 --- a/mediapipe/tasks/ios/common/sources/MPPCommon.h +++ b/mediapipe/tasks/ios/common/sources/MPPCommon.h @@ -18,7 +18,7 @@ NS_ASSUME_NONNULL_BEGIN /** * @enum MPPTasksErrorCode - * This enum specifies error codes for Mediapipe Task Library. + * This enum specifies error codes for MediaPipe Task Library. * It maintains a 1:1 mapping to MediaPipeTasksStatus of the C ++libray. */ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { From f37689fc33de306cead655cadfd283430fc5003f Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:15:53 +0530 Subject: [PATCH 257/346] Updated documentation for MPPCommonUtils.m --- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h | 2 +- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h index 5404a074d..69c28b916 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -18,7 +18,7 @@ NS_ASSUME_NONNULL_BEGIN -/** Error domain of Mediapipe Task related errors. */ +/** Error domain of MediaPipe Task related errors. */ extern NSString *const MPPTasksErrorDomain; /** Helper utility for the all tasks which encapsulates common functionality. */ diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 8234ac6d3..1a37f8465 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -96,7 +96,7 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; // The mapping to absl::Status::code() is done to generate a more specific error code than // MPPTasksErrorCodeError in cases when the payload can't be mapped to // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn - // returned without modification by Mediapipe cc library methods. + // returned without modification by MediaPipe cc library methods. if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) { switch (status.code()) { case absl::StatusCode::kInternal: From 27ce2ec00f0fd5526c186c9b92e570a3acdca58c Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:22:11 +0530 Subject: [PATCH 258/346] Updated C++ types to camel case in MPPTaskInfo --- .../tasks/ios/core/sources/MPPTaskInfo.mm | 55 ++++++++++--------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm index 5f2290497..ae6ed2a70 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm @@ -13,6 +13,7 @@ // limitations under the License. #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" + #import "mediapipe/tasks/ios/common/sources/MPPCommon.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" @@ -69,59 +70,59 @@ using ::mediapipe::InputStreamInfo; } - (CalculatorGraphConfig)generateGraphConfig { - CalculatorGraphConfig graph_config; + CalculatorGraphConfig graphConfig; - Node *task_subgraph_node = graph_config.add_node(); - task_subgraph_node->set_calculator(self.taskGraphName.cppString); - [self.taskOptions copyToProto:task_subgraph_node->mutable_options()]; + Node *taskSubgraphNode = graphConfig.add_node(); + taskSubgraphNode->set_calculator(self.taskGraphName.cppString); + [self.taskOptions copyToProto:taskSubgraphNode->mutable_options()]; for (NSString *outputStream in self.outputStreams) { - auto cpp_output_stream = std::string(outputStream.cppString); - task_subgraph_node->add_output_stream(cpp_output_stream); - graph_config.add_output_stream(cpp_output_stream); + auto cppOutputStream = std::string(outputStream.cppString); + taskSubgraphNode->add_output_stream(cppOutputStream); + graphConfig.add_output_stream(cppOutputStream); } if (!self.enableFlowLimiting) { for (NSString *inputStream in self.inputStreams) { - auto cpp_input_stream = inputStream.cppString; - task_subgraph_node->add_input_stream(cpp_input_stream); - graph_config.add_input_stream(cpp_input_stream); + auto cppInputStream = inputStream.cppString; + taskSubgraphNode->add_input_stream(cppInputStream); + graphConfig.add_input_stream(cppInputStream); } - return graph_config; + return graphConfig; } - Node *flow_limit_calculator_node = graph_config.add_node(); + Node *flowLimitCalculatorNode = graphConfig.add_node(); - flow_limit_calculator_node->set_calculator("FlowLimiterCalculator"); + flowLimitCalculatorNode->set_calculator("FlowLimiterCalculator"); - InputStreamInfo *input_stream_info = flow_limit_calculator_node->add_input_stream_info(); - input_stream_info->set_tag_index("FINISHED"); - input_stream_info->set_back_edge(true); + InputStreamInfo *inputStreamInfo = flowLimitCalculatorNode->add_input_stream_info(); + inputStreamInfo->set_tag_index("FINISHED"); + inputStreamInfo->set_back_edge(true); - FlowLimiterCalculatorOptions *flow_limit_calculator_options = - flow_limit_calculator_node->mutable_options()->MutableExtension( + FlowLimiterCalculatorOptions *flowLimitCalculatorOptions = + flowLimitCalculatorNode->mutable_options()->MutableExtension( FlowLimiterCalculatorOptions::ext); - flow_limit_calculator_options->set_max_in_flight(1); - flow_limit_calculator_options->set_max_in_queue(1); + flowLimitCalculatorOptions->set_max_in_flight(1); + flowLimitCalculatorOptions->set_max_in_queue(1); for (NSString *inputStream in self.inputStreams) { - graph_config.add_input_stream(inputStream.cppString); + graphConfig.add_input_stream(inputStream.cppString); NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream]; - flow_limit_calculator_node->add_input_stream(strippedInputStream.cppString); + flowLimitCalculatorNode->add_input_stream(strippedInputStream.cppString); NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream]; - task_subgraph_node->add_input_stream(taskInputStream.cppString); + taskSubgraphNode->add_input_stream(taskInputStream.cppString); NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream]; - flow_limit_calculator_node->add_output_stream(strippedTaskInputStream.cppString); + flowLimitCalculatorNode->add_output_stream(strippedTaskInputStream.cppString); } NSString *firstOutputStream = self.outputStreams[0]; - auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString; - flow_limit_calculator_node->add_input_stream(finished_output_stream); + auto finishedOutputStream = "FINISHED:" + firstOutputStream.cppString; + flowLimitCalculatorNode->add_input_stream(finishedOutputStream); - return graph_config; + return graphConfig; } + (NSString *)stripTagIndex:(NSString *)tagIndexName { From 61d16b284b9d0b063a61f31ae9565532f6e69798 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:23:22 +0530 Subject: [PATCH 259/346] Updated comments in MPPTaskOptions.h --- mediapipe/tasks/ios/core/sources/MPPTaskOptions.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h index ee2f7d032..e10678348 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h @@ -25,7 +25,7 @@ NS_SWIFT_NAME(TaskOptions) @interface MPPTaskOptions : NSObject /** - * Base options for configuring the Mediapipe task. + * Base options for configuring the MediaPipe task. */ @property(nonatomic, copy) MPPBaseOptions *baseOptions; From 16f9831c3fece5d9907868ea60c7bb7fb0c01cc5 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:23:37 +0530 Subject: [PATCH 260/346] Updated formatting in MPPTaskOptions.m --- mediapipe/tasks/ios/core/sources/MPPTaskOptions.m | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m index fe74517c3..ad11bbc6e 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m @@ -13,6 +13,7 @@ // limitations under the License. #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" + #import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" @implementation MPPTaskOptions From bc1b069edf818e9431697ceb040cc1c105984ef3 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:24:41 +0530 Subject: [PATCH 261/346] Updated property name in MPPTaskResult --- mediapipe/tasks/ios/core/sources/MPPTaskResult.h | 4 ++-- mediapipe/tasks/ios/core/sources/MPPTaskResult.m | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h index d15d4f258..4ee7b2fc6 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -26,11 +26,11 @@ NS_SWIFT_NAME(TaskResult) /** * Timestamp that is associated with the task result object. */ -@property(nonatomic, assign, readonly) long timestamp; +@property(nonatomic, assign, readonly) NSInteger timestampMs; - (instancetype)init NS_UNAVAILABLE; -- (instancetype)initWithTimestamp:(long)timestamp NS_DESIGNATED_INITIALIZER; +- (instancetype)initWithTimestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m index 7088eb246..6c08014ff 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m @@ -16,16 +16,16 @@ @implementation MPPTaskResult -- (instancetype)initWithTimestamp:(long)timestamp { +- (instancetype)initWithTimestampMs:(NSInteger)timestampMs { self = [super init]; if (self) { - _timestamp = timestamp; + _timestampMs = timestampMs; } return self; } - (id)copyWithZone:(NSZone *)zone { - return [[MPPTaskResult alloc] initWithTimestamp:self.timestamp]; + return [[MPPTaskResult alloc] initWithTimestampMs:self.timestampMs]; } @end From c6bae99a2fc120c4de58f352dc64b6dc0aff728b Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:25:56 +0530 Subject: [PATCH 262/346] Updated formatting in MPPTextPacketCreator.mm --- mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm index ca86e7a0b..fb59b363d 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm @@ -13,6 +13,7 @@ // limitations under the License. #import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" + #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" namespace { From b6bcc35adef1ea3d27af6f35488a1608a4670be5 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:36:15 +0530 Subject: [PATCH 263/346] Added provision for packets callback in iOS task runner --- mediapipe/tasks/ios/core/BUILD | 12 +++-- .../tasks/ios/core/sources/MPPTaskRunner.h | 52 +++++++++++++++++-- .../tasks/ios/core/sources/MPPTaskRunner.mm | 12 ++++- 3 files changed, 64 insertions(+), 12 deletions(-) diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index 434d20085..757e2d4cc 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -56,12 +56,12 @@ objc_library( deps = [ ":MPPTaskOptions", ":MPPTaskOptionsProtocol", - "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/tasks/ios/common:MPPCommon", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", ], ) @@ -88,8 +88,10 @@ objc_library( "-std=c++17", ], deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", + "//mediapipe/tasks/cc/core:task_runner", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", ], ) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h index 2b9f2ecdb..a1b1dfad4 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -20,23 +20,65 @@ NS_ASSUME_NONNULL_BEGIN /** - * This class is used to create and call appropriate methods on the C++ Task Runner. - */ + * This class is used to create and call appropriate methods on the C++ Task Runner to initialize, + * execute and terminate any MediaPipe task. + * + * An instance of the newly created C++ task runner will be stored until this class is destroyed. + * When methods are called for processing (performing inference), closing etc., on this class, + * internally the appropriate methods will be called on the C++ task runner instance to execute the + * appropriate actions. For each type of task, a subclass of this class must be defined to add any + * additional functionality. For eg:, vision tasks must create an `MPPVisionTaskRunner` and provide + * additional functionality. An instance of `MPPVisionTaskRunner` can in turn be used by the each + * vision task for creation and execution of the task. Please see the documentation for the C++ Task + * Runner for more details on how the taks runner operates. + **/ @interface MPPTaskRunner : NSObject /** - * Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto. + * Initializes a new `MPPTaskRunner` with the MediaPipe calculator configuration proto and an + * optional C++ packets callback. + * + * You can pass `nullptr` for `packetsCallback` in case the mode of operation requested by the user + * is synchronous. + * + * If the task is operating in asynchronous mode, any iOS MediaPipe task that uses the + * `MPPTaskRunner` must define a C++ callback function to obtain the results of inference + * asynchronously and deliver the results to the user. To accomplish this, the callback function + * should in turn invoke the block provided by the user in the task options supplied to create the + * task. Please see the documentation of the C++ Task Runner for more information on the synchronous + * and asynchronous modes of operation. * * @param graphConfig A mediapipe task graph config proto. + * @param packetsCallback An optional C++ callback function that takes a list of output packets as + * the input argument. If provided, the callback must in turn call the block provided by the user in + * the appropriate task options. * - * @return An instance of `MPPTaskRunner` initialized to the given graph config proto. - */ + * @return An instance of `MPPTaskRunner` initialized to the given graph config proto and optional + * packetsCallback. + **/ - (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + packetsCallback: + (mediapipe::tasks::core::PacketsCallback)packetsCallback error:(NSError **)error NS_DESIGNATED_INITIALIZER; +/** + * A synchronous method for processing batch data or offline streaming data. This method is designed + * for processing either batch data such as unrelated images and texts or offline streaming data + * such as the decoded frames from a video file or audio file. The call blocks the current + * thread until a failure status or a successful result is returned. If the input packets have no + * timestamp, an internal timestamp will be assigend per invocation. Otherwise, when the timestamp + * is set in the input packets, the caller must ensure that the input packet timestamps are greater + * than the timestamps of the previous invocation. This method is thread-unsafe and it is the + * caller's responsibility to synchronize access to this method across multiple threads and to + * ensure that the input packet timestamps are in order. + **/ - (absl::StatusOr)process: (const mediapipe::tasks::core::PacketMap &)packetMap; +/** + * Shuts down the C++ task runner. After the runner is closed, any calls that send input data to the + * runner are illegal and will receive errors. + **/ - (absl::Status)close; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm index c5c307fd5..a77f206b2 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm @@ -13,11 +13,17 @@ // limitations under the License. #import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" + #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "tensorflow/lite/core/api/op_resolver.h" + namespace { using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::tasks::core::MediaPipeBuiltinOpResolver; using ::mediapipe::tasks::core::PacketMap; +using ::mediapipe::tasks::core::PacketsCallback; using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; } // namespace @@ -30,15 +36,17 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; @implementation MPPTaskRunner - (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + packetsCallback:(PacketsCallback)packetsCallback error:(NSError **)error { self = [super init]; if (self) { - auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig)); + auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig), + absl::make_unique(), + std::move(packetsCallback)); if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) { return nil; } - _cppTaskRunner = std::move(taskRunnerResult.value()); } return self; From b91b485035545f3263cb88ade8444eb6fc32d407 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:36:28 +0530 Subject: [PATCH 264/346] Added MPPBaseOptions Helpers --- mediapipe/tasks/ios/core/utils/BUILD | 27 ++++++++++++ .../utils/sources/MPPBaseOptions+Helpers.h | 26 +++++++++++ .../utils/sources/MPPBaseOptions+Helpers.mm | 44 +++++++++++++++++++ 3 files changed, 97 insertions(+) create mode 100644 mediapipe/tasks/ios/core/utils/BUILD create mode 100644 mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h create mode 100644 mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm diff --git a/mediapipe/tasks/ios/core/utils/BUILD b/mediapipe/tasks/ios/core/utils/BUILD new file mode 100644 index 000000000..1cfc75e6a --- /dev/null +++ b/mediapipe/tasks/ios/core/utils/BUILD @@ -0,0 +1,27 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPBaseOptionsHelpers", + srcs = ["sources/MPPBaseOptions+Helpers.mm"], + hdrs = ["sources/MPPBaseOptions+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/core:MPPBaseOptions", + "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", + ], +) diff --git a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h new file mode 100644 index 000000000..d52df2ae4 --- /dev/null +++ b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h @@ -0,0 +1,26 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPBaseOptions (Helpers) + +- (void)copyToProto:(mediapipe::tasks::core::proto::BaseOptions *)baseOptionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm new file mode 100644 index 000000000..3fd8fbda3 --- /dev/null +++ b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm @@ -0,0 +1,44 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" + +namespace { +using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions; +} + +@implementation MPPBaseOptions (Helpers) + +- (void)copyToProto:(BaseOptionsProto *)baseOptionsProto { + baseOptionsProto->Clear(); + + if (self.modelAssetPath) { + baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String); + } + + switch (self.delegate) { + case MPPDelegateCPU: { + baseOptionsProto->mutable_acceleration()->mutable_tflite(); + break; + } + case MPPDelegateGPU: { + // TODO: Provide an implementation for GPU Delegate. + [NSException raise:@"Invalid value for delegate" format:@"GPU Delegate is not implemented."]; + } + default: + break; + } +} + +@end From 14e3de49ad09d9fc33cfe95ffb8038e473e132cf Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 6 Jan 2023 16:37:31 +0530 Subject: [PATCH 265/346] Added MPPTextTaskRunner --- mediapipe/tasks/ios/text/core/BUILD | 31 +++++++++++++ .../ios/text/core/sources/MPPTextTaskRunner.h | 43 +++++++++++++++++++ .../text/core/sources/MPPTextTaskRunner.mm | 29 +++++++++++++ 3 files changed, 103 insertions(+) create mode 100644 mediapipe/tasks/ios/text/core/BUILD create mode 100644 mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h create mode 100644 mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm diff --git a/mediapipe/tasks/ios/text/core/BUILD b/mediapipe/tasks/ios/text/core/BUILD new file mode 100644 index 000000000..bf88f5734 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/BUILD @@ -0,0 +1,31 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextTaskRunner", + srcs = ["sources/MPPTextTaskRunner.mm"], + hdrs = ["sources/MPPTextTaskRunner.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/tasks/ios/core:MPPTaskRunner", + ], +) + diff --git a/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h new file mode 100644 index 000000000..e3df3de9d --- /dev/null +++ b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h @@ -0,0 +1,43 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * This class is used to create and call appropriate methods on the C++ Task Runner to initialize, + * execute and terminate any MediaPipe text task. + **/ +@interface MPPTextTaskRunner : MPPTaskRunner + +/** + * Initializes a new `MPPTextTaskRunner` with the MediaPipe calculator config proto. + * + * @param graphConfig A MediaPipe calculator config proto. + * + * @return An instance of `MPPTextTaskRunner` initialized to the given MediaPipe calculator config + * proto. + **/ +- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm new file mode 100644 index 000000000..956448c17 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm @@ -0,0 +1,29 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" + +namespace { +using ::mediapipe::CalculatorGraphConfig; +} // namespace + +@implementation MPPTextTaskRunner + +- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + error:(NSError **)error { + self = [super initWithCalculatorGraphConfig:graphConfig packetsCallback:nullptr error:error]; + return self; +} + +@end From 2cce88080e8d320a547a870e9bf3f2f9f86fa2e0 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 6 Jan 2023 15:18:12 -0800 Subject: [PATCH 266/346] Internal change PiperOrigin-RevId: 500271109 --- mediapipe/calculators/image/scale_image_utils.cc | 6 ++++++ mediapipe/calculators/image/scale_image_utils_test.cc | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/mediapipe/calculators/image/scale_image_utils.cc b/mediapipe/calculators/image/scale_image_utils.cc index 490d0336a..86a53ffc5 100644 --- a/mediapipe/calculators/image/scale_image_utils.cc +++ b/mediapipe/calculators/image/scale_image_utils.cc @@ -142,6 +142,9 @@ absl::Status FindOutputDimensions(int input_width, // static_cast(input_height)); try_width = (try_width / 2) * 2; try_height = (try_height / 2) * 2; + // The output width/height should be greater than 0. + try_width = std::max(try_width, 1); + try_height = std::max(try_height, 1); if (target_height <= 0 || try_height <= target_height) { // The resulting height based on the target width and aspect ratio @@ -160,6 +163,9 @@ absl::Status FindOutputDimensions(int input_width, // static_cast(input_width)); try_width = (try_width / 2) * 2; try_height = (try_height / 2) * 2; + // The output width/height should be greater than 0. + try_width = std::max(try_width, 1); + try_height = std::max(try_height, 1); if (target_width <= 0 || try_width <= target_width) { // The resulting width based on the target width and aspect ratio diff --git a/mediapipe/calculators/image/scale_image_utils_test.cc b/mediapipe/calculators/image/scale_image_utils_test.cc index bda1fa4d6..b4810071c 100644 --- a/mediapipe/calculators/image/scale_image_utils_test.cc +++ b/mediapipe/calculators/image/scale_image_utils_test.cc @@ -124,6 +124,16 @@ TEST(ScaleImageUtilsTest, FindOutputDimensionsPreserveRatio) { &output_width, &output_height)); EXPECT_EQ(151, output_width); EXPECT_EQ(101, output_height); + // Scale to height 1. + MP_ASSERT_OK(FindOutputDimensions(10000, 10, 100, 0, 0, true, 2, + &output_width, &output_height)); + EXPECT_EQ(100, output_width); + EXPECT_EQ(1, output_height); + // Scale to width 1. + MP_ASSERT_OK(FindOutputDimensions(10, 10000, 0, 100, 0, true, 2, + &output_width, &output_height)); + EXPECT_EQ(1, output_width); + EXPECT_EQ(100, output_height); } // Tests scaling without keeping the aspect ratio fixed. From 9b34a105cfc3ca01a2a45afc011d613daaab7f26 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 6 Jan 2023 18:15:34 -0800 Subject: [PATCH 267/346] Do not depend on Image methods in TaskRunner PiperOrigin-RevId: 500299571 --- .../tasks/web/audio/audio_classifier/BUILD | 1 + .../audio_classifier/audio_classifier.ts | 3 ++- .../tasks/web/audio/audio_embedder/BUILD | 1 + .../audio/audio_embedder/audio_embedder.ts | 3 ++- mediapipe/tasks/web/core/BUILD | 2 -- mediapipe/tasks/web/core/task_runner.ts | 21 +++++++------------ mediapipe/tasks/web/core/task_runner_test.ts | 20 +++++++----------- .../text/text_classifier/text_classifier.ts | 4 ++-- .../web/text/text_embedder/text_embedder.ts | 4 ++-- mediapipe/tasks/web/vision/core/BUILD | 2 ++ .../vision/core/vision_task_runner.test.ts | 4 ++-- .../web/vision/core/vision_task_runner.ts | 15 ++++++++++++- .../tasks/web/vision/gesture_recognizer/BUILD | 2 +- .../gesture_recognizer/gesture_recognizer.ts | 4 ++-- .../gesture_recognizer_test.ts | 4 ++-- .../tasks/web/vision/hand_landmarker/BUILD | 2 +- .../vision/hand_landmarker/hand_landmarker.ts | 4 ++-- .../hand_landmarker/hand_landmarker_test.ts | 5 +++-- .../image_classifier/image_classifier.ts | 4 ++-- .../vision/image_embedder/image_embedder.ts | 4 ++-- .../vision/object_detector/object_detector.ts | 4 ++-- 21 files changed, 61 insertions(+), 52 deletions(-) diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 24ef31feb..a94b4931d 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -27,6 +27,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 51573f50a..92fca93ad 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -22,6 +22,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; +import {CachedGraphRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -98,7 +99,7 @@ export class AudioClassifier extends AudioTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new CachedGraphRunner(wasmModule, glCanvas)); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index 0817776c5..68a7f7bd5 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -27,6 +27,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 6a4b8ce39..2e210f969 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -24,6 +24,7 @@ import {Embedding} from '../../../../tasks/web/components/containers/embedding_r import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; +import {CachedGraphRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -100,7 +101,7 @@ export class AudioEmbedder extends AudioTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new CachedGraphRunner(wasmModule, glCanvas)); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index c0d10d28b..371c75da0 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -22,7 +22,6 @@ mediapipe_ts_library( "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", - "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", ], @@ -57,7 +56,6 @@ mediapipe_ts_library( deps = [ ":core", ":task_runner", - ":task_runner_test_utils", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/web/graph_runner:graph_runner_ts", ], diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index ffb538b52..a3df7adf5 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -19,8 +19,7 @@ import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb'; import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; -import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; -import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; +import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor} from '../../../web/graph_runner/graph_runner'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; import {WasmFileset} from './wasm_fileset'; @@ -29,10 +28,12 @@ import {WasmFileset} from './wasm_fileset'; const NO_ASSETS = undefined; // tslint:disable-next-line:enforce-name-casing -const GraphRunnerImageLibType = - SupportModelResourcesGraphService(SupportImage(GraphRunner)); -/** An implementation of the GraphRunner that supports image operations */ -export class GraphRunnerImageLib extends GraphRunnerImageLibType {} +const CachedGraphRunnerType = SupportModelResourcesGraphService(GraphRunner); +/** + * An implementation of the GraphRunner that exposes the resource graph + * service. + */ +export class CachedGraphRunner extends CachedGraphRunnerType {} /** * Creates a new instance of a Mediapipe Task. Determines if SIMD is @@ -64,7 +65,6 @@ export async function createTaskRunner( /** Base class for all MediaPipe Tasks. */ export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; - protected graphRunner: GraphRunnerImageLib; private processingErrors: Error[] = []; /** @@ -79,12 +79,7 @@ export abstract class TaskRunner { } /** @hideconstructor protected */ - constructor( - wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, - graphRunner?: GraphRunnerImageLib) { - this.graphRunner = - graphRunner ?? new GraphRunnerImageLib(wasmModule, glCanvas); - + constructor(protected readonly graphRunner: CachedGraphRunner) { // Disables the automatic render-to-screen code, which allows for pure // CPU processing. this.graphRunner.setAutoRenderToScreen(false); diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts index a55ac04d7..684beb70c 100644 --- a/mediapipe/tasks/web/core/task_runner_test.ts +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -18,11 +18,10 @@ import 'jasmine'; // Placeholder for internal dependency on encodeByteArray import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {TaskRunner} from '../../../tasks/web/core/task_runner'; -import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; import {ErrorListener} from '../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource URL builder -import {GraphRunnerImageLib} from './task_runner'; +import {CachedGraphRunner} from './task_runner'; import {TaskRunnerOptions} from './task_runner_options.d'; class TaskRunnerFake extends TaskRunner { @@ -32,18 +31,15 @@ class TaskRunnerFake extends TaskRunner { baseOptions = new BaseOptionsProto(); static createFake(): TaskRunnerFake { - const wasmModule = createSpyWasmModule(); - return new TaskRunnerFake(wasmModule); + return new TaskRunnerFake(); } - constructor(wasmModuleFake: SpyWasmModule) { - super( - wasmModuleFake, /* glCanvas= */ null, - jasmine.createSpyObj([ - 'setAutoRenderToScreen', 'setGraph', 'finishProcessing', - 'registerModelResourcesGraphService', 'attachErrorListener' - ])); - const graphRunner = this.graphRunner as jasmine.SpyObj; + constructor() { + super(jasmine.createSpyObj([ + 'setAutoRenderToScreen', 'setGraph', 'finishProcessing', + 'registerModelResourcesGraphService', 'attachErrorListener' + ])); + const graphRunner = this.graphRunner as jasmine.SpyObj; expect(graphRunner.registerModelResourcesGraphService).toHaveBeenCalled(); expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled(); graphRunner.attachErrorListener.and.callFake(listener => { diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 981438625..6aef1b3e4 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -21,7 +21,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {CachedGraphRunner, TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -96,7 +96,7 @@ export class TextClassifier extends TaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new CachedGraphRunner(wasmModule, glCanvas)); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 7aa0aa6b9..db7986dec 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -23,7 +23,7 @@ import {Embedding} from '../../../../tasks/web/components/containers/embedding_r import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {CachedGraphRunner, TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -100,7 +100,7 @@ export class TextEmbedder extends TaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new CachedGraphRunner(wasmModule, glCanvas)); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 03958a819..3574483df 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -20,7 +20,9 @@ mediapipe_ts_library( ":vision_task_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_ts", + "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", ], ) diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts index d77cc4fed..f3f25070e 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -21,13 +21,13 @@ import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_u import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {VisionTaskOptions} from './vision_task_options'; -import {VisionTaskRunner} from './vision_task_runner'; +import {VisionGraphRunner, VisionTaskRunner} from './vision_task_runner'; class VisionTaskRunnerFake extends VisionTaskRunner { baseOptions = new BaseOptionsProto(); constructor() { - super(createSpyWasmModule(), /* glCanvas= */ null); + super(new VisionGraphRunner(createSpyWasmModule(), /* glCanvas= */ null)); } protected override process(): void {} diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 952990326..c3e0d3c7e 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -15,12 +15,25 @@ */ import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {GraphRunner, ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {SupportImage} from '../../../../web/graph_runner/graph_runner_image_lib'; +import {SupportModelResourcesGraphService} from '../../../../web/graph_runner/register_model_resources_graph_service'; import {VisionTaskOptions} from './vision_task_options'; +// tslint:disable-next-line:enforce-name-casing +const GraphRunnerVisionType = + SupportModelResourcesGraphService(SupportImage(GraphRunner)); +/** An implementation of the GraphRunner that supports image operations */ +export class VisionGraphRunner extends GraphRunnerVisionType {} + /** Base class for all MediaPipe Vision Tasks. */ export abstract class VisionTaskRunner extends TaskRunner { + /** @hideconstructor protected */ + constructor(protected override readonly graphRunner: VisionGraphRunner) { + super(graphRunner); + } + /** Configures the shared options of a vision task. */ override applyOptions(options: VisionTaskOptions): Promise { if ('runningMode' in options) { diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index aa2f9c366..5fdf9b43e 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -67,8 +67,8 @@ mediapipe_ts_library( "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:task_runner", "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/tasks/web/vision/core:vision_task_runner", ], ) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index c77f2c67a..8d36ed89c 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -30,7 +30,7 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; -import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -131,7 +131,7 @@ export class GestureRecognizer extends constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new VisionGraphRunner(wasmModule, glCanvas)); this.options = new GestureRecognizerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index 3611c3a7d..3699033b2 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -18,8 +18,8 @@ import 'jasmine'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; -import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {GestureRecognizer, GestureRecognizerOptions} from './gesture_recognizer'; @@ -98,7 +98,7 @@ class GestureRecognizerFake extends GestureRecognizer implements spyOn(this.graphRunner, 'addProtoToStream'); } - getGraphRunner(): GraphRunnerImageLib { + getGraphRunner(): VisionGraphRunner { return this.graphRunner; } } diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index d1f1e48f3..e7083a050 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -62,8 +62,8 @@ mediapipe_ts_library( "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:task_runner", "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/tasks/web/vision/core:vision_task_runner", ], ) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 24cf9a402..5db6d48f5 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -26,7 +26,7 @@ import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/han import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; -import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -119,7 +119,7 @@ export class HandLandmarker extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new VisionGraphRunner(wasmModule, glCanvas)); this.options = new HandLandmarkerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts index 1a813c6f7..bce0eac02 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -18,12 +18,13 @@ import 'jasmine'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; -import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {HandLandmarker} from './hand_landmarker'; import {HandLandmarkerOptions} from './hand_landmarker_options'; + // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern @@ -87,7 +88,7 @@ class HandLandmarkerFake extends HandLandmarker implements MediapipeTasksFake { spyOn(this.graphRunner, 'addProtoToStream'); } - getGraphRunner(): GraphRunnerImageLib { + getGraphRunner(): VisionGraphRunner { return this.graphRunner; } } diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 9298a860c..4a2be5566 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -22,7 +22,7 @@ import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_cla import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; -import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -97,7 +97,7 @@ export class ImageClassifier extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new VisionGraphRunner(wasmModule, glCanvas)); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index cf0bd8c5d..4651ae4ce 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -24,7 +24,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; -import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -99,7 +99,7 @@ export class ImageEmbedder extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new VisionGraphRunner(wasmModule, glCanvas)); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index e4c51de08..ac489ec00 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -20,7 +20,7 @@ import {Detection as DetectionProto} from '../../../../framework/formats/detecti import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; -import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -96,7 +96,7 @@ export class ObjectDetector extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + super(new VisionGraphRunner(wasmModule, glCanvas)); this.options.setBaseOptions(new BaseOptionsProto()); } From 9055effddd35e0424db2a11a81445f32f6badae8 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 6 Jan 2023 20:52:30 -0800 Subject: [PATCH 268/346] Add ImageProcessingOptions to all Vision Tasks PiperOrigin-RevId: 500323261 --- .../cc/vision/core/image_processing_options.h | 2 +- .../tasks/web/components/containers/BUILD | 5 + .../tasks/web/components/containers/rect.d.ts | 41 +++++ .../tasks/web/core/task_runner_test_utils.ts | 6 +- mediapipe/tasks/web/vision/core/BUILD | 12 ++ .../vision/core/image_processing_options.d.ts | 42 +++++ .../vision/core/vision_task_runner.test.ts | 158 ++++++++++++++++-- .../web/vision/core/vision_task_runner.ts | 95 +++++++++-- .../tasks/web/vision/gesture_recognizer/BUILD | 2 +- .../gesture_recognizer/gesture_recognizer.ts | 48 +++--- .../tasks/web/vision/hand_landmarker/BUILD | 2 +- .../vision/hand_landmarker/hand_landmarker.ts | 46 ++--- .../tasks/web/vision/image_classifier/BUILD | 1 + .../image_classifier/image_classifier.ts | 43 ++--- .../tasks/web/vision/image_embedder/BUILD | 1 + .../vision/image_embedder/image_embedder.ts | 45 ++--- .../tasks/web/vision/object_detector/BUILD | 1 + .../vision/object_detector/object_detector.ts | 42 +++-- 18 files changed, 460 insertions(+), 132 deletions(-) create mode 100644 mediapipe/tasks/web/components/containers/rect.d.ts create mode 100644 mediapipe/tasks/web/vision/core/image_processing_options.d.ts diff --git a/mediapipe/tasks/cc/vision/core/image_processing_options.h b/mediapipe/tasks/cc/vision/core/image_processing_options.h index 1983272fc..e2647be71 100644 --- a/mediapipe/tasks/cc/vision/core/image_processing_options.h +++ b/mediapipe/tasks/cc/vision/core/image_processing_options.h @@ -28,7 +28,7 @@ namespace core { // Options for image processing. // // If both region-or-interest and rotation are specified, the crop around the -// region-of-interest is extracted first, the the specified rotation is applied +// region-of-interest is extracted first, then the specified rotation is applied // to the crop. struct ImageProcessingOptions { // The optional region-of-interest to crop from the image. If not specified, diff --git a/mediapipe/tasks/web/components/containers/BUILD b/mediapipe/tasks/web/components/containers/BUILD index fb0fdff16..a0db59d0b 100644 --- a/mediapipe/tasks/web/components/containers/BUILD +++ b/mediapipe/tasks/web/components/containers/BUILD @@ -24,3 +24,8 @@ mediapipe_ts_declaration( name = "embedding_result", srcs = ["embedding_result.d.ts"], ) + +mediapipe_ts_declaration( + name = "rect", + srcs = ["rect.d.ts"], +) diff --git a/mediapipe/tasks/web/components/containers/rect.d.ts b/mediapipe/tasks/web/components/containers/rect.d.ts new file mode 100644 index 000000000..9afece9ca --- /dev/null +++ b/mediapipe/tasks/web/components/containers/rect.d.ts @@ -0,0 +1,41 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Defines a rectangle, used e.g. as part of detection results or as input + * region-of-interest. + */ +export declare interface Rect { + left: number; + top: number; + right: number; + bottom: number; +} + +/** + * Defines a rectangle, used e.g. as part of detection results or as input + * region-of-interest. + * + * The coordinates are normalized with respect to the image dimensions, i.e. + * generally in [0,1] but they may exceed these bounds if describing a region + * overlapping the image. The origin is on the top-left corner of the image. + */ +export declare interface RectF { + left: number; + top: number; + right: number; + bottom: number; +} diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts index 838b3f585..62dd0463a 100644 --- a/mediapipe/tasks/web/core/task_runner_test_utils.ts +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -32,12 +32,14 @@ export declare type SpyWasmModule = jasmine.SpyObj; * in pure JS/TS (and optionally spy on the calls). */ export function createSpyWasmModule(): SpyWasmModule { - return jasmine.createSpyObj([ + const spyWasmModule = jasmine.createSpyObj([ '_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener', '_attachProtoVectorListener', '_free', '_waitUntilIdle', '_addStringToInputStream', '_registerModelResourcesGraphService', - '_configureAudio' + '_configureAudio', '_malloc', '_addProtoToInputStream' ]); + spyWasmModule.HEAPU8 = jasmine.createSpyObj(['set']); + return spyWasmModule; } /** diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 3574483df..a0a008122 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -5,6 +5,14 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", package(default_visibility = ["//mediapipe/tasks:internal"]) +mediapipe_ts_declaration( + name = "image_processing_options", + srcs = ["image_processing_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/components/containers:rect", + ], +) + mediapipe_ts_declaration( name = "vision_task_options", srcs = ["vision_task_options.d.ts"], @@ -17,7 +25,9 @@ mediapipe_ts_library( name = "vision_task_runner", srcs = ["vision_task_runner.ts"], deps = [ + ":image_processing_options", ":vision_task_options", + "//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", @@ -31,8 +41,10 @@ mediapipe_ts_library( testonly = True, srcs = ["vision_task_runner.test.ts"], deps = [ + ":image_processing_options", ":vision_task_options", ":vision_task_runner", + "//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/web/core:task_runner_test_utils", "//mediapipe/web/graph_runner:graph_runner_ts", diff --git a/mediapipe/tasks/web/vision/core/image_processing_options.d.ts b/mediapipe/tasks/web/vision/core/image_processing_options.d.ts new file mode 100644 index 000000000..b76731546 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/image_processing_options.d.ts @@ -0,0 +1,42 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {RectF} from '../../../../tasks/web/components/containers/rect'; + +/** + * Options for image processing. + * + * If both region-or-interest and rotation are specified, the crop around the + * region-of-interest is extracted first, then the specified rotation is applied + * to the crop. + */ +export declare interface ImageProcessingOptions { + /** + * The optional region-of-interest to crop from the image. If not specified, + * the full image is used. + * + * Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom. + */ + regionOfInterest?: RectF; + + /** + * The rotation to apply to the image (or cropped region-of-interest), in + * degrees clockwise. + * + * The rotation must be a multiple (positive or negative) of 90°. + */ + rotationDegrees?: number; +} diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts index f3f25070e..a48381038 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -16,21 +16,62 @@ import 'jasmine'; +import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; -import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils'; +import {addJasmineCustomFloatEqualityTester} from '../../../../tasks/web/core/task_runner_test_utils'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {VisionTaskOptions} from './vision_task_options'; import {VisionGraphRunner, VisionTaskRunner} from './vision_task_runner'; -class VisionTaskRunnerFake extends VisionTaskRunner { + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; + +const IMAGE = {} as unknown as HTMLImageElement; +const TIMESTAMP = 42; + +class VisionTaskRunnerFake extends VisionTaskRunner { baseOptions = new BaseOptionsProto(); + fakeGraphRunner: jasmine.SpyObj; + expectedImageSource?: ImageSource; + expectedNormalizedRect?: NormalizedRect; constructor() { - super(new VisionGraphRunner(createSpyWasmModule(), /* glCanvas= */ null)); - } + super( + jasmine.createSpyObj([ + 'addProtoToStream', 'addGpuBufferAsImageToStream', + 'setAutoRenderToScreen', 'registerModelResourcesGraphService', + 'finishProcessing' + ]), + IMAGE_STREAM, NORM_RECT_STREAM); - protected override process(): void {} + this.fakeGraphRunner = + this.graphRunner as unknown as jasmine.SpyObj; + + (this.graphRunner.addProtoToStream as jasmine.Spy) + .and.callFake((serializedData, type, streamName, timestamp) => { + expect(type).toBe('mediapipe.NormalizedRect'); + expect(streamName).toBe(NORM_RECT_STREAM); + expect(timestamp).toBe(TIMESTAMP); + + const actualNormalizedRect = + NormalizedRect.deserializeBinary(serializedData); + expect(actualNormalizedRect.toObject()) + .toEqual(this.expectedNormalizedRect!.toObject()); + }); + + (this.graphRunner.addGpuBufferAsImageToStream as jasmine.Spy) + .and.callFake((imageSource, streamName, timestamp) => { + expect(streamName).toBe(IMAGE_STREAM); + expect(timestamp).toBe(TIMESTAMP); + expect(imageSource).toBe(this.expectedImageSource!); + }); + } protected override refreshGraph(): void {} @@ -38,12 +79,31 @@ class VisionTaskRunnerFake extends VisionTaskRunner { return this.applyOptions(options); } - override processImageData(image: ImageSource): void { - super.processImageData(image); + override processImageData( + image: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined): void { + super.processImageData(image, imageProcessingOptions); } - override processVideoData(imageFrame: ImageSource, timestamp: number): void { - super.processVideoData(imageFrame, timestamp); + override processVideoData( + imageFrame: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined, + timestamp: number): void { + super.processVideoData(imageFrame, imageProcessingOptions, timestamp); + } + + expectNormalizedRect( + xCenter: number, yCenter: number, width: number, height: number): void { + const rect = new NormalizedRect(); + rect.setXCenter(xCenter); + rect.setYCenter(yCenter); + rect.setWidth(width); + rect.setHeight(height); + this.expectedNormalizedRect = rect; + } + + expectImage(imageSource: ImageSource): void { + this.expectedImageSource = imageSource; } } @@ -51,6 +111,7 @@ describe('VisionTaskRunner', () => { let visionTaskRunner: VisionTaskRunnerFake; beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions( {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); @@ -72,7 +133,8 @@ describe('VisionTaskRunner', () => { await visionTaskRunner.setOptions({runningMode: 'video'}); // Clear running mode - await visionTaskRunner.setOptions({runningMode: undefined}); + await visionTaskRunner.setOptions( + {runningMode: /* imageProcessingOptions= */ undefined}); expect(visionTaskRunner.baseOptions.toObject()) .toEqual(jasmine.objectContaining({useStreamMode: false})); }); @@ -80,20 +142,90 @@ describe('VisionTaskRunner', () => { it('cannot process images with video mode', async () => { await visionTaskRunner.setOptions({runningMode: 'video'}); expect(() => { - visionTaskRunner.processImageData({} as HTMLImageElement); + visionTaskRunner.processImageData( + IMAGE, /* imageProcessingOptions= */ undefined); }).toThrowError(/Task is not initialized with image mode./); }); it('cannot process video with image mode', async () => { // Use default for `useStreamMode` expect(() => { - visionTaskRunner.processVideoData({} as HTMLImageElement, 42); + visionTaskRunner.processVideoData( + IMAGE, /* imageProcessingOptions= */ undefined, TIMESTAMP); }).toThrowError(/Task is not initialized with video mode./); // Explicitly set to image mode await visionTaskRunner.setOptions({runningMode: 'image'}); expect(() => { - visionTaskRunner.processVideoData({} as HTMLImageElement, 42); + visionTaskRunner.processVideoData( + IMAGE, /* imageProcessingOptions= */ undefined, TIMESTAMP); }).toThrowError(/Task is not initialized with video mode./); }); + + it('sends packets to graph', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + + visionTaskRunner.expectImage(IMAGE); + visionTaskRunner.expectNormalizedRect(0.5, 0.5, 1, 1); + visionTaskRunner.processVideoData( + IMAGE, /* imageProcessingOptions= */ undefined, TIMESTAMP); + }); + + it('sends packets to graph with image processing options', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + + visionTaskRunner.expectImage(IMAGE); + visionTaskRunner.expectNormalizedRect(0.3, 0.6, 0.2, 0.4); + visionTaskRunner.processVideoData( + IMAGE, + {regionOfInterest: {left: 0.2, right: 0.4, top: 0.4, bottom: 0.8}}, + TIMESTAMP); + }); + + describe('validates processing options', () => { + it('with left > right', () => { + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.2, + right: 0.1, + top: 0.1, + bottom: 0.2, + } + }); + }).toThrowError('Expected RectF with left < right and top < bottom.'); + }); + + it('with top > bottom', () => { + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.1, + right: 0.2, + top: 0.2, + bottom: 0.1, + } + }); + }).toThrowError('Expected RectF with left < right and top < bottom.'); + }); + + it('with out of range values', () => { + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.1, + right: 1.1, + top: 0.1, + bottom: 0.2, + } + }); + }).toThrowError('Expected RectF values to be in [0,1].'); + }); + + it('with non-90 degree rotation', () => { + expect(() => { + visionTaskRunner.processImageData(IMAGE, {rotationDegrees: 42}); + }).toThrowError('Expected rotation to be a multiple of 90°.'); + }); + }); }); diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index c3e0d3c7e..9adc810fc 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -14,7 +14,9 @@ * limitations under the License. */ +import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {GraphRunner, ImageSource} from '../../../../web/graph_runner/graph_runner'; import {SupportImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {SupportModelResourcesGraphService} from '../../../../web/graph_runner/register_model_resources_graph_service'; @@ -27,10 +29,26 @@ const GraphRunnerVisionType = /** An implementation of the GraphRunner that supports image operations */ export class VisionGraphRunner extends GraphRunnerVisionType {} +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + /** Base class for all MediaPipe Vision Tasks. */ -export abstract class VisionTaskRunner extends TaskRunner { - /** @hideconstructor protected */ - constructor(protected override readonly graphRunner: VisionGraphRunner) { +export abstract class VisionTaskRunner extends TaskRunner { + /** + * Constructor to initialize a `VisionTaskRunner`. + * + * @param graphRunner the graph runner for this task. + * @param imageStreamName the name of the input image stream. + * @param normRectStreamName the name of the input normalized rect image + * stream used to provide (mandatory) rotation and (optional) + * region-of-interest. + * + * @hideconstructor protected + */ + constructor( + protected override readonly graphRunner: VisionGraphRunner, + private readonly imageStreamName: string, + private readonly normRectStreamName: string) { super(graphRunner); } @@ -44,27 +62,84 @@ export abstract class VisionTaskRunner extends TaskRunner { return super.applyOptions(options); } - /** Sends an image packet to the graph and awaits results. */ - protected abstract process(input: ImageSource, timestamp: number): T; - /** Sends a single image to the graph and awaits results. */ - protected processImageData(image: ImageSource): T { + protected processImageData( + image: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined): void { if (!!this.baseOptions?.getUseStreamMode()) { throw new Error( 'Task is not initialized with image mode. ' + '\'runningMode\' must be set to \'image\'.'); } - return this.process(image, performance.now()); + this.process(image, imageProcessingOptions, performance.now()); } /** Sends a single video frame to the graph and awaits results. */ - protected processVideoData(imageFrame: ImageSource, timestamp: number): T { + protected processVideoData( + imageFrame: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined, + timestamp: number): void { if (!this.baseOptions?.getUseStreamMode()) { throw new Error( 'Task is not initialized with video mode. ' + '\'runningMode\' must be set to \'video\'.'); } - return this.process(imageFrame, timestamp); + this.process(imageFrame, imageProcessingOptions, timestamp); + } + + private convertToNormalizedRect(imageProcessingOptions?: + ImageProcessingOptions): NormalizedRect { + const normalizedRect = new NormalizedRect(); + + if (imageProcessingOptions?.regionOfInterest) { + const roi = imageProcessingOptions.regionOfInterest; + + if (roi.left >= roi.right || roi.top >= roi.bottom) { + throw new Error('Expected RectF with left < right and top < bottom.'); + } + if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) { + throw new Error('Expected RectF values to be in [0,1].'); + } + + normalizedRect.setXCenter((roi.left + roi.right) / 2.0); + normalizedRect.setYCenter((roi.top + roi.bottom) / 2.0); + normalizedRect.setWidth(roi.right - roi.left); + normalizedRect.setHeight(roi.bottom - roi.top); + return normalizedRect; + } else { + normalizedRect.setXCenter(0.5); + normalizedRect.setYCenter(0.5); + normalizedRect.setWidth(1); + normalizedRect.setHeight(1); + } + + if (imageProcessingOptions?.rotationDegrees) { + if (imageProcessingOptions?.rotationDegrees % 90 !== 0) { + throw new Error( + 'Expected rotation to be a multiple of 90°.', + ); + } + + // Convert to radians anti-clockwise. + normalizedRect.setRotation( + -Math.PI * imageProcessingOptions.rotationDegrees / 180.0); + } + + return normalizedRect; + } + + /** Runs the graph and blocks on the response. */ + private process( + imageSource: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined, + timestamp: number): void { + const normalizedRect = this.convertToNormalizedRect(imageProcessingOptions); + this.graphRunner.addProtoToStream( + normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect', + this.normRectStreamName, timestamp); + this.graphRunner.addGpuBufferAsImageToStream( + imageSource, this.imageStreamName, timestamp ?? performance.now()); + this.finishProcessing(); } } diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index 5fdf9b43e..9156e89b7 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -20,7 +20,6 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", - "//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_jspb_proto", @@ -33,6 +32,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 8d36ed89c..e0c6affcb 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -18,7 +18,6 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationList} from '../../../../framework/formats/classification_pb'; import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; -import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {GestureClassifierGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options_pb'; import {GestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options_pb'; @@ -30,6 +29,7 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -57,15 +57,8 @@ const DEFAULT_NUM_HANDS = 1; const DEFAULT_SCORE_THRESHOLD = 0.5; const DEFAULT_CATEGORY_INDEX = -1; -const FULL_IMAGE_RECT = new NormalizedRect(); -FULL_IMAGE_RECT.setXCenter(0.5); -FULL_IMAGE_RECT.setYCenter(0.5); -FULL_IMAGE_RECT.setWidth(1); -FULL_IMAGE_RECT.setHeight(1); - /** Performs hand gesture recognition on images. */ -export class GestureRecognizer extends - VisionTaskRunner { +export class GestureRecognizer extends VisionTaskRunner { private gestures: Category[][] = []; private landmarks: NormalizedLandmark[][] = []; private worldLandmarks: Landmark[][] = []; @@ -131,7 +124,9 @@ export class GestureRecognizer extends constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(new VisionGraphRunner(wasmModule, glCanvas)); + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM); this.options = new GestureRecognizerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); @@ -228,10 +223,16 @@ export class GestureRecognizer extends * GestureRecognizer is created with running mode `image`. * * @param image A single image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The detected gestures. */ - recognize(image: ImageSource): GestureRecognizerResult { - return this.processImageData(image); + recognize( + image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + GestureRecognizerResult { + this.resetResults(); + this.processImageData(image, imageProcessingOptions); + return this.processResults(); } /** @@ -241,28 +242,27 @@ export class GestureRecognizer extends * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The detected gestures. */ - recognizeForVideo(videoFrame: ImageSource, timestamp: number): + recognizeForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): GestureRecognizerResult { - return this.processVideoData(videoFrame, timestamp); + this.resetResults(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + return this.processResults(); } - /** Runs the gesture recognition and blocks on the response. */ - protected override process(imageSource: ImageSource, timestamp: number): - GestureRecognizerResult { + private resetResults(): void { this.gestures = []; this.landmarks = []; this.worldLandmarks = []; this.handednesses = []; + } - this.graphRunner.addGpuBufferAsImageToStream( - imageSource, IMAGE_STREAM, timestamp); - this.graphRunner.addProtoToStream( - FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', - NORM_RECT_STREAM, timestamp); - this.finishProcessing(); - + private processResults(): GestureRecognizerResult { if (this.gestures.length === 0) { // If no gestures are detected in the image, just return an empty list return { diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index e7083a050..c5687ee2f 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -20,7 +20,6 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", - "//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_jspb_proto", @@ -28,6 +27,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 5db6d48f5..e238bc96f 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -18,7 +18,6 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationList} from '../../../../framework/formats/classification_pb'; import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; -import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detector/proto/hand_detector_graph_options_pb'; import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; @@ -26,6 +25,7 @@ import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/han import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -51,14 +51,9 @@ const HAND_LANDMARKER_GRAPH = const DEFAULT_NUM_HANDS = 1; const DEFAULT_SCORE_THRESHOLD = 0.5; const DEFAULT_CATEGORY_INDEX = -1; -const FULL_IMAGE_RECT = new NormalizedRect(); -FULL_IMAGE_RECT.setXCenter(0.5); -FULL_IMAGE_RECT.setYCenter(0.5); -FULL_IMAGE_RECT.setWidth(1); -FULL_IMAGE_RECT.setHeight(1); /** Performs hand landmarks detection on images. */ -export class HandLandmarker extends VisionTaskRunner { +export class HandLandmarker extends VisionTaskRunner { private landmarks: NormalizedLandmark[][] = []; private worldLandmarks: Landmark[][] = []; private handednesses: Category[][] = []; @@ -119,7 +114,9 @@ export class HandLandmarker extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(new VisionGraphRunner(wasmModule, glCanvas)); + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM); this.options = new HandLandmarkerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); @@ -180,10 +177,15 @@ export class HandLandmarker extends VisionTaskRunner { * HandLandmarker is created with running mode `image`. * * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The detected hand landmarks. */ - detect(image: ImageSource): HandLandmarkerResult { - return this.processImageData(image); + detect(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + HandLandmarkerResult { + this.resetResults(); + this.processImageData(image, imageProcessingOptions); + return this.processResults(); } /** @@ -193,27 +195,25 @@ export class HandLandmarker extends VisionTaskRunner { * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The detected hand landmarks. */ - detectForVideo(videoFrame: ImageSource, timestamp: number): - HandLandmarkerResult { - return this.processVideoData(videoFrame, timestamp); + detectForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): HandLandmarkerResult { + this.resetResults(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + return this.processResults(); } - /** Runs the hand landmarker graph and blocks on the response. */ - protected override process(imageSource: ImageSource, timestamp: number): - HandLandmarkerResult { + private resetResults(): void { this.landmarks = []; this.worldLandmarks = []; this.handednesses = []; + } - this.graphRunner.addGpuBufferAsImageToStream( - imageSource, IMAGE_STREAM, timestamp); - this.graphRunner.addProtoToStream( - FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', - NORM_RECT_STREAM, timestamp); - this.finishProcessing(); - + private processResults(): HandLandmarkerResult { return { landmarks: this.landmarks, worldLandmarks: this.worldLandmarks, diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index 310575964..86c7d8457 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -26,6 +26,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 4a2be5566..2ad4a821d 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -22,6 +22,7 @@ import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_cla import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -31,7 +32,8 @@ import {ImageClassifierResult} from './image_classifier_result'; const IMAGE_CLASSIFIER_GRAPH = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'; -const INPUT_STREAM = 'input_image'; +const IMAGE_STREAM = 'input_image'; +const NORM_RECT_STREAM = 'norm_rect'; const CLASSIFICATIONS_STREAM = 'classifications'; export * from './image_classifier_options'; @@ -42,7 +44,7 @@ export {ImageSource}; // Used in the public API // tslint:disable:jspb-use-builder-pattern /** Performs classification on images. */ -export class ImageClassifier extends VisionTaskRunner { +export class ImageClassifier extends VisionTaskRunner { private classificationResult: ImageClassifierResult = {classifications: []}; private readonly options = new ImageClassifierGraphOptions(); @@ -97,7 +99,9 @@ export class ImageClassifier extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(new VisionGraphRunner(wasmModule, glCanvas)); + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM); this.options.setBaseOptions(new BaseOptionsProto()); } @@ -130,10 +134,15 @@ export class ImageClassifier extends VisionTaskRunner { * ImageClassifier is created with running mode `image`. * * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The classification result of the image */ - classify(image: ImageSource): ImageClassifierResult { - return this.processImageData(image); + classify(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + ImageClassifierResult { + this.classificationResult = {classifications: []}; + this.processImageData(image, imageProcessingOptions); + return this.classificationResult; } /** @@ -143,28 +152,23 @@ export class ImageClassifier extends VisionTaskRunner { * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The classification result of the image */ - classifyForVideo(videoFrame: ImageSource, timestamp: number): - ImageClassifierResult { - return this.processVideoData(videoFrame, timestamp); - } - - /** Runs the image classification graph and blocks on the response. */ - protected override process(imageSource: ImageSource, timestamp: number): - ImageClassifierResult { - // Get classification result by running our MediaPipe graph. + classifyForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): ImageClassifierResult { this.classificationResult = {classifications: []}; - this.graphRunner.addGpuBufferAsImageToStream( - imageSource, INPUT_STREAM, timestamp ?? performance.now()); - this.finishProcessing(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); return this.classificationResult; } /** Updates the MediaPipe graph configuration. */ protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); - graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); const calculatorOptions = new CalculatorOptions(); @@ -175,7 +179,8 @@ export class ImageClassifier extends VisionTaskRunner { // are built-in. const classifierNode = new CalculatorGraphConfig.Node(); classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH); - classifierNode.addInputStream('IMAGE:' + INPUT_STREAM); + classifierNode.addInputStream('IMAGE:' + IMAGE_STREAM); + classifierNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM); classifierNode.setOptions(calculatorOptions); diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index de4785e6c..449cee9bb 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -26,6 +26,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 4651ae4ce..64a10f5f4 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -24,6 +24,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -31,10 +32,12 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner import {ImageEmbedderOptions} from './image_embedder_options'; import {ImageEmbedderResult} from './image_embedder_result'; + // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern -const INPUT_STREAM = 'image_in'; +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; const EMBEDDINGS_STREAM = 'embeddings_out'; const TEXT_EMBEDDER_CALCULATOR = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'; @@ -44,7 +47,7 @@ export * from './image_embedder_result'; export {ImageSource}; // Used in the public API /** Performs embedding extraction on images. */ -export class ImageEmbedder extends VisionTaskRunner { +export class ImageEmbedder extends VisionTaskRunner { private readonly options = new ImageEmbedderGraphOptions(); private embeddings: ImageEmbedderResult = {embeddings: []}; @@ -99,7 +102,9 @@ export class ImageEmbedder extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(new VisionGraphRunner(wasmModule, glCanvas)); + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM); this.options.setBaseOptions(new BaseOptionsProto()); } @@ -132,10 +137,14 @@ export class ImageEmbedder extends VisionTaskRunner { * ImageEmbedder is created with running mode `image`. * * @param image The image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The classification result of the image */ - embed(image: ImageSource): ImageEmbedderResult { - return this.processImageData(image); + embed(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + ImageEmbedderResult { + this.processImageData(image, imageProcessingOptions); + return this.embeddings; } /** @@ -145,11 +154,15 @@ export class ImageEmbedder extends VisionTaskRunner { * * @param imageFrame The image frame to process. * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The classification result of the image */ - embedForVideo(imageFrame: ImageSource, timestamp: number): - ImageEmbedderResult { - return this.processVideoData(imageFrame, timestamp); + embedForVideo( + imageFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): ImageEmbedderResult { + this.processVideoData(imageFrame, imageProcessingOptions, timestamp); + return this.embeddings; } /** @@ -165,16 +178,6 @@ export class ImageEmbedder extends VisionTaskRunner { return computeCosineSimilarity(u, v); } - /** Runs the embedding extraction and blocks on the response. */ - protected process(image: ImageSource, timestamp: number): - ImageEmbedderResult { - // Get embeddings by running our MediaPipe graph. - this.graphRunner.addGpuBufferAsImageToStream( - image, INPUT_STREAM, timestamp ?? performance.now()); - this.finishProcessing(); - return this.embeddings; - } - /** * Internal function for converting raw data into an embedding, and setting it * as our embeddings result. @@ -187,7 +190,8 @@ export class ImageEmbedder extends VisionTaskRunner { /** Updates the MediaPipe graph configuration. */ protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); - graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM); const calculatorOptions = new CalculatorOptions(); @@ -195,7 +199,8 @@ export class ImageEmbedder extends VisionTaskRunner { const embedderNode = new CalculatorGraphConfig.Node(); embedderNode.setCalculator(TEXT_EMBEDDER_CALCULATOR); - embedderNode.addInputStream('IMAGE:' + INPUT_STREAM); + embedderNode.addInputStream('IMAGE:' + IMAGE_STREAM); + embedderNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); embedderNode.addOutputStream('EMBEDDINGS:' + EMBEDDINGS_STREAM); embedderNode.setOptions(calculatorOptions); diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index fc206a2d7..76fa589c8 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -23,6 +23,7 @@ mediapipe_ts_library( "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index ac489ec00..3a79c1b00 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -20,6 +20,7 @@ import {Detection as DetectionProto} from '../../../../framework/formats/detecti import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -27,7 +28,8 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner import {ObjectDetectorOptions} from './object_detector_options'; import {Detection} from './object_detector_result'; -const INPUT_STREAM = 'input_frame_gpu'; +const IMAGE_STREAM = 'input_frame_gpu'; +const NORM_RECT_STREAM = 'norm_rect'; const DETECTIONS_STREAM = 'detections'; const OBJECT_DETECTOR_GRAPH = 'mediapipe.tasks.vision.ObjectDetectorGraph'; @@ -41,7 +43,7 @@ export {ImageSource}; // Used in the public API // tslint:disable:jspb-use-builder-pattern /** Performs object detection on images. */ -export class ObjectDetector extends VisionTaskRunner { +export class ObjectDetector extends VisionTaskRunner { private detections: Detection[] = []; private readonly options = new ObjectDetectorOptionsProto(); @@ -96,7 +98,9 @@ export class ObjectDetector extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(new VisionGraphRunner(wasmModule, glCanvas)); + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM); this.options.setBaseOptions(new BaseOptionsProto()); } @@ -160,10 +164,15 @@ export class ObjectDetector extends VisionTaskRunner { * ObjectDetector is created with running mode `image`. * * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The list of detected objects */ - detect(image: ImageSource): Detection[] { - return this.processImageData(image); + detect(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + Detection[] { + this.detections = []; + this.processImageData(image, imageProcessingOptions); + return [...this.detections]; } /** @@ -173,20 +182,15 @@ export class ObjectDetector extends VisionTaskRunner { * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The list of detected objects */ - detectForVideo(videoFrame: ImageSource, timestamp: number): Detection[] { - return this.processVideoData(videoFrame, timestamp); - } - - /** Runs the object detector graph and blocks on the response. */ - protected override process(imageSource: ImageSource, timestamp: number): - Detection[] { - // Get detections by running our MediaPipe graph. + detectForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): Detection[] { this.detections = []; - this.graphRunner.addGpuBufferAsImageToStream( - imageSource, INPUT_STREAM, timestamp ?? performance.now()); - this.finishProcessing(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); return [...this.detections]; } @@ -226,7 +230,8 @@ export class ObjectDetector extends VisionTaskRunner { /** Updates the MediaPipe graph configuration. */ protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); - graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addOutputStream(DETECTIONS_STREAM); const calculatorOptions = new CalculatorOptions(); @@ -235,7 +240,8 @@ export class ObjectDetector extends VisionTaskRunner { const detectorNode = new CalculatorGraphConfig.Node(); detectorNode.setCalculator(OBJECT_DETECTOR_GRAPH); - detectorNode.addInputStream('IMAGE:' + INPUT_STREAM); + detectorNode.addInputStream('IMAGE:' + IMAGE_STREAM); + detectorNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); detectorNode.addOutputStream('DETECTIONS:' + DETECTIONS_STREAM); detectorNode.setOptions(calculatorOptions); From b4ede6db7bf85071893c7edd29cde9e5d7a288f9 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 6 Jan 2023 21:00:22 -0800 Subject: [PATCH 269/346] Fix typo in Category.java PiperOrigin-RevId: 500324008 --- .../mediapipe/tasks/components/containers/Category.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java index e955605e4..ab3fd0bd8 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java @@ -19,9 +19,9 @@ import com.google.mediapipe.formats.proto.ClassificationProto; import java.util.Objects; /** - * Category is a util class, contains a category name, its display name, a float value as score, and - * the index of the label in the corresponding label file. Typically it's used as result of - * classification or detection tasks. + * Category is a util class, that contains a category name, its display name, a float value as + * score, and the index of the label in the corresponding label file. Typically it's used as result + * of classification or detection tasks. */ @AutoValue public abstract class Category { From ed0054836a62d52771520aa4f07be6b1c5ad3962 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 6 Jan 2023 21:04:57 -0800 Subject: [PATCH 270/346] Allow task to recover after a failed graph start PiperOrigin-RevId: 500324587 --- mediapipe/tasks/web/core/task_runner.ts | 21 +++++++++++--------- mediapipe/tasks/web/core/task_runner_test.ts | 12 +++++++++++ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index a3df7adf5..c2679b773 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -164,16 +164,19 @@ export abstract class TaskRunner { /** Throws the error from the error listener if an error was raised. */ private handleErrors() { - const errorCount = this.processingErrors.length; - if (errorCount === 1) { - // Re-throw error to get a more meaningful stacktrace - throw new Error(this.processingErrors[0].message); - } else if (errorCount > 1) { - throw new Error( - 'Encountered multiple errors: ' + - this.processingErrors.map(e => e.message).join(', ')); + try { + const errorCount = this.processingErrors.length; + if (errorCount === 1) { + // Re-throw error to get a more meaningful stacktrace + throw new Error(this.processingErrors[0].message); + } else if (errorCount > 1) { + throw new Error( + 'Encountered multiple errors: ' + + this.processingErrors.map(e => e.message).join(', ')); + } + } finally { + this.processingErrors = []; } - this.processingErrors = []; } /** Configures the `externalFile` option */ diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts index 684beb70c..9a8aa32eb 100644 --- a/mediapipe/tasks/web/core/task_runner_test.ts +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -139,6 +139,18 @@ describe('TaskRunner', () => { }).toThrowError(/Test error 1, Test error 2/); }); + it('clears errors once thrown', () => { + taskRunner.enqueueError('Test error'); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).toThrowError(/Test error/); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).not.toThrow(); + }); + it('verifies that at least one model asset option is provided', () => { expect(() => { taskRunner.setOptions({}); From c9ebc6fa606888542ad89b978c2658c127d4226f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 6 Jan 2023 21:34:46 -0800 Subject: [PATCH 271/346] Use synthetic timestamps in Web when none provided PiperOrigin-RevId: 500327275 --- .../audio_classifier/audio_classifier.ts | 5 ++++- .../audio/audio_embedder/audio_embedder.ts | 18 ++++++++++++------ .../tasks/web/audio/core/audio_task_runner.ts | 5 ++++- mediapipe/tasks/web/core/task_runner.ts | 18 +++++++++++++++++- .../text/text_classifier/text_classifier.ts | 10 ++++++---- .../web/text/text_embedder/text_embedder.ts | 19 ++++++++++++------- .../web/vision/core/vision_task_runner.ts | 6 +++++- .../gesture_recognizer/gesture_recognizer.ts | 12 ++++++++---- .../vision/hand_landmarker/hand_landmarker.ts | 9 ++++++--- .../image_classifier/image_classifier.ts | 3 ++- .../vision/image_embedder/image_embedder.ts | 8 +++++--- .../vision/object_detector/object_detector.ts | 5 +++-- 12 files changed, 84 insertions(+), 34 deletions(-) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 92fca93ad..e26ead6a9 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -126,6 +126,8 @@ export class AudioClassifier extends AudioTaskRunner { return this.applyOptions(options); } + // TODO: Add a classifyStream() that takes a timestamp + /** * Performs audio classification on the provided audio clip and waits * synchronously for the response. @@ -194,8 +196,9 @@ export class AudioClassifier extends AudioTaskRunner { graphConfig.addNode(classifierNode); this.graphRunner.attachProtoVectorListener( - TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => { + TIMESTAMPED_CLASSIFICATIONS_STREAM, (binaryProtos, timestamp) => { this.addJsAudioClassificationResults(binaryProtos); + this.setLatestOutputTimestamp(timestamp); }); const binaryGraph = graphConfig.serializeBinary(); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 2e210f969..7411f95ef 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -128,6 +128,8 @@ export class AudioEmbedder extends AudioTaskRunner { return this.applyOptions(options); } + // TODO: Add a classifyStream() that takes a timestamp + /** * Performs embeding extraction on the provided audio clip and waits * synchronously for the response. @@ -193,20 +195,24 @@ export class AudioEmbedder extends AudioTaskRunner { graphConfig.addNode(embedderNode); - this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { - const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); - this.embeddingResults.push( - convertFromEmbeddingResultProto(embeddingResult)); - }); + this.graphRunner.attachProtoListener( + EMBEDDINGS_STREAM, (binaryProto, timestamp) => { + const embeddingResult = + EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + this.setLatestOutputTimestamp(timestamp); + }); this.graphRunner.attachProtoVectorListener( - TIMESTAMPED_EMBEDDINGS_STREAM, data => { + TIMESTAMPED_EMBEDDINGS_STREAM, (data, timestamp) => { for (const binaryProto of data) { const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); this.embeddingResults.push( convertFromEmbeddingResultProto(embeddingResult)); } + this.setLatestOutputTimestamp(timestamp); }); const binaryGraph = graphConfig.serializeBinary(); diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts index 24d78378d..ff39185f2 100644 --- a/mediapipe/tasks/web/audio/core/audio_task_runner.ts +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -36,8 +36,11 @@ export abstract class AudioTaskRunner extends TaskRunner { /** Sends a single audio clip to the graph and awaits results. */ protected processAudioClip(audioData: Float32Array, sampleRate?: number): T { + // Increment the timestamp by 1 millisecond to guarantee that we send + // monotonically increasing timestamps to the graph. + const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; return this.process( - audioData, sampleRate ?? this.defaultSampleRate, performance.now()); + audioData, sampleRate ?? this.defaultSampleRate, syntheticTimestamp); } } diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index c2679b773..8d483d9ff 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -50,7 +50,7 @@ export async function createTaskRunner( } }; - // Initialize a canvas if requested. If OffscreenCanvas is availble, we + // Initialize a canvas if requested. If OffscreenCanvas is available, we // let the graph runner initialize it by passing `undefined`. const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? document.createElement('canvas') : @@ -66,6 +66,7 @@ export async function createTaskRunner( export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; private processingErrors: Error[] = []; + private latestOutputTimestamp = 0; /** * Creates a new instance of a Mediapipe Task. Determines if SIMD is @@ -162,6 +163,21 @@ export abstract class TaskRunner { this.handleErrors(); } + /* + * Sets the latest output timestamp received from the graph (in ms). + * Timestamps that are smaller than the currently latest output timestamp are + * ignored. + */ + protected setLatestOutputTimestamp(timestamp: number): void { + this.latestOutputTimestamp = + Math.max(this.latestOutputTimestamp, timestamp); + } + + /** Returns the latest output timestamp. */ + protected getLatestOutputTimestamp() { + return this.latestOutputTimestamp; + } + /** Throws the error from the error listener if an error was raised. */ private handleErrors() { try { diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 6aef1b3e4..ff314cfc3 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -131,10 +131,11 @@ export class TextClassifier extends TaskRunner { * @return The classification result of the text */ classify(text: string): TextClassifierResult { - // Get classification result by running our MediaPipe graph. + // Increment the timestamp by 1 millisecond to guarantee that we send + // monotonically increasing timestamps to the graph. + const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; this.classificationResult = {classifications: []}; - this.graphRunner.addStringToStream( - text, INPUT_STREAM, /* timestamp= */ performance.now()); + this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp); this.finishProcessing(); return this.classificationResult; } @@ -158,9 +159,10 @@ export class TextClassifier extends TaskRunner { graphConfig.addNode(classifierNode); this.graphRunner.attachProtoListener( - CLASSIFICATIONS_STREAM, binaryProto => { + CLASSIFICATIONS_STREAM, (binaryProto, timestamp) => { this.classificationResult = convertFromClassificationResultProto( ClassificationResult.deserializeBinary(binaryProto)); + this.setLatestOutputTimestamp(timestamp); }); const binaryGraph = graphConfig.serializeBinary(); diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index db7986dec..daa1d24ed 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -135,9 +135,10 @@ export class TextEmbedder extends TaskRunner { * @return The embedding resuls of the text */ embed(text: string): TextEmbedderResult { - // Get text embeddings by running our MediaPipe graph. - this.graphRunner.addStringToStream( - text, INPUT_STREAM, /* timestamp= */ performance.now()); + // Increment the timestamp by 1 millisecond to guarantee that we send + // monotonically increasing timestamps to the graph. + const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; + this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp); this.finishProcessing(); return this.embeddingResult; } @@ -173,10 +174,14 @@ export class TextEmbedder extends TaskRunner { graphConfig.addNode(embedderNode); - this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { - const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); - this.embeddingResult = convertFromEmbeddingResultProto(embeddingResult); - }); + this.graphRunner.attachProtoListener( + EMBEDDINGS_STREAM, (binaryProto, timestamp) => { + const embeddingResult = + EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResult = + convertFromEmbeddingResultProto(embeddingResult); + this.setLatestOutputTimestamp(timestamp); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 9adc810fc..9ed9ffdb2 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -71,7 +71,11 @@ export abstract class VisionTaskRunner extends TaskRunner { 'Task is not initialized with image mode. ' + '\'runningMode\' must be set to \'image\'.'); } - this.process(image, imageProcessingOptions, performance.now()); + + // Increment the timestamp by 1 millisecond to guarantee that we send + // monotonically increasing timestamps to the graph. + const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; + this.process(image, imageProcessingOptions, syntheticTimestamp); } /** Sends a single video frame to the graph and awaits results. */ diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index e0c6affcb..48efc4855 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -380,23 +380,27 @@ export class GestureRecognizer extends VisionTaskRunner { graphConfig.addNode(recognizerNode); this.graphRunner.attachProtoVectorListener( - LANDMARKS_STREAM, binaryProto => { + LANDMARKS_STREAM, (binaryProto, timestamp) => { this.addJsLandmarks(binaryProto); + this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachProtoVectorListener( - WORLD_LANDMARKS_STREAM, binaryProto => { + WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => { this.adddJsWorldLandmarks(binaryProto); + this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachProtoVectorListener( - HAND_GESTURES_STREAM, binaryProto => { + HAND_GESTURES_STREAM, (binaryProto, timestamp) => { // Gesture index is not used, because the final gesture result comes // from multiple classifiers. this.gestures.push( ...this.toJsCategories(binaryProto, /* populateIndex= */ false)); + this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachProtoVectorListener( - HANDEDNESS_STREAM, binaryProto => { + HANDEDNESS_STREAM, (binaryProto, timestamp) => { this.handednesses.push(...this.toJsCategories(binaryProto)); + this.setLatestOutputTimestamp(timestamp); }); const binaryGraph = graphConfig.serializeBinary(); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index e238bc96f..b51fb6a52 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -313,16 +313,19 @@ export class HandLandmarker extends VisionTaskRunner { graphConfig.addNode(landmarkerNode); this.graphRunner.attachProtoVectorListener( - LANDMARKS_STREAM, binaryProto => { + LANDMARKS_STREAM, (binaryProto, timestamp) => { this.addJsLandmarks(binaryProto); + this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachProtoVectorListener( - WORLD_LANDMARKS_STREAM, binaryProto => { + WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => { this.adddJsWorldLandmarks(binaryProto); + this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachProtoVectorListener( - HANDEDNESS_STREAM, binaryProto => { + HANDEDNESS_STREAM, (binaryProto, timestamp) => { this.handednesses.push(...this.toJsCategories(binaryProto)); + this.setLatestOutputTimestamp(timestamp); }); const binaryGraph = graphConfig.serializeBinary(); diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 2ad4a821d..cb2849cd8 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -187,9 +187,10 @@ export class ImageClassifier extends VisionTaskRunner { graphConfig.addNode(classifierNode); this.graphRunner.attachProtoListener( - CLASSIFICATIONS_STREAM, binaryProto => { + CLASSIFICATIONS_STREAM, (binaryProto, timestamp) => { this.classificationResult = convertFromClassificationResultProto( ClassificationResult.deserializeBinary(binaryProto)); + this.setLatestOutputTimestamp(timestamp); }); const binaryGraph = graphConfig.serializeBinary(); diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 64a10f5f4..788646e6d 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -206,9 +206,11 @@ export class ImageEmbedder extends VisionTaskRunner { graphConfig.addNode(embedderNode); - this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { - this.addJsImageEmdedding(binaryProto); - }); + this.graphRunner.attachProtoListener( + EMBEDDINGS_STREAM, (binaryProto, timestamp) => { + this.addJsImageEmdedding(binaryProto); + this.setLatestOutputTimestamp(timestamp); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 3a79c1b00..5741a3a0c 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -176,7 +176,7 @@ export class ObjectDetector extends VisionTaskRunner { } /** - * Performs object detection on the provided vidoe frame and waits + * Performs object detection on the provided video frame and waits * synchronously for the response. Only use this method when the * ObjectDetector is created with running mode `video`. * @@ -248,8 +248,9 @@ export class ObjectDetector extends VisionTaskRunner { graphConfig.addNode(detectorNode); this.graphRunner.attachProtoVectorListener( - DETECTIONS_STREAM, binaryProto => { + DETECTIONS_STREAM, (binaryProto, timestamp) => { this.addJsObjectDetections(binaryProto); + this.setLatestOutputTimestamp(timestamp); }); const binaryGraph = graphConfig.serializeBinary(); From 7f043b7de1f4230359c4b16e5deae58cb9ea50b2 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 6 Jan 2023 21:40:09 -0800 Subject: [PATCH 272/346] Allow split_vector_calculator to be build with iOS and MEDIAPIPE_DISABLE_GPU PiperOrigin-RevId: 500327774 --- mediapipe/calculators/core/BUILD | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index b3378a74e..df54c5800 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -13,12 +13,21 @@ # limitations under the License. # +load("@bazel_skylib//lib:selects.bzl", "selects") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) package(default_visibility = ["//visibility:public"]) +selects.config_setting_group( + name = "ios_or_disable_gpu", + match_any = [ + "//mediapipe/gpu:disable_gpu", + "//mediapipe:ios", + ], +) + mediapipe_proto_library( name = "concatenate_vector_calculator_proto", srcs = ["concatenate_vector_calculator.proto"], @@ -899,8 +908,7 @@ cc_library( "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ] + select({ - "//mediapipe/gpu:disable_gpu": [], - "//mediapipe:ios": [], + ":ios_or_disable_gpu": [], "//conditions:default": [ "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", ], From e0a254789a1ec05f3c09411b45a6c59d0ed3075e Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Fri, 6 Jan 2023 22:13:13 -0800 Subject: [PATCH 273/346] Internal change. PiperOrigin-RevId: 500331015 --- mediapipe/framework/formats/tensor/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/framework/formats/tensor/BUILD b/mediapipe/framework/formats/tensor/BUILD index c634b0dda..3895fc82e 100644 --- a/mediapipe/framework/formats/tensor/BUILD +++ b/mediapipe/framework/formats/tensor/BUILD @@ -13,7 +13,7 @@ # limitations under the License. package( - default_visibility = ["//visibility:public"], + default_visibility = ["//visibility:private"], features = ["-layering_check"], ) From 1bbe065647b30f7b457df56747b24510c225258d Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 9 Jan 2023 09:11:37 -0800 Subject: [PATCH 274/346] Simplify default options for GestureRecognize PiperOrigin-RevId: 500729643 --- mediapipe/tasks/testdata/vision/BUILD | 2 + .../gesture_recognizer/gesture_recognizer.ts | 39 +++++-------------- third_party/external_files.bzl | 6 +++ 3 files changed, 17 insertions(+), 30 deletions(-) diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 95b721fdb..607245700 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -38,6 +38,7 @@ mediapipe_files(srcs = [ "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", "fist.jpg", + "fist.png", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "hand_landmarker.task", @@ -95,6 +96,7 @@ filegroup( "cats_and_dogs_no_resizing.jpg", "cats_and_dogs_rotated.jpg", "fist.jpg", + "fist.png", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "left_hands.jpg", diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 48efc4855..1b7201b9a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -54,7 +54,7 @@ const GESTURE_RECOGNIZER_GRAPH = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'; const DEFAULT_NUM_HANDS = 1; -const DEFAULT_SCORE_THRESHOLD = 0.5; +const DEFAULT_CONFIDENCE = 0.5; const DEFAULT_CATEGORY_INDEX = -1; /** Performs hand gesture recognition on images. */ @@ -143,8 +143,6 @@ export class GestureRecognizer extends VisionTaskRunner { new HandGestureRecognizerGraphOptions(); this.options.setHandGestureRecognizerGraphOptions( this.handGestureRecognizerGraphOptions); - - this.initDefaults(); } protected override get baseOptions(): BaseOptionsProto { @@ -165,22 +163,14 @@ export class GestureRecognizer extends VisionTaskRunner { * @param options The options for the gesture recognizer. */ override setOptions(options: GestureRecognizerOptions): Promise { - if ('numHands' in options) { - this.handDetectorGraphOptions.setNumHands( - options.numHands ?? DEFAULT_NUM_HANDS); - } - if ('minHandDetectionConfidence' in options) { - this.handDetectorGraphOptions.setMinDetectionConfidence( - options.minHandDetectionConfidence ?? DEFAULT_SCORE_THRESHOLD); - } - if ('minHandPresenceConfidence' in options) { - this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( - options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); - } - if ('minTrackingConfidence' in options) { - this.handLandmarkerGraphOptions.setMinTrackingConfidence( - options.minTrackingConfidence ?? DEFAULT_SCORE_THRESHOLD); - } + this.handDetectorGraphOptions.setNumHands( + options.numHands ?? DEFAULT_NUM_HANDS); + this.handDetectorGraphOptions.setMinDetectionConfidence( + options.minHandDetectionConfidence ?? DEFAULT_CONFIDENCE); + this.handLandmarkerGraphOptions.setMinTrackingConfidence( + options.minTrackingConfidence ?? DEFAULT_CONFIDENCE); + this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( + options.minHandPresenceConfidence ?? DEFAULT_CONFIDENCE); if (options.cannedGesturesClassifierOptions) { // Note that we have to support both JSPB and ProtobufJS and cannot @@ -281,17 +271,6 @@ export class GestureRecognizer extends VisionTaskRunner { } } - /** Sets the default values for the graph. */ - private initDefaults(): void { - this.handDetectorGraphOptions.setNumHands(DEFAULT_NUM_HANDS); - this.handDetectorGraphOptions.setMinDetectionConfidence( - DEFAULT_SCORE_THRESHOLD); - this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( - DEFAULT_SCORE_THRESHOLD); - this.handLandmarkerGraphOptions.setMinTrackingConfidence( - DEFAULT_SCORE_THRESHOLD); - } - /** Converts the proto data to a Category[][] structure. */ private toJsCategories(data: Uint8Array[], populateIndex = true): Category[][] { diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 72ca95e66..790486676 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -286,6 +286,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/fist_landmarks.pbtxt?generation=1666999360561864"], ) + http_file( + name = "com_google_mediapipe_fist_png", + sha256 = "4397b3d3f590c88a8de7d21c08d73a0df4a97fd93f92cbd086eef37fd246daaa", + urls = ["https://storage.googleapis.com/mediapipe-assets/fist.png?generation=1672952068696274"], + ) + http_file( name = "com_google_mediapipe_general_meta_json", sha256 = "b95363e4bae89b9c2af484498312aaad4efc7ff57c7eadcc4e5e7adca641445f", From 2b9299959cddc5505cb1d28fc50a2f9d46702f12 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 9 Jan 2023 09:14:05 -0800 Subject: [PATCH 275/346] Internal change PiperOrigin-RevId: 500730237 --- .../web/vision/object_detector/object_detector_result.d.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts index e9e3843bc..c9c87a1bf 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts @@ -16,6 +16,8 @@ import {Category} from '../../../../tasks/web/components/containers/category'; +export {Category}; + /** An integer bounding box, axis aligned. */ export declare interface BoundingBox { /** The X coordinate of the top-left corner, in pixels. */ From c6cf598774810fdf45f325a8b5cb083884a13e6d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 9 Jan 2023 09:52:04 -0800 Subject: [PATCH 276/346] Minor fix for max_queue_size documentation PiperOrigin-RevId: 500738798 --- mediapipe/framework/calculator.proto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/framework/calculator.proto b/mediapipe/framework/calculator.proto index 7c5e8b144..eecd033c9 100644 --- a/mediapipe/framework/calculator.proto +++ b/mediapipe/framework/calculator.proto @@ -382,7 +382,7 @@ message CalculatorGraphConfig { // is empty and no other nodes are running (to prevent possible deadlocks due // to a incorrectly specified value). This global parameter is set to 100 // packets by default to enable pipelining. If any node indicates that it - // buffers packets before emitting them, then the max(node_buffer_size, + // buffers packets before emitting them, then the max(buffer_size_hint, // max_queue_size) is used. Set this parameter to -1 to disable throttling // (i.e. the graph will use as much memory as it requires). If not specified, // the limit is 100 packets. From 73f4636292b4ee65c36863a664b3dfb9e11b36a5 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 9 Jan 2023 10:34:26 -0800 Subject: [PATCH 277/346] Create README.md files to NPM packages PiperOrigin-RevId: 500750516 --- mediapipe/tasks/web/BUILD | 3 ++ mediapipe/tasks/web/audio/BUILD | 2 + mediapipe/tasks/web/audio/README.md | 31 +++++++++++ mediapipe/tasks/web/text/BUILD | 2 + mediapipe/tasks/web/text/README.md | 34 ++++++++++++ mediapipe/tasks/web/vision/BUILD | 2 + mediapipe/tasks/web/vision/README.md | 78 ++++++++++++++++++++++++++++ 7 files changed, 152 insertions(+) create mode 100644 mediapipe/tasks/web/audio/README.md create mode 100644 mediapipe/tasks/web/text/README.md create mode 100644 mediapipe/tasks/web/vision/README.md diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index bc9e84147..02bd70dd0 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -65,6 +65,7 @@ pkg_npm( "wasm/audio_wasm_nosimd_internal.js", "wasm/audio_wasm_nosimd_internal.wasm", ":audio_bundle", + "//mediapipe/tasks/web/audio:README.md", ], ) @@ -108,6 +109,7 @@ pkg_npm( "wasm/text_wasm_nosimd_internal.js", "wasm/text_wasm_nosimd_internal.wasm", ":text_bundle", + "//mediapipe/tasks/web/text:README.md", ], ) @@ -151,5 +153,6 @@ pkg_npm( "wasm/vision_wasm_nosimd_internal.js", "wasm/vision_wasm_nosimd_internal.wasm", ":vision_bundle", + "//mediapipe/tasks/web/vision:README.md", ], ) diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 9d26f1118..50a611f41 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -4,6 +4,8 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) +exports_files(["README.md"]) + mediapipe_ts_library( name = "audio_lib", srcs = ["index.ts"], diff --git a/mediapipe/tasks/web/audio/README.md b/mediapipe/tasks/web/audio/README.md new file mode 100644 index 000000000..834785709 --- /dev/null +++ b/mediapipe/tasks/web/audio/README.md @@ -0,0 +1,31 @@ +# MediaPipe Tasks Vision Package + +This package contains the audio tasks for MediaPipe. + +## Audio Classification + +The MediaPipe Audio Classification task performs classification on audio data. + +``` +const audio = await FilesetResolver.forAudioTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-audio@latest/wasm" +); +const audioClassifier = await AudioClassifier.createFromModelPath(audio, + "https://storage.googleapis.com/mediapipe-tasks/audio_classifier/yamnet_audio_classifier_with_metadata.tflite" +); +const classifications = audioClassifier.classifiy(audioData); +``` + +## Audio Embedding + +The MediaPipe Audio Embedding task extracts embeddings from audio data. + +``` +const audio = await FilesetResolver.forAudioTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-audio@latest/wasm" +); +const audioEmbedder = await AudioEmbedder.createFromModelPath(audio, + "model.tflite" +); +const embeddings = audioEmbedder.embed(audioData); +``` diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index 32f43d4b6..077b25645 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -4,6 +4,8 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) +exports_files(["README.md"]) + mediapipe_ts_library( name = "text_lib", srcs = ["index.ts"], diff --git a/mediapipe/tasks/web/text/README.md b/mediapipe/tasks/web/text/README.md new file mode 100644 index 000000000..247dc6d30 --- /dev/null +++ b/mediapipe/tasks/web/text/README.md @@ -0,0 +1,34 @@ +# MediaPipe Tasks Text Package + +This package contains the text tasks for MediaPipe. + +## Text Classification + +MediaPipe Text Classifier task lets you classify text into a set of defined +categories, such as positive or negative sentiment. + +``` +const text = await FilesetResolver.forTextTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-text@latest/wasm" +); +const textClassifier = await TextClassifier.createFromModelPath(text, + "https://storage.googleapis.com/mediapipe-tasks/text_classifier/bert_text_classifier.tflite" +); +const classifications = textClassifier.classifiy(textData); +``` + +For more information, refer to the [Text Classification](https://developers.google.com/mediapipe/solutions/text/text_classifier/web_js) documentation. + +## Text Embedding + +The MediaPipe Text Embedding task extracts embeddings from text data. + +``` +const text = await FilesetResolver.forTextTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-text@latest/wasm" +); +const textEmbedder = await TextEmbedder.createFromModelPath(text, + "https://storage.googleapis.com/mediapipe-tasks/text_embedder/mobilebert_embedding_with_metadata.tflite" +); +const embeddings = textEmbedder.embed(textData); +``` diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 93493e873..ea022e900 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -4,6 +4,8 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) +exports_files(["README.md"]) + mediapipe_ts_library( name = "vision_lib", srcs = ["index.ts"], diff --git a/mediapipe/tasks/web/vision/README.md b/mediapipe/tasks/web/vision/README.md new file mode 100644 index 000000000..51f43821c --- /dev/null +++ b/mediapipe/tasks/web/vision/README.md @@ -0,0 +1,78 @@ +# MediaPipe Tasks Vision Package + +This package contains the vision tasks for MediaPipe. + +## Object Detection + +The MediaPipe Object Detector task lets you detect the presence and location of +multiple classes of objects within images or videos. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const objectDetector = await ObjectDetector.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite" +); +const image = document.getElementById("image") as HTMLImageElement; +const detections = objectDetector.detect(image); +``` + +For more information, refer to the [Object Detector](https://developers.google.com/mediapipe/solutions/vision/object_detector/web_js) documentation. + +## Image Classification + +The MediaPipe Image Classifier task lets you perform classification on images. +You can use this task to identify what an image represents among a set of +categories defined at training time. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const imageClassifier = await ImageClassifier.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/image_classifier/efficientnet_lite0_uint8.tflite" +); +const image = document.getElementById("image") as HTMLImageElement; +const classifications = imageClassifier.classify(image); +``` + +For more information, refer to the [Image Classification](https://developers.google.com/mediapipe/solutions/vision/image_classifier/web_js) documentation. + +## Gesture Recognition + +The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real +time, and provides the recognized hand gesture results along with the landmarks +of the detected hands. You can use this task to recognize specific hand gestures +from a user, and invoke application features that correspond to those gestures. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/gesture_recognizer/gesture_recognizer.task" +); +const image = document.getElementById("image") as HTMLImageElement; +const recognitions = gestureRecognizer.recognize(image); +``` + +## Handlandmark Detection + +The MediaPipe Hand Landmarker task lets you detect the landmarks of the hands in +an image. You can use this Task to localize key points of the hands and render +visual effects over the hands. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const handLandmarker = await HandLandmarker.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/hand_landmarker/hand_landmarker.task" +); +const image = document.getElementById("image") as HTMLImageElement; +const landmarks = handLandmarker.detect(image); +``` + +For more information, refer to the [Handlandmark Detection](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker/web_js) documentation. + From d40fa6b16d9e14cf0ac7ff30efa45eef588567d5 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 9 Jan 2023 11:02:48 -0800 Subject: [PATCH 278/346] Internal Model Maker change. PiperOrigin-RevId: 500758488 --- .../python/core/tasks/classifier.py | 16 ++- .../python/core/utils/model_util.py | 4 +- .../python/vision/image_classifier/BUILD | 10 -- .../vision/image_classifier/__init__.py | 1 - .../image_classifier/image_classifier.py | 96 +++++++++-------- .../train_image_classifier_lib.py | 102 ------------------ 6 files changed, 67 insertions(+), 162 deletions(-) delete mode 100644 mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index f376edffa..0908dddf5 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -48,11 +48,12 @@ class Classifier(custom_model.CustomModel): self._hparams: hp.BaseHParams = None self._history: tf.keras.callbacks.History = None - # TODO: Integrate this into all Model Maker tasks. + # TODO: Integrate this into GestureRecognizer. def _train_model(self, train_data: classification_ds.ClassificationDataset, validation_data: classification_ds.ClassificationDataset, - preprocessor: Optional[Callable[..., bool]] = None): + preprocessor: Optional[Callable[..., bool]] = None, + checkpoint_path: Optional[str] = None): """Trains the classifier model. Compiles and fits the tf.keras `_model` and records the `_history`. @@ -62,6 +63,9 @@ class Classifier(custom_model.CustomModel): validation_data: Validation data. preprocessor: An optional data preprocessor that can be used when generating a tf.data.Dataset. + checkpoint_path: An optional directory for the checkpoint file to support + continual training. If provided, loads model weights from the latest + checkpoint in the directory. """ tf.compat.v1.logging.info('Training the models...') if len(train_data) < self._hparams.batch_size: @@ -88,6 +92,14 @@ class Classifier(custom_model.CustomModel): optimizer=self._optimizer, loss=self._loss_function, metrics=[self._metric_function]) + + latest_checkpoint = ( + tf.train.latest_checkpoint(checkpoint_path) + if checkpoint_path else None) + if latest_checkpoint: + print(f'Resuming from {latest_checkpoint}') + self._model.load_weights(latest_checkpoint) + self._history = self._model.fit( x=train_dataset, epochs=self._hparams.epochs, diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index f10d9390c..db02444df 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -42,7 +42,9 @@ def get_default_callbacks( checkpoint_path = os.path.join(export_dir, 'checkpoint') checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( - checkpoint_path, save_weights_only=True) + os.path.join(checkpoint_path, 'model-{epoch:04d}'), + save_weights_only=True, + period=5) return [summary_callback, checkpoint_callback] diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index d7c47a359..bd916a92b 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -87,15 +87,6 @@ py_library( ], ) -py_library( - name = "train_image_classifier_lib", - srcs = ["train_image_classifier_lib.py"], - deps = [ - ":hyperparameters", - "//mediapipe/model_maker/python/core/utils:model_util", - ], -) - py_library( name = "image_classifier", srcs = ["image_classifier.py"], @@ -104,7 +95,6 @@ py_library( ":image_classifier_options", ":model_options", ":model_spec", - ":train_image_classifier_lib", "//mediapipe/model_maker/python/core/data:classification_dataset", "//mediapipe/model_maker/python/core/tasks:classifier", "//mediapipe/model_maker/python/core/utils:model_util", diff --git a/mediapipe/model_maker/python/vision/image_classifier/__init__.py b/mediapipe/model_maker/python/vision/image_classifier/__init__.py index 0f964ef66..4cde9e7e3 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/__init__.py +++ b/mediapipe/model_maker/python/vision/image_classifier/__init__.py @@ -35,4 +35,3 @@ del image_classifier del image_classifier_options del model_options del model_spec -del train_image_classifier_lib # pylint: disable=undefined-variable diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index df71a8fef..c2181121c 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -28,7 +28,6 @@ from mediapipe.model_maker.python.vision.image_classifier import hyperparameters from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms -from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer @@ -57,6 +56,10 @@ class ImageClassifier(classifier.Classifier): mean_rgb=self._model_spec.mean_rgb, stddev_rgb=self._model_spec.stddev_rgb, use_augmentation=hparams.do_data_augmentation) + self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) + self._loss_function = tf.keras.losses.CategoricalCrossentropy( + label_smoothing=self._hparams.label_smoothing) + self._metric_function = 'accuracy' self._history = None # Training history returned from `keras_model.fit`. @classmethod @@ -66,7 +69,7 @@ class ImageClassifier(classifier.Classifier): validation_data: classification_ds.ClassificationDataset, options: image_classifier_options.ImageClassifierOptions, ) -> 'ImageClassifier': - """Creates and trains an image classifier. + """Creates and trains an ImageClassifier. Loads data and trains the model based on data for image classification. If a checkpoint file exists in the {options.hparams.export_dir}/checkpoint/ @@ -93,58 +96,29 @@ class ImageClassifier(classifier.Classifier): label_names=train_data.label_names, hparams=options.hparams, model_options=options.model_options) - - image_classifier._create_model() - - tf.compat.v1.logging.info('Training the models...') - image_classifier._train( - train_data=train_data, validation_data=validation_data) - + image_classifier._create_and_train_model(train_data, validation_data) return image_classifier - # TODO: Migrate to the shared training library of Model Maker. - def _train(self, train_data: classification_ds.ClassificationDataset, - validation_data: classification_ds.ClassificationDataset): - """Trains the model with input train_data. - - The training results are recorded by a self._history object returned by - tf.keras.Model.fit(). + def _create_and_train_model( + self, train_data: classification_ds.ClassificationDataset, + validation_data: classification_ds.ClassificationDataset): + """Creates and trains the model and optimizer. Args: train_data: Training data. validation_data: Validation data. """ - - tf.compat.v1.logging.info('Training the models...') - hparams = self._hparams - if len(train_data) < hparams.batch_size: - raise ValueError('The size of the train_data (%d) couldn\'t be smaller ' - 'than batch_size (%d). To solve this problem, set ' - 'the batch_size smaller or increase the size of the ' - 'train_data.' % (len(train_data), hparams.batch_size)) - - train_dataset = train_data.gen_tf_dataset( - batch_size=hparams.batch_size, - is_training=True, - shuffle=self._shuffle, - preprocess=self._preprocess) - hparams.steps_per_epoch = model_util.get_steps_per_epoch( - steps_per_epoch=hparams.steps_per_epoch, - batch_size=hparams.batch_size, + self._create_model() + self._hparams.steps_per_epoch = model_util.get_steps_per_epoch( + steps_per_epoch=self._hparams.steps_per_epoch, + batch_size=self._hparams.batch_size, train_data=train_data) - train_dataset = train_dataset.take(count=hparams.steps_per_epoch) - - validation_dataset = validation_data.gen_tf_dataset( - batch_size=hparams.batch_size, - is_training=False, - preprocess=self._preprocess) - - # Train the model. - self._history = train_image_classifier_lib.train_model( - model=self._model, - hparams=hparams, - train_ds=train_dataset, - validation_ds=validation_dataset) + self._optimizer = self._create_optimizer() + self._train_model( + train_data=train_data, + validation_data=validation_data, + preprocessor=self._preprocess, + checkpoint_path=os.path.join(self._hparams.export_dir, 'checkpoint')) def _create_model(self): """Creates the classifier model from TFHub pretrained models.""" @@ -198,3 +172,33 @@ class ImageClassifier(classifier.Classifier): model_util.save_tflite(tflite_model_with_metadata, tflite_file) with open(metadata_file, 'w') as f: f.write(metadata_json) + + def _create_optimizer(self) -> tf.keras.optimizers.Optimizer: + """Creates an optimizer with learning rate schedule. + + Uses Keras CosineDecay schedule for the learning rate by default. + + Returns: + A tf.keras.optimizers.Optimizer for model training. + """ + # Learning rate is linear to batch size. + init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256 + + # Get decay steps. + total_training_steps = self._hparams.steps_per_epoch * self._hparams.epochs + default_decay_steps = ( + self._hparams.decay_samples // self._hparams.batch_size) + decay_steps = max(total_training_steps, default_decay_steps) + + learning_rate_fn = tf.keras.experimental.CosineDecay( + initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0) + warmup_steps = self._hparams.warmup_epochs * self._hparams.steps_per_epoch + if warmup_steps: + learning_rate_fn = model_util.WarmUp( + initial_learning_rate=init_lr, + decay_schedule_fn=learning_rate_fn, + warmup_steps=warmup_steps) + optimizer = tf.keras.optimizers.RMSprop( + learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001) + + return optimizer diff --git a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py deleted file mode 100644 index c5b28cff5..000000000 --- a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Library to train model.""" - -import os -import tensorflow as tf - -from mediapipe.model_maker.python.core.utils import model_util -from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp - - -def _create_optimizer(init_lr: float, decay_steps: int, - warmup_steps: int) -> tf.keras.optimizers.Optimizer: - """Creates an optimizer with learning rate schedule. - - Uses Keras CosineDecay schedule for the learning rate by default. - - Args: - init_lr: Initial learning rate. - decay_steps: Number of steps to decay over. - warmup_steps: Number of steps to do warmup for. - - Returns: - A tf.keras.optimizers.Optimizer for model training. - """ - learning_rate_fn = tf.keras.experimental.CosineDecay( - initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0) - if warmup_steps: - learning_rate_fn = model_util.WarmUp( - initial_learning_rate=init_lr, - decay_schedule_fn=learning_rate_fn, - warmup_steps=warmup_steps) - optimizer = tf.keras.optimizers.RMSprop( - learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001) - - return optimizer - - -def train_model(model: tf.keras.Model, hparams: hp.HParams, - train_ds: tf.data.Dataset, - validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History: - """Trains model with the input data and hyperparameters. - - Args: - model: Input tf.keras.Model. - hparams: Hyperparameters for training image classifier. - train_ds: tf.data.Dataset, training data to be fed in tf.keras.Model.fit(). - validation_ds: tf.data.Dataset, validation data to be fed in - tf.keras.Model.fit(). - - Returns: - The tf.keras.callbacks.History object returned by tf.keras.Model.fit(). - """ - - # Learning rate is linear to batch size. - learning_rate = hparams.learning_rate * hparams.batch_size / 256 - - # Get decay steps. - # NOMUTANTS--(b/256493858):Plan to test it in the unified training library. - total_training_steps = hparams.steps_per_epoch * hparams.epochs - default_decay_steps = hparams.decay_samples // hparams.batch_size - decay_steps = max(total_training_steps, default_decay_steps) - - warmup_steps = hparams.warmup_epochs * hparams.steps_per_epoch - optimizer = _create_optimizer( - init_lr=learning_rate, decay_steps=decay_steps, warmup_steps=warmup_steps) - - loss = tf.keras.losses.CategoricalCrossentropy( - label_smoothing=hparams.label_smoothing) - model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) - - summary_dir = os.path.join(hparams.export_dir, 'summaries') - summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) - # Save checkpoint every 5 epochs. - checkpoint_path = os.path.join(hparams.export_dir, 'checkpoint') - checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( - os.path.join(checkpoint_path, 'model-{epoch:04d}'), - save_weights_only=True, - period=5) - - latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path) - if latest_checkpoint: - print(f'Resuming from {latest_checkpoint}') - model.load_weights(latest_checkpoint) - - # Train the model. - return model.fit( - x=train_ds, - epochs=hparams.epochs, - validation_data=validation_ds, - callbacks=[summary_callback, checkpoint_callback]) From 08310231145b3c82e3d72effa49e081960b6be58 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 9 Jan 2023 11:09:28 -0800 Subject: [PATCH 279/346] Use uppercase enum constants for RunningMode PiperOrigin-RevId: 500760402 --- .../tasks/web/vision/core/vision_task_options.d.ts | 2 +- .../web/vision/core/vision_task_runner.test.ts | 14 +++++++------- .../tasks/web/vision/core/vision_task_runner.ts | 6 +++--- .../vision/image_embedder/image_embedder_test.ts | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts index 76c0177a0..44b1660ff 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts @@ -21,7 +21,7 @@ import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options' * 1) The image mode for processing single image inputs. * 2) The video mode for processing decoded frames of a video. */ -export type RunningMode = 'image'|'video'; +export type RunningMode = 'IMAGE'|'VIDEO'; /** The options for configuring a MediaPipe vision task. */ export declare interface VisionTaskOptions extends TaskRunnerOptions { diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts index a48381038..4567134b8 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -118,19 +118,19 @@ describe('VisionTaskRunner', () => { }); it('can enable image mode', async () => { - await visionTaskRunner.setOptions({runningMode: 'image'}); + await visionTaskRunner.setOptions({runningMode: 'IMAGE'}); expect(visionTaskRunner.baseOptions.toObject()) .toEqual(jasmine.objectContaining({useStreamMode: false})); }); it('can enable video mode', async () => { - await visionTaskRunner.setOptions({runningMode: 'video'}); + await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); expect(visionTaskRunner.baseOptions.toObject()) .toEqual(jasmine.objectContaining({useStreamMode: true})); }); it('can clear running mode', async () => { - await visionTaskRunner.setOptions({runningMode: 'video'}); + await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); // Clear running mode await visionTaskRunner.setOptions( @@ -140,7 +140,7 @@ describe('VisionTaskRunner', () => { }); it('cannot process images with video mode', async () => { - await visionTaskRunner.setOptions({runningMode: 'video'}); + await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); expect(() => { visionTaskRunner.processImageData( IMAGE, /* imageProcessingOptions= */ undefined); @@ -155,7 +155,7 @@ describe('VisionTaskRunner', () => { }).toThrowError(/Task is not initialized with video mode./); // Explicitly set to image mode - await visionTaskRunner.setOptions({runningMode: 'image'}); + await visionTaskRunner.setOptions({runningMode: 'IMAGE'}); expect(() => { visionTaskRunner.processVideoData( IMAGE, /* imageProcessingOptions= */ undefined, TIMESTAMP); @@ -163,7 +163,7 @@ describe('VisionTaskRunner', () => { }); it('sends packets to graph', async () => { - await visionTaskRunner.setOptions({runningMode: 'video'}); + await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); visionTaskRunner.expectImage(IMAGE); visionTaskRunner.expectNormalizedRect(0.5, 0.5, 1, 1); @@ -172,7 +172,7 @@ describe('VisionTaskRunner', () => { }); it('sends packets to graph with image processing options', async () => { - await visionTaskRunner.setOptions({runningMode: 'video'}); + await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); visionTaskRunner.expectImage(IMAGE); visionTaskRunner.expectNormalizedRect(0.3, 0.6, 0.2, 0.4); diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 9ed9ffdb2..71cac920c 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -56,7 +56,7 @@ export abstract class VisionTaskRunner extends TaskRunner { override applyOptions(options: VisionTaskOptions): Promise { if ('runningMode' in options) { const useStreamMode = - !!options.runningMode && options.runningMode !== 'image'; + !!options.runningMode && options.runningMode !== 'IMAGE'; this.baseOptions.setUseStreamMode(useStreamMode); } return super.applyOptions(options); @@ -69,7 +69,7 @@ export abstract class VisionTaskRunner extends TaskRunner { if (!!this.baseOptions?.getUseStreamMode()) { throw new Error( 'Task is not initialized with image mode. ' + - '\'runningMode\' must be set to \'image\'.'); + '\'runningMode\' must be set to \'IMAGE\'.'); } // Increment the timestamp by 1 millisecond to guarantee that we send @@ -86,7 +86,7 @@ export abstract class VisionTaskRunner extends TaskRunner { if (!this.baseOptions?.getUseStreamMode()) { throw new Error( 'Task is not initialized with video mode. ' + - '\'runningMode\' must be set to \'video\'.'); + '\'runningMode\' must be set to \'VIDEO\'.'); } this.process(imageFrame, imageProcessingOptions, timestamp); } diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts index 01ec751e3..5a8293c44 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts @@ -143,7 +143,7 @@ describe('ImageEmbedder', () => { }); it('for video mode', async () => { - await imageEmbedder.setOptions({runningMode: 'video'}); + await imageEmbedder.setOptions({runningMode: 'VIDEO'}); // Invoke the video embedder const embeddingResult = From 704964be33d737c44e6154fde410b363df161e73 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 9 Jan 2023 14:03:42 -0800 Subject: [PATCH 280/346] Fix accidental suppressions of GLSL linker error reporting PiperOrigin-RevId: 500802177 --- mediapipe/gpu/shader_util.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/gpu/shader_util.cc b/mediapipe/gpu/shader_util.cc index 2132cbda9..5de7e24f5 100644 --- a/mediapipe/gpu/shader_util.cc +++ b/mediapipe/gpu/shader_util.cc @@ -140,7 +140,7 @@ GLint GlhCreateProgram(const GLchar* vert_src, const GLchar* frag_src, glBindAttribLocation(*program, attr_locations[i], attr_names[i]); } - ok = GlhLinkProgram(*program); + ok = GlhLinkProgram(*program, force_log_errors); } if (vert_shader) glDeleteShader(vert_shader); From 76a7c9d5d488eb1c661bd6cb219eba35f7cd07ed Mon Sep 17 00:00:00 2001 From: Liam Miller-Cushon Date: Mon, 9 Jan 2023 14:47:21 -0800 Subject: [PATCH 281/346] Internal change PiperOrigin-RevId: 500813290 --- .../android/solutions/gradle/wrapper/gradle-wrapper.properties | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties index 41dfb8790..070cb702f 100644 --- a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties +++ b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.4-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists From d7ee875356012514d8d5287a360cb8ea391ad0b2 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 9 Jan 2023 16:15:52 -0800 Subject: [PATCH 282/346] Fix spacing issue in test name PiperOrigin-RevId: 500833769 --- .../web/vision/gesture_recognizer/gesture_recognizer_test.ts | 2 +- .../tasks/web/vision/hand_landmarker/hand_landmarker_test.ts | 2 +- .../tasks/web/vision/object_detector/object_detector_test.ts | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index 3699033b2..dfc252eb6 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -147,7 +147,7 @@ describe('GestureRecognizer', () => { ]); }); - describe('setOptions() ', () => { + describe('setOptions()', () => { interface TestCase { optionPath: [keyof GestureRecognizerOptions, ...string[]]; fieldPath: string[]; diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts index bce0eac02..0abd1df27 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -129,7 +129,7 @@ describe('HandLandmarker', () => { ]); }); - describe('setOptions() ', () => { + describe('setOptions()', () => { interface TestCase { optionPath: [keyof HandLandmarkerOptions, ...string[]]; fieldPath: string[]; diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts index 5bfb74ab6..ceb96acb1 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -111,7 +111,7 @@ describe('ObjectDetector', () => { verifyGraph(objectDetector, ['displayNamesLocale', 'en']); }); - describe('setOptions() ', () => { + describe('setOptions()', () => { interface TestCase { optionName: keyof ObjectDetectorOptions; protoName: string; From 6032604f94208bf9649a97a564046984ac538819 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 10 Jan 2023 08:42:07 -0800 Subject: [PATCH 283/346] Hide base task api classes for MediaPipe Tasks Python from API docs PiperOrigin-RevId: 501004802 --- mediapipe/tasks/python/audio/core/base_audio_task_api.py | 3 +-- mediapipe/tasks/python/core/BUILD | 1 + mediapipe/tasks/python/core/task_info.py | 2 ++ mediapipe/tasks/python/text/core/base_text_task_api.py | 3 +-- mediapipe/tasks/python/vision/core/base_vision_task_api.py | 3 +-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/tasks/python/audio/core/base_audio_task_api.py b/mediapipe/tasks/python/audio/core/base_audio_task_api.py index b2197c142..5b08a2b76 100644 --- a/mediapipe/tasks/python/audio/core/base_audio_task_api.py +++ b/mediapipe/tasks/python/audio/core/base_audio_task_api.py @@ -29,6 +29,7 @@ _RunningMode = running_mode_module.AudioTaskRunningMode _Timestamp = timestamp_module.Timestamp +@doc_controls.do_not_generate_docs class BaseAudioTaskApi(object): """The base class of the user-facing mediapipe audio task api classes.""" @@ -133,12 +134,10 @@ class BaseAudioTaskApi(object): """ self._runner.close() - @doc_controls.do_not_generate_docs def __enter__(self): """Return `self` upon entering the runtime context.""" return self - @doc_controls.do_not_generate_docs def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback): """Shuts down the mediapipe audio task instance on exit of the context manager. diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index f14d59b99..6098fb5f5 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -43,6 +43,7 @@ py_library( name = "task_info", srcs = ["task_info.py"], deps = [ + ":optional_dependencies", "//mediapipe/calculators/core:flow_limiter_calculator_py_pb2", "//mediapipe/framework:calculator_options_py_pb2", "//mediapipe/framework:calculator_py_pb2", diff --git a/mediapipe/tasks/python/core/task_info.py b/mediapipe/tasks/python/core/task_info.py index 31605480f..6ea2cee7b 100644 --- a/mediapipe/tasks/python/core/task_info.py +++ b/mediapipe/tasks/python/core/task_info.py @@ -20,8 +20,10 @@ from typing import Any, List from mediapipe.calculators.core import flow_limiter_calculator_pb2 from mediapipe.framework import calculator_options_pb2 from mediapipe.framework import calculator_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +@doc_controls.do_not_generate_docs @dataclasses.dataclass class TaskInfo: """Specifications of a MediaPipe task graph. diff --git a/mediapipe/tasks/python/text/core/base_text_task_api.py b/mediapipe/tasks/python/text/core/base_text_task_api.py index b22bfff00..1d6311561 100644 --- a/mediapipe/tasks/python/text/core/base_text_task_api.py +++ b/mediapipe/tasks/python/text/core/base_text_task_api.py @@ -20,6 +20,7 @@ from mediapipe.tasks.python.core.optional_dependencies import doc_controls _TaskRunner = task_runner.TaskRunner +@doc_controls.do_not_generate_docs class BaseTextTaskApi(object): """The base class of the user-facing mediapipe text task api classes.""" @@ -40,12 +41,10 @@ class BaseTextTaskApi(object): """ self._runner.close() - @doc_controls.do_not_generate_docs def __enter__(self): """Returns `self` upon entering the runtime context.""" return self - @doc_controls.do_not_generate_docs def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback): """Shuts down the mediapipe text task instance on exit of the context manager. diff --git a/mediapipe/tasks/python/vision/core/base_vision_task_api.py b/mediapipe/tasks/python/vision/core/base_vision_task_api.py index 016170398..0c8262d4b 100644 --- a/mediapipe/tasks/python/vision/core/base_vision_task_api.py +++ b/mediapipe/tasks/python/vision/core/base_vision_task_api.py @@ -31,6 +31,7 @@ _RunningMode = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions +@doc_controls.do_not_generate_docs class BaseVisionTaskApi(object): """The base class of the user-facing mediapipe vision task api classes.""" @@ -178,12 +179,10 @@ class BaseVisionTaskApi(object): """ self._runner.close() - @doc_controls.do_not_generate_docs def __enter__(self): """Return `self` upon entering the runtime context.""" return self - @doc_controls.do_not_generate_docs def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback): """Shuts down the mediapipe vision task instance on exit of the context manager. From 25abd122b338de4598edc72987bd91a13104c84d Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 10 Jan 2023 09:44:04 -0800 Subject: [PATCH 284/346] Support AudioRecord in MediaPipe audio tasks in Java. PiperOrigin-RevId: 501019327 --- .../tasks/audio/core/BaseAudioTaskApi.java | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java index 2782f8d36..7abde72d5 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java @@ -14,6 +14,9 @@ package com.google.mediapipe.tasks.audio.core; +import android.media.AudioFormat; +import android.media.AudioRecord; +import android.media.MediaRecorder; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.tasks.components.containers.AudioData; @@ -149,4 +152,71 @@ public class BaseAudioTaskApi implements AutoCloseable { public void close() { runner.close(); } + + /** + * Creates an {@link android.media.AudioRecord} instance to record audio stream. The returned + * AudioRecord instance is initialized and client needs to call {@link + * android.media.AudioRecord#startRecording} method to start recording. + * + *

Note that MediaPipe Audio tasks will up/down sample automatically to fit the sample rate + * required by the model. The default sample rate of the MediaPipe pretrained audio model, Yamnet, + * is 16kHz. + * + * @param numChannels the number of audio channels. + * @param sampleRate the audio sample rate. + * @return an {@link android.media.AudioRecord} instance in {@link + * android.media.AudioRecord#STATE_INITIALIZED} + * @throws IllegalArgumentException if the model required channel count is unsupported + * @throws IllegalStateException if AudioRecord instance failed to initialize + */ + public static AudioRecord createAudioRecord(int numChannels, int sampleRate) { + int channelConfig = 0; + switch (numChannels) { + case 1: + channelConfig = AudioFormat.CHANNEL_IN_MONO; + break; + case 2: + channelConfig = AudioFormat.CHANNEL_IN_STEREO; + break; + default: + throw new IllegalArgumentException( + "getAudioRecord method only supports 1 or 2 audio channels."); + } + + int bufferSizeInBytes = + AudioRecord.getMinBufferSize(sampleRate, channelConfig, AudioFormat.ENCODING_PCM_FLOAT); + if (bufferSizeInBytes == AudioRecord.ERROR + || bufferSizeInBytes == AudioRecord.ERROR_BAD_VALUE) { + throw new IllegalStateException( + String.format("AudioRecord.getMinBufferSize failed. Returned: %d", bufferSizeInBytes)); + } + AudioRecord audioRecord = + new AudioRecord( + // including MIC, UNPROCESSED, and CAMCORDER. + MediaRecorder.AudioSource.VOICE_RECOGNITION, + sampleRate, + channelConfig, + AudioFormat.ENCODING_PCM_FLOAT, + bufferSizeInBytes); + if (audioRecord.getState() != AudioRecord.STATE_INITIALIZED) { + throw new IllegalStateException(String.format("AudioRecordfailed to initialize")); + } + return audioRecord; + } + + /** + * Creates an {@link android.media.AudioRecord} instance to record audio stream that has mono + * channel at sample rate at sample rate 16kHz, the sample rate required for models like Yamnet. + * The returned AudioRecord instance is initialized and client needs to call {@link + * android.media.AudioRecord#startRecording} method to start recording. + * + * @return an {@link android.media.AudioRecord} instance in {@link + * android.media.AudioRecord#STATE_INITIALIZED} + * @throws IllegalArgumentException if the model required channel count is unsupported + * @throws IllegalStateException if AudioRecord instance failed to initialize + */ + public static AudioRecord createAudioRecord() { + // TODO: Support creating AudioRecord based on the model specifications. + return createAudioRecord(1, 16000); + } } From 54268594dd8d6aa75222a408cc03a049b82be467 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Tue, 10 Jan 2023 17:35:57 -0800 Subject: [PATCH 285/346] Internal change. PiperOrigin-RevId: 501136760 --- .../formats/tensor/cpu_buffer_converters.cc | 240 +++++++++++++++ .../tensor/cpu_buffer_converters_test.cc | 282 ++++++++++++++++++ 2 files changed, 522 insertions(+) create mode 100644 mediapipe/framework/formats/tensor/cpu_buffer_converters.cc create mode 100644 mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc diff --git a/mediapipe/framework/formats/tensor/cpu_buffer_converters.cc b/mediapipe/framework/formats/tensor/cpu_buffer_converters.cc new file mode 100644 index 000000000..e4e705be5 --- /dev/null +++ b/mediapipe/framework/formats/tensor/cpu_buffer_converters.cc @@ -0,0 +1,240 @@ +#include +#include +#include + +#include "mediapipe/framework/formats/tensor/backend.h" +#include "mediapipe/framework/formats/tensor/tensor2.h" +#include "mediapipe/framework/formats/tensor/views/buffer.h" +#include "mediapipe/framework/formats/tensor/views/cpu_buffer.h" +#include "third_party/FP16/include/fp16.h" + +namespace mediapipe { +namespace { + +template +auto ConverterCheckFunction() { + return + [](const Tensor2& tensor, uint64_t source_descriptor_type_id, + const Tensor2::ViewDescriptor& source_base_descriptor, + uint64_t destination_descriptor_type_id, + const Tensor2::ViewDescriptor& destination_base_descriptor) -> bool { + if (source_descriptor_type_id != TensorCpuView::kId || + destination_descriptor_type_id != TensorCpuView::kId) + return false; + auto source_descriptor = + static_cast(source_base_descriptor); + auto destination_descriptor = + static_cast( + destination_base_descriptor); + return source_descriptor.buffer.format == + TensorTypeToFormat::value && + destination_descriptor.buffer.format == + TensorTypeToFormat::value; + }; +} + +template +auto ConvertFunction() { + return [](const Tensor2& tensor, const Tensor2::View& source_base_view, + const Tensor2::View& destination_base_view) -> bool { + auto source = source_base_view.DownCast(); + auto destination = destination_base_view.DownCast(); + if (source->descriptor().buffer.format == + destination->descriptor().buffer.format) { + std::memcpy( + destination->data(), source->data(), + TensorBufferSize(destination->descriptor().buffer, tensor.shape())); + } else { + auto source_pointer = source->data(); + auto destination_pointer = destination->data(); + for (int i = 0; i < tensor.shape().NumElements(); i++) { + *destination_pointer++ = + GpuLikeTypeCast(*source_pointer++); + } + } + return true; + }; +} + +#define REGISTER_CONVERTER(SourceType, DestinationType) \ + TENSOR_REGISTER_CONVERTER( \ + {ConverterCheckFunction(), \ + ConvertFunction()}); + +REGISTER_CONVERTER(float, Float16); +REGISTER_CONVERTER(float, int8_t); +REGISTER_CONVERTER(float, uint8_t); +REGISTER_CONVERTER(float, int16_t); +REGISTER_CONVERTER(float, uint16_t); +REGISTER_CONVERTER(float, int32_t); +REGISTER_CONVERTER(float, uint32_t); + +REGISTER_CONVERTER(Float16, float); +REGISTER_CONVERTER(Float16, int8_t); +REGISTER_CONVERTER(Float16, uint8_t); +REGISTER_CONVERTER(Float16, int16_t); +REGISTER_CONVERTER(Float16, uint16_t); +REGISTER_CONVERTER(Float16, int32_t); +REGISTER_CONVERTER(Float16, uint32_t); + +REGISTER_CONVERTER(int8_t, float); +REGISTER_CONVERTER(int8_t, Float16); +REGISTER_CONVERTER(int8_t, uint8_t); +REGISTER_CONVERTER(int8_t, int16_t); +REGISTER_CONVERTER(int8_t, uint16_t); +REGISTER_CONVERTER(int8_t, int32_t); +REGISTER_CONVERTER(int8_t, uint32_t); + +REGISTER_CONVERTER(uint8_t, float); +REGISTER_CONVERTER(uint8_t, Float16); +REGISTER_CONVERTER(uint8_t, int8_t); +REGISTER_CONVERTER(uint8_t, int16_t); +REGISTER_CONVERTER(uint8_t, uint16_t); +REGISTER_CONVERTER(uint8_t, int32_t); +REGISTER_CONVERTER(uint8_t, uint32_t); + +REGISTER_CONVERTER(int16_t, float); +REGISTER_CONVERTER(int16_t, Float16); +REGISTER_CONVERTER(int16_t, int8_t); +REGISTER_CONVERTER(int16_t, uint8_t); +REGISTER_CONVERTER(int16_t, uint16_t); +REGISTER_CONVERTER(int16_t, uint32_t); +REGISTER_CONVERTER(int16_t, uint32_t); + +REGISTER_CONVERTER(uint16_t, float); +REGISTER_CONVERTER(uint16_t, Float16); +REGISTER_CONVERTER(uint16_t, int8_t); +REGISTER_CONVERTER(uint16_t, uint8_t); +REGISTER_CONVERTER(uint16_t, int16_t); +REGISTER_CONVERTER(uint16_t, int32_t); +REGISTER_CONVERTER(uint16_t, uint32_t); + +REGISTER_CONVERTER(int32_t, float); +REGISTER_CONVERTER(int32_t, Float16); +REGISTER_CONVERTER(int32_t, int8_t); +REGISTER_CONVERTER(int32_t, uint8_t); +REGISTER_CONVERTER(int32_t, int16_t); +REGISTER_CONVERTER(int32_t, uint16_t); +REGISTER_CONVERTER(int32_t, uint32_t); + +REGISTER_CONVERTER(uint32_t, float); +REGISTER_CONVERTER(uint32_t, Float16); +REGISTER_CONVERTER(uint32_t, int8_t); +REGISTER_CONVERTER(uint32_t, uint8_t); +REGISTER_CONVERTER(uint32_t, int16_t); +REGISTER_CONVERTER(uint32_t, uint16_t); +REGISTER_CONVERTER(uint32_t, int32_t); + +template +auto DequantizationCheckFunction() { + return + [](const Tensor2& tensor, uint64_t source_descriptor_type_id, + const Tensor2::ViewDescriptor& source_base_descriptor, + uint64_t destination_descriptor_type_id, + const Tensor2::ViewDescriptor& destination_base_descriptor) -> bool { + if (source_descriptor_type_id != TensorCpuView::kId || + destination_descriptor_type_id != TensorCpuView::kId) + return false; + auto source_descriptor = + static_cast(source_base_descriptor); + auto destination_descriptor = + static_cast( + destination_base_descriptor); + return source_descriptor.buffer.format == + TensorBufferDescriptor::Format::kQuantizedInt8 && + destination_descriptor.buffer.format == + TensorTypeToFormat::value; + }; +} + +template +auto DequantizationConvertFunction() { + return [](const Tensor2& tensor, const Tensor2::View& source_base_view, + const Tensor2::View& destination_base_view) -> bool { + auto source = source_base_view.DownCast(); + auto destination = destination_base_view.DownCast(); + auto source_pointer = source->data(); + auto destination_pointer = destination->data(); + int zero_point = + source->descriptor().buffer.quantization_parameters.zero_point; + float scale = source->descriptor().buffer.quantization_parameters.scale; + for (int i = 0; i < tensor.shape().NumElements(); i++) { + *destination_pointer++ = static_cast( + (*source_pointer++ - zero_point) * scale); + } + return true; + }; +} + +#define REGISTER_DEQUANTIZATION_CONVERTER(DestinationType) \ + TENSOR_REGISTER_CONVERTER( \ + {DequantizationCheckFunction(), \ + DequantizationConvertFunction()}); + +REGISTER_DEQUANTIZATION_CONVERTER(float); +REGISTER_DEQUANTIZATION_CONVERTER(Float16); +REGISTER_DEQUANTIZATION_CONVERTER(int8_t); +REGISTER_DEQUANTIZATION_CONVERTER(uint8_t); +REGISTER_DEQUANTIZATION_CONVERTER(int16_t); +REGISTER_DEQUANTIZATION_CONVERTER(uint16_t); +REGISTER_DEQUANTIZATION_CONVERTER(int32_t); +REGISTER_DEQUANTIZATION_CONVERTER(uint32_t); + +template +auto QuantizationCheckFunction() { + return + [](const Tensor2& tensor, uint64_t source_descriptor_type_id, + const Tensor2::ViewDescriptor& source_base_descriptor, + uint64_t destination_descriptor_type_id, + const Tensor2::ViewDescriptor& destination_base_descriptor) -> bool { + if (source_descriptor_type_id != TensorCpuView::kId || + destination_descriptor_type_id != TensorCpuView::kId) + return false; + auto source_descriptor = + static_cast(source_base_descriptor); + auto destination_descriptor = + static_cast( + destination_base_descriptor); + bool same = source_descriptor.buffer.format == + TensorTypeToFormat::value && + destination_descriptor.buffer.format == + TensorBufferDescriptor::Format::kQuantizedInt8; + return same; + }; +} + +template +auto QuantizationConvertFunction() { + return [](const Tensor2& tensor, const Tensor2::View& source_base_view, + const Tensor2::View& destination_base_view) -> bool { + auto source = source_base_view.DownCast(); + auto destination = destination_base_view.DownCast(); + auto source_pointer = source->data(); + auto destination_pointer = destination->data(); + int zero_point = + destination->descriptor().buffer.quantization_parameters.zero_point; + float scale = + destination->descriptor().buffer.quantization_parameters.scale; + for (int i = 0; i < tensor.shape().NumElements(); i++) { + *destination_pointer++ = + static_cast(*source_pointer++ / scale + zero_point); + } + return true; + }; +} + +#define REGISTER_QUANTIZATION_CONVERTER(SourceType) \ + TENSOR_REGISTER_CONVERTER({QuantizationCheckFunction(), \ + QuantizationConvertFunction()}); + +REGISTER_QUANTIZATION_CONVERTER(float); +REGISTER_QUANTIZATION_CONVERTER(Float16); +REGISTER_QUANTIZATION_CONVERTER(int8_t); +REGISTER_QUANTIZATION_CONVERTER(uint8_t); +REGISTER_QUANTIZATION_CONVERTER(int16_t); +REGISTER_QUANTIZATION_CONVERTER(uint16_t); +REGISTER_QUANTIZATION_CONVERTER(int32_t); +REGISTER_QUANTIZATION_CONVERTER(uint32_t); + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc b/mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc new file mode 100644 index 000000000..3619ad531 --- /dev/null +++ b/mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc @@ -0,0 +1,282 @@ +#include + +#include "mediapipe/framework/formats/tensor/tensor2.h" +#include "mediapipe/framework/formats/tensor/views/buffer.h" +#include "mediapipe/framework/formats/tensor/views/cpu_buffer.h" +#include "testing/base/public/gmock.h" +#include "testing/base/public/gunit.h" + +MATCHER_P(NearWithPrecision, precision, "") { + return std::abs(std::get<0>(arg) - std::get<1>(arg)) < precision; +} +MATCHER_P(IntegerEqual, precision, "") { + return std::get<0>(arg) == std::get<1>(arg); +} + +namespace mediapipe { + +TEST(TensorCpuViewTest, TestWrite32ThenRead16) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = 1234.0f; + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat16}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_EQ(*view->data(), 1234.0f); + } +} + +TEST(TensorCpuViewTest, TestWrite16ThenRead32) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat16}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = 1234.0f; + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_EQ(*view->data(), 1234.0f); + } +} + +TEST(TensorCpuViewTest, TestWriteFloat32ThenReadInt8) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = 0.121569f; + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_EQ( + *view->data(), + static_cast(0.121569f * std::numeric_limits::max())); + } +} + +TEST(TensorCpuViewTest, TestWriteInt8ThenReadFloat32) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = + static_cast(0.123f * std::numeric_limits::max()); + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_NEAR(*view->data(), 0.123f, + 1.0f / std::numeric_limits::max()); + } +} + +TEST(TensorCpuViewTest, TestWriteUInt8ThenReadUInt16) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = 123; + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kUInt16}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_EQ(*view->data(), uint16_t{123} << 8); + } +} + +TEST(TensorCpuViewTest, TestWriteUInt16ThenReadUInt8) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kUInt16}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = uint16_t{123} << 8; + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_EQ(*view->data(), 123); + } +} + +TEST(TensorCpuViewTest, TestWriteNegativeInt8ThenReadUInt8) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kInt8}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = -123; + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_EQ(*view->data(), 0); + } +} + +TEST(TensorCpuViewTest, TestWritePositiveInt8ThenReadUInt8) { + Tensor2 tensor{Tensor2::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kInt8}})); + ASSERT_NE(view->data(), nullptr); + *view->data() = 123; + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); + ASSERT_NE(view->data(), nullptr); + EXPECT_EQ(*view->data(), 123 * 2); + } +} + +TEST(TensorCpuViewTest, TestDequantization) { + constexpr int num_elements = 20; + // Gives quantization values in range [-100, 90]. + constexpr int zero_point = -100; + constexpr float scale = 2.0f; + Tensor2 tensor{Tensor2::Shape({num_elements})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = { + .format = TensorBufferDescriptor::Format::kQuantizedInt8, + .quantization_parameters = {.scale = scale, + .zero_point = zero_point}}})); + ASSERT_NE(view->data(), nullptr); + auto data = view->data(); + for (int i = 0; i < num_elements; ++i) { + // Add some bias (+1) to make round-up take place. + data[i] = (i * 20 + 1) / scale + zero_point; + } + } + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + ASSERT_NE(view->data(), nullptr); + std::vector reference(num_elements); + for (int i = 0; i < num_elements; ++i) { + reference[i] = i * 20.0f + 1.0f; + } + EXPECT_THAT(absl::Span(view->data(), num_elements), + testing::Pointwise(NearWithPrecision(1.001), reference)); + } +} + +TEST(TensorCpuViewTest, TestQuantization) { + constexpr int num_elements = 20; + // Gives quantization values in range [-100, 90]. + constexpr int zero_point = -100; + constexpr float scale = 2.0f; + Tensor2 tensor{Tensor2::Shape({num_elements})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + ASSERT_NE(view->data(), nullptr); + auto data = view->data(); + for (int i = 0; i < num_elements; ++i) { + // Add some bias (+1) to make round-up take place. + data[i] = i * 20 + 1; + } + } + { + TensorCpuViewDescriptor d{ + .buffer = {.format = TensorBufferDescriptor::Format::kQuantizedInt8, + .quantization_parameters = {.scale = scale, + .zero_point = zero_point}}}; + MP_ASSERT_OK_AND_ASSIGN( + auto view, tensor.GetView(d)); + ASSERT_NE(view->data(), nullptr); + std::vector reference(num_elements); + for (int i = 0; i < num_elements; ++i) { + reference[i] = (i * 20 + 1) / scale + zero_point; + } + EXPECT_THAT(absl::Span(view->data(), num_elements), + testing::Pointwise(IntegerEqual(0), reference)); + } +} + +} // namespace mediapipe From ed6abbbe43df13ddb1145ee94cec681dbf9d6473 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 11 Jan 2023 16:21:28 +0530 Subject: [PATCH 286/346] Added iOS text classifier options --- .../tasks/ios/text/text_classifier/BUILD | 27 ++++++++ .../sources/MPPTextClassifierOptions.h | 62 +++++++++++++++++++ .../sources/MPPTextClassifierOptions.m | 40 ++++++++++++ 3 files changed, 129 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_classifier/BUILD create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD new file mode 100644 index 000000000..dff39baab --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -0,0 +1,27 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextClassifierOptions", + srcs = ["sources/MPPTextClassifierOptions.m"], + hdrs = ["sources/MPPTextClassifierOptions.h"], + deps = [ + "//mediapipe/tasks/ios/core:MPPTaskOptions", + ], +) + diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h new file mode 100644 index 000000000..d43d801d4 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h @@ -0,0 +1,62 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Options for setting up a `MPPTextClassifierOptions`. + */ +NS_SWIFT_NAME(TextClassifierOptions) +@interface MPPTextClassifierOptions : MPPTaskOptions + +/** + * The locale to use for display names specified through the TFLite Model + * Metadata, if any. Defaults to English. + */ +@property(nonatomic, copy) NSString *displayNamesLocale; + +/** + * The maximum number of top-scored classification results to return. If < 0, + * all available results will be returned. If 0, an invalid argument error is + * returned. + */ +@property(nonatomic) NSInteger maxResults; + +/** + * Score threshold to override the one provided in the model metadata (if any). + * Results below this value are rejected. + */ +@property(nonatomic) float scoreThreshold; + +/** + * The allowlist of category names. If non-empty, detection results whose + * category name is not in this set will be filtered out. Duplicate or unknown + * category names are ignored. Mutually exclusive with categoryDenylist. + */ +@property(nonatomic, copy) NSArray *categoryAllowlist; + +/** + * The denylist of category names. If non-empty, detection results whose + * category name is in this set will be filtered out. Duplicate or unknown + * category names are ignored. Mutually exclusive with categoryAllowlist. + */ +@property(nonatomic, copy) NSArray *categoryDenylist; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m new file mode 100644 index 000000000..2d5c17cda --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m @@ -0,0 +1,40 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" + +@implementation MPPTextClassifierOptions + +- (instancetype)init { + self = [super init]; + if (self) { + _maxResults = -1; + _scoreThreshold = 0; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPTextClassifierOptions *textClassifierOptions = [super copyWithZone:zone]; + + textClassifierOptions.scoreThreshold = self.scoreThreshold; + textClassifierOptions.maxResults = self.maxResults; + textClassifierOptions.categoryDenylist = self.categoryDenylist; + textClassifierOptions.categoryAllowlist = self.categoryAllowlist; + textClassifierOptions.displayNamesLocale = self.displayNamesLocale; + + return textClassifierOptions; +} + +@end From 1161ebce9d720e544d3cf740b5e6a5aa446979ed Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 11 Jan 2023 16:22:09 +0530 Subject: [PATCH 287/346] Added iOS text classifier result --- .../tasks/ios/text/text_classifier/BUILD | 10 +++++ .../sources/MPPTextClassifierResult.h | 44 +++++++++++++++++++ .../sources/MPPTextClassifierResult.m | 28 ++++++++++++ 3 files changed, 82 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index dff39baab..59ef601bf 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -25,3 +25,13 @@ objc_library( ], ) +objc_library( + name = "MPPTextClassifierResult", + srcs = ["sources/MPPTextClassifierResult.m"], + hdrs = ["sources/MPPTextClassifierResult.h"], + deps = [ + "//mediapipe/tasks/ios/core:MPPTaskResult", + "//mediapipe/tasks/ios/components/containers:MPPClassificationResult", + ], +) + diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h new file mode 100644 index 000000000..63bb92352 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h @@ -0,0 +1,44 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Represents the classification results generated by `MPPTextClassifier`. **/ +NS_SWIFT_NAME(TextClassifierResult) +@interface MPPTextClassifierResult : MPPTaskResult + +/** The `MPPClassificationResult` instance containing one set of results per classifier head. **/ +@property(nonatomic, readonly) MPPClassificationResult *classificationResult; + +/** + * Initializes a new `MPPTextClassifierResult` with the given `MPPClassificationResult` and + * timestamp (in milliseconds). + * + * @param classificationResult The `MPPClassificationResult` instance containing one set of results + * per classifier head. + * @param timestampMs The timestamp for this result. + * + * @return An instance of `MPPTextClassifierResult` initialized with the given + * `MPPClassificationResult` and timestamp (in milliseconds). + */ +- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult + timestampMs:(NSInteger)timestampMs; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m new file mode 100644 index 000000000..4d5c1104a --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m @@ -0,0 +1,28 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" + +@implementation MPPTextClassifierResult + +- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult + timestampMs:(NSInteger)timestampMs { + self = [super initWithTimestampMs:timestampMs]; + if (self) { + _classificationResult = classificationResult; + } + return self; +} + +@end From 54161cc1abaa11701bc2a51d8ef331db55db0b19 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 11 Jan 2023 20:22:02 +0530 Subject: [PATCH 288/346] Added iOS text classifier options helpers --- .../ios/text/text_classifier/utils/BUILD | 31 ++++++++++ .../MPPTextClassifierOptions+Helpers.h | 26 +++++++++ .../MPPTextClassifierOptions+Helpers.mm | 56 +++++++++++++++++++ 3 files changed, 113 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/BUILD create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD new file mode 100644 index 000000000..9b01c763e --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD @@ -0,0 +1,31 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextClassifierOptionsHelpers", + srcs = ["sources/MPPTextClassifierOptions+Helpers.mm"], + hdrs = ["sources/MPPTextClassifierOptions+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierOptions", + "//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers", + "//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + ], +) + diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h new file mode 100644 index 000000000..1e52e5c87 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h @@ -0,0 +1,26 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPTextClassifierOptions (Helpers) + +- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm new file mode 100644 index 000000000..c370f11ef --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm @@ -0,0 +1,56 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" + +#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h" + +namespace { +using CalculatorOptionsProto = ::mediapipe::CalculatorOptions; +using TextClassifierGraphOptionsProto = + ::mediapipe::tasks::text::text_classifier::proto::TextClassifierGraphOptions; +using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto::ClassifierOptions; +} // namespace + +@implementation MPPTextClassifierOptions (Helpers) + +- (void)copyToProto:(CalculatorOptionsProto *)optionsProto { + TextClassifierGraphOptionsProto *graphOptions = + optionsProto->MutableExtension(TextClassifierGraphOptionsProto::ext); + [self.baseOptions copyToProto:graphOptions->mutable_base_options()]; + + ClassifierOptionsProto *classifierOptionsProto = graphOptions->mutable_classifier_options(); + classifierOptionsProto->Clear(); + + if (self.displayNamesLocale) { + classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); + } + + classifierOptionsProto->set_max_results((int)self.maxResults); + classifierOptionsProto->set_score_threshold(self.scoreThreshold); + + for (NSString *category in self.categoryAllowlist) { + classifierOptionsProto->add_category_allowlist(category.cppString); + } + + for (NSString *category in self.categoryDenylist) { + classifierOptionsProto->add_category_denylist(category.cppString); + } + +} + +@end From a0220de2338e4fbc308c95d9fef91383dd817ca4 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 11 Jan 2023 20:22:20 +0530 Subject: [PATCH 289/346] Added iOS text classifier result helpers --- .../ios/text/text_classifier/utils/BUILD | 10 +++++ .../sources/MPPTextClassifierResult+Helpers.h | 28 ++++++++++++ .../MPPTextClassifierResult+Helpers.mm | 43 +++++++++++++++++++ 3 files changed, 81 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD index 9b01c763e..299050b32 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD @@ -29,3 +29,13 @@ objc_library( ], ) +objc_library( + name = "MPPTextClassifierResultHelpers", + srcs = ["sources/MPPTextClassifierResult+Helpers.mm"], + hdrs = ["sources/MPPTextClassifierResult+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult", + "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", + "//mediapipe/framework:packet", + ], +) diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h new file mode 100644 index 000000000..f1b728b0a --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h @@ -0,0 +1,28 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" + +#include "mediapipe/framework/packet.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPTextClassifierResult (Helpers) + ++ (MPPTextClassifierResult *)textClassifierResultWithClassificationsPacket: + (const mediapipe::Packet &)packet; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm new file mode 100644 index 000000000..62e0d8cb1 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm @@ -0,0 +1,43 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +static const int kMicroSecondsPerMilliSecond = 1000; + +namespace { +using ClassificationResultProto = + ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::Packet; +} // namespace + +#define int kMicroSecondsPerMilliSecond = 1000; + +@implementation MPPTextClassifierResult (Helpers) + ++ (MPPTextClassifierResult *)textClassifierResultWithClassificationsPacket:(const Packet &)packet { + MPPClassificationResult *classificationResult = [MPPClassificationResult + classificationResultWithProto:packet.Get()]; + + return [[MPPTextClassifierResult alloc] + initWithClassificationResult:classificationResult + timestampMs:(NSInteger)(packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond)]; +} + +@end From b1ded2f700a424a2c6782a4f571bcd70e554fc6b Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 11 Jan 2023 20:22:33 +0530 Subject: [PATCH 290/346] Added iOS text classifier --- .../tasks/ios/text/text_classifier/BUILD | 21 ++++ .../sources/MPPTextClassifier.h | 103 ++++++++++++++++++ .../sources/MPPTextClassifier.mm | 98 +++++++++++++++++ 3 files changed, 222 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index 59ef601bf..e5242f50d 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -35,3 +35,24 @@ objc_library( ], ) +objc_library( + name = "MPPTextClassifier", + srcs = ["sources/MPPTextClassifier.mm"], + hdrs = ["sources/MPPTextClassifier.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/core:MPPTaskInfo", + "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", + "//mediapipe/tasks/ios/core:MPPTextPacketCreator", + "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", + "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + ":MPPTextClassifierOptions", + ], +) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h new file mode 100644 index 000000000..10bccad3d --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -0,0 +1,103 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * @brief Performs classification on text. + * + * This API expects a TFLite model with (optional) [TFLite Model + * Metadata](https://www.tensorflow.org/lite/convert/metadata")that contains the mandatory + * (described below) input tensors, output tensor, and the optional (but recommended) label + * items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor. + * + * Metadata is required for models with int32 input tensors because it contains the input + * process unit for the model's Tokenizer. No metadata is required for models with string + * input tensors. + * + * Input tensors + * - Three input tensors `kTfLiteInt32` of shape `[batch_size xbert_max_seq_len]` + * representing the input ids, mask ids, and segment ids. This input signature requires + * a Bert Tokenizer process unit in the model metadata. + * - Or one input tensor `kTfLiteInt32` of shape `[batch_size xmax_seq_len]` representing + * the input ids. This input signature requires a Regex Tokenizer process unit in the + * model metadata. + * - Or one input tensor (`kTfLiteString`) that is shapeless or has shape `[1]` containing + * the input string. + * + * At least one output tensor (`kTfLiteFloat32/kBool`) with: + * - `N` classes and shape `[1 x N]` + * - optional (but recommended) label map(s) as AssociatedFiles with type + * TENSOR_AXIS_LABELS, + * containing one label per line. The first such AssociatedFile (if any) is used to fill + * the `categoryName` field of the results. The `displayName` field is filled from the + * AssociatedFile (if any) whose locale matches the `displayNamesLocale` field of the + * `MPPTextClassifierOptions` used at creation time ("en" by default, i.e. English). If + * none of these are available, only the `index` field of the results will be filled. + */ +NS_SWIFT_NAME(TextClassifier) +@interface MPPTextClassifier : NSObject + +/** + * Creates a new instance of `MPPTextClassifier` from an absolute path to a TensorFlow Lite + * model file stored locally on the device and the default `MPPTextClassifierOptions`. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the + * device. + * @param error An optional error parameter populated when there is an error in initializing + * the text classifier. + * + * @return A new instance of `MPPTextClassifier` with the given model path. `nil` if there is an + * error in initializing the text classifier. + */ +- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; + +/** + * Creates a new instance of `MPPTextClassifier` from the given `MPPTextClassifierOptions`. + * + * @param options The options of type `MPPTextClassifierOptions` to use for configuring the + * `MPPTextClassifier`. + * @param error An optional error parameter populated when there is an error in initializing + * the text classifier. + * + * @return A new instance of `MPPTextClassifier` with the given options. `nil` if there is an + * error in initializing the text classifier. + */ +- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * Performs classification on the input text. + * + * @param text The `NSString` on which classification is to be performed. + * @param error An optional error parameter populated when there is an error in performing + * classification on the input text. + * + * @return A `MPPTextClassifierResult` object that contains a list of text classifications. + */ +- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text error:(NSError **)error; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm new file mode 100644 index 000000000..aed05ec37 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -0,0 +1,98 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" +#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" + +#include "absl/status/statusor.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +namespace { +using ::mediapipe::Packet; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::core::PacketMap; +} // namespace + +static NSString *const kClassificationsStreamName = @"classifications_out"; +static NSString *const kClassificationsTag = @"CLASSIFICATIONS"; +static NSString *const kTextInStreamName = @"text_in"; +static NSString *const kTextTag = @"TEXT"; +static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph"; + +@interface MPPTextClassifier () { + /** iOS Text Task Runner */ + MPPTextTaskRunner *_textTaskRunner; +} +@end + +@implementation MPPTextClassifier + +- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] + initWithTaskGraphName:kTaskGraphName + inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ] + outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag, + kClassificationsStreamName] ] + taskOptions:options + enableFlowLimiting:NO + error:error]; + + if (!taskInfo) { + return nil; + } + + _textTaskRunner = + [[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] + error:error]; + + if (!_textTaskRunner) { + return nil; + } + + self = [super init]; + + return self; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { + MPPTextClassifierOptions *options = [[MPPTextClassifierOptions alloc] init]; + + options.baseOptions.modelAssetPath = modelPath; + + return [self initWithOptions:options error:error]; +} + +- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text error:(NSError **)error { + Packet packet = [MPPTextPacketCreator createWithText:text]; + + std::map packetMap = {{kTextInStreamName.cppString, packet}}; + absl::StatusOr statusOrOutputPacketMap = [_textTaskRunner process:packetMap]; + + if (![MPPCommonUtils checkCppError:statusOrOutputPacketMap.status() toError:error]) { + return nil; + } + + return [MPPTextClassifierResult + textClassifierResultWithClassificationsPacket:statusOrOutputPacketMap.value() + [kClassificationsStreamName.cppString]]; +} + +@end From fe05a8d201a12011dc7cc82eaf4dc0f4fad42b20 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 11 Jan 2023 20:24:17 +0530 Subject: [PATCH 291/346] Reformatted code --- .../sources/MPPTextClassifier.h | 20 +++++++++---------- .../sources/MPPTextClassifier.mm | 6 +++--- .../sources/MPPTextClassifierResult.h | 2 +- .../MPPTextClassifierOptions+Helpers.mm | 1 - 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h index 10bccad3d..48498edca 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -22,19 +22,19 @@ NS_ASSUME_NONNULL_BEGIN /** * @brief Performs classification on text. - * + * * This API expects a TFLite model with (optional) [TFLite Model * Metadata](https://www.tensorflow.org/lite/convert/metadata")that contains the mandatory - * (described below) input tensors, output tensor, and the optional (but recommended) label + * (described below) input tensors, output tensor, and the optional (but recommended) label * items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor. * - * Metadata is required for models with int32 input tensors because it contains the input - * process unit for the model's Tokenizer. No metadata is required for models with string + * Metadata is required for models with int32 input tensors because it contains the input + * process unit for the model's Tokenizer. No metadata is required for models with string * input tensors. * * Input tensors * - Three input tensors `kTfLiteInt32` of shape `[batch_size xbert_max_seq_len]` - * representing the input ids, mask ids, and segment ids. This input signature requires + * representing the input ids, mask ids, and segment ids. This input signature requires * a Bert Tokenizer process unit in the model metadata. * - Or one input tensor `kTfLiteInt32` of shape `[batch_size xmax_seq_len]` representing * the input ids. This input signature requires a Regex Tokenizer process unit in the @@ -44,12 +44,12 @@ NS_ASSUME_NONNULL_BEGIN * * At least one output tensor (`kTfLiteFloat32/kBool`) with: * - `N` classes and shape `[1 x N]` - * - optional (but recommended) label map(s) as AssociatedFiles with type + * - optional (but recommended) label map(s) as AssociatedFiles with type * TENSOR_AXIS_LABELS, - * containing one label per line. The first such AssociatedFile (if any) is used to fill - * the `categoryName` field of the results. The `displayName` field is filled from the - * AssociatedFile (if any) whose locale matches the `displayNamesLocale` field of the - * `MPPTextClassifierOptions` used at creation time ("en" by default, i.e. English). If + * containing one label per line. The first such AssociatedFile (if any) is used to fill + * the `categoryName` field of the results. The `displayName` field is filled from the + * AssociatedFile (if any) whose locale matches the `displayNamesLocale` field of the + * `MPPTextClassifierOptions` used at creation time ("en" by default, i.e. English). If * none of these are available, only the `index` field of the results will be filled. */ NS_SWIFT_NAME(TextClassifier) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm index aed05ec37..59b5423bb 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -62,11 +62,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T _textTaskRunner = [[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error]; - + if (!_textTaskRunner) { return nil; - } - + } + self = [super init]; return self; diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h index 63bb92352..6744a8e16 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h @@ -26,7 +26,7 @@ NS_SWIFT_NAME(TextClassifierResult) @property(nonatomic, readonly) MPPClassificationResult *classificationResult; /** - * Initializes a new `MPPTextClassifierResult` with the given `MPPClassificationResult` and + * Initializes a new `MPPTextClassifierResult` with the given `MPPClassificationResult` and * timestamp (in milliseconds). * * @param classificationResult The `MPPClassificationResult` instance containing one set of results diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm index c370f11ef..de64d970c 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm @@ -50,7 +50,6 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto for (NSString *category in self.categoryDenylist) { classifierOptionsProto->add_category_denylist(category.cppString); } - } @end From c7e36f87207731c02c4d1b72399491c2c6d73f24 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 11 Jan 2023 20:31:46 +0530 Subject: [PATCH 292/346] Re-ordered dependencies in build file --- mediapipe/tasks/ios/text/text_classifier/BUILD | 16 ++++++++-------- .../tasks/ios/text/text_classifier/utils/BUILD | 12 ++++++------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index e5242f50d..a6315840b 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -30,8 +30,8 @@ objc_library( srcs = ["sources/MPPTextClassifierResult.m"], hdrs = ["sources/MPPTextClassifierResult.h"], deps = [ - "//mediapipe/tasks/ios/core:MPPTaskResult", "//mediapipe/tasks/ios/components/containers:MPPClassificationResult", + "//mediapipe/tasks/ios/core:MPPTaskResult", ], ) @@ -44,15 +44,15 @@ objc_library( "-std=c++17", ], deps = [ + ":MPPTextClassifierOptions", "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", - "//mediapipe/tasks/ios/core:MPPTaskOptions", - "//mediapipe/tasks/ios/core:MPPTaskInfo", - "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", - "//mediapipe/tasks/ios/core:MPPTextPacketCreator", - "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", - "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:NSStringHelpers", - ":MPPTextClassifierOptions", + "//mediapipe/tasks/ios/core:MPPTaskInfo", + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/core:MPPTextPacketCreator", + "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", + "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", + "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers", ], ) diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD index 299050b32..23627391c 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD @@ -21,11 +21,11 @@ objc_library( srcs = ["sources/MPPTextClassifierOptions+Helpers.mm"], hdrs = ["sources/MPPTextClassifierOptions+Helpers.h"], deps = [ - "//mediapipe/tasks/ios/common/utils:NSStringHelpers", - "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierOptions", - "//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers", - "//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol", + "//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers", + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierOptions", ], ) @@ -34,8 +34,8 @@ objc_library( srcs = ["sources/MPPTextClassifierResult+Helpers.mm"], hdrs = ["sources/MPPTextClassifierResult+Helpers.h"], deps = [ - "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult", - "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", "//mediapipe/framework:packet", + "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult", ], ) From 0e56bd38f3123ced3de8c0c2862b7f7c55549078 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 11 Jan 2023 12:54:58 -0800 Subject: [PATCH 293/346] Fix for CHECK failure due to pointer description sometimes being larger than allocated string space PiperOrigin-RevId: 501355568 --- mediapipe/framework/tool/sink.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/tool/sink.cc b/mediapipe/framework/tool/sink.cc index 4a181b43f..f8abf4925 100644 --- a/mediapipe/framework/tool/sink.cc +++ b/mediapipe/framework/tool/sink.cc @@ -87,7 +87,8 @@ void AddVectorSink(const std::string& stream_name, // node->mutable_options()->MutableExtension( CallbackPacketCalculatorOptions::ext); options->set_type(CallbackPacketCalculatorOptions::VECTOR_PACKET); - char address[17]; + // Up to 64-bit pointer in hex (16 characters) and an optional "0x" prepended. + char address[19]; int written = snprintf(address, sizeof(address), "%p", dumped_data); CHECK(written > 0 && written < sizeof(address)); options->set_pointer(address); @@ -112,7 +113,8 @@ void AddPostStreamPacketSink(const std::string& stream_name, node->mutable_options()->MutableExtension( CallbackPacketCalculatorOptions::ext); options->set_type(CallbackPacketCalculatorOptions::POST_STREAM_PACKET); - char address[17]; + // Up to 64-bit pointer in hex (16 characters) and an optional "0x" prepended. + char address[19]; int written = snprintf(address, sizeof(address), "%p", post_stream_packet); CHECK(written > 0 && written < sizeof(address)); options->set_pointer(address); From 5612af68cdb6cd157d8a86f13425913521e2de49 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 11 Jan 2023 13:00:37 -0800 Subject: [PATCH 294/346] Propagate compatible_with for drishti_proto_library PiperOrigin-RevId: 501356895 --- mediapipe/framework/tool/mediapipe_graph.bzl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/tool/mediapipe_graph.bzl b/mediapipe/framework/tool/mediapipe_graph.bzl index 45d98b1eb..ef5182a53 100644 --- a/mediapipe/framework/tool/mediapipe_graph.bzl +++ b/mediapipe/framework/tool/mediapipe_graph.bzl @@ -67,7 +67,8 @@ def data_as_c_string( name, srcs, outs = None, - testonly = None): + testonly = None, + compatible_with = None): """Encodes the data from a file as a C string literal. This produces a text file containing the quoted C string literal. It can be @@ -79,6 +80,7 @@ def data_as_c_string( outs: A list containing a single item, the name of the output text file. Defaults to the rule name. testonly: pass 1 if the graph is to be used only for tests. + compatible_with: a list of environments the rule is compatible with. """ if len(srcs) != 1: fail("srcs must be a single-element list") @@ -92,6 +94,7 @@ def data_as_c_string( cmd = "$(location %s) \"$<\" > \"$@\"" % encode_as_c_string, tools = [encode_as_c_string], testonly = testonly, + compatible_with = compatible_with, ) def mediapipe_simple_subgraph( @@ -208,6 +211,7 @@ def mediapipe_options_library( deps = [], visibility = None, testonly = None, + compatible_with = None, **kwargs): """Registers options protobuf metadata for defining options packets. @@ -217,6 +221,7 @@ def mediapipe_options_library( deps: any additional protobuf dependencies. visibility: The list of packages the subgraph should be visible to. testonly: pass 1 if the graph is to be used only for tests. + compatible_with: a list of environments the rule is compatible with. **kwargs: Remaining keyword args, forwarded to cc_library. """ @@ -224,16 +229,19 @@ def mediapipe_options_library( name = proto_lib + "_transitive", deps = [proto_lib], testonly = testonly, + compatible_with = compatible_with, ) direct_descriptor_set( name = proto_lib + "_direct", deps = [proto_lib], testonly = testonly, + compatible_with = compatible_with, ) data_as_c_string( name = name + "_inc", srcs = [proto_lib + "_transitive-transitive-descriptor-set.proto.bin"], outs = [proto_lib + "_descriptors.inc"], + compatible_with = compatible_with, ) native.genrule( name = name + "_type_name", @@ -245,6 +253,7 @@ def mediapipe_options_library( tools = ["//mediapipe/framework/tool:message_type_util"], visibility = visibility, testonly = testonly, + compatible_with = compatible_with, ) expand_template( name = name + "_cc", @@ -256,6 +265,7 @@ def mediapipe_options_library( "{{DESCRIPTOR_INC_FILE_PATH}}": native.package_name() + "/" + proto_lib + "_descriptors.inc", }, testonly = testonly, + compatible_with = compatible_with, ) native.cc_library( name = proto_lib.replace("_proto", "_options_registry"), @@ -274,6 +284,7 @@ def mediapipe_options_library( visibility = visibility, testonly = testonly, features = ["-no_undefined"], + compatible_with = compatible_with, **kwargs ) mediapipe_reexport_library( From 36be94f861e57dd21e4e7a64e5a1650e73313753 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 11 Jan 2023 14:21:02 -0800 Subject: [PATCH 295/346] Internal change PiperOrigin-RevId: 501378130 --- .../desktop/autoflip/autoflip_messages.proto | 4 +++ .../calculators/scene_cropping_calculator.cc | 21 +++++++++-- .../scene_cropping_calculator_test.cc | 35 +++++++++++++++++++ 3 files changed, 57 insertions(+), 3 deletions(-) diff --git a/mediapipe/examples/desktop/autoflip/autoflip_messages.proto b/mediapipe/examples/desktop/autoflip/autoflip_messages.proto index 8507c9ad7..c89a6aea6 100644 --- a/mediapipe/examples/desktop/autoflip/autoflip_messages.proto +++ b/mediapipe/examples/desktop/autoflip/autoflip_messages.proto @@ -185,6 +185,10 @@ message ExternalRenderFrame { // original dimensions of the input video. The first step to render this // frame is to crop this rect from the input frame. optional Rect crop_from_location = 1; + // Rect that must be cropped out of the input frame. It is defined in the + // ratio of the frame of the input video. The first step to render this frame + // is to crop this rect from the input frame. + optional Rect normalized_crop_from_location = 7; // The placement location where the above rect is placed on the output frame. // This will always have the same aspect ratio as the above rect but scaling // may be required. diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc index 89170dc6a..7e286b743 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc @@ -201,13 +201,26 @@ absl::Status ParseAspectRatioString(const std::string& aspect_ratio_string, void ConstructExternalRenderMessage( const cv::Rect& crop_from_location, const cv::Rect& render_to_location, const cv::Scalar& padding_color, const uint64 timestamp_us, - ExternalRenderFrame* external_render_message) { + ExternalRenderFrame* external_render_message, int frame_width, + int frame_height) { auto crop_from_message = external_render_message->mutable_crop_from_location(); crop_from_message->set_x(crop_from_location.x); crop_from_message->set_y(crop_from_location.y); crop_from_message->set_width(crop_from_location.width); crop_from_message->set_height(crop_from_location.height); + + auto normalized_crop_from_message = + external_render_message->mutable_normalized_crop_from_location(); + normalized_crop_from_message->set_x(crop_from_location.x / + static_cast(frame_width)); + normalized_crop_from_message->set_y(crop_from_location.y / + static_cast(frame_height)); + normalized_crop_from_message->set_width(crop_from_location.width / + static_cast(frame_width)); + normalized_crop_from_message->set_height(crop_from_location.height / + static_cast(frame_height)); + auto render_to_message = external_render_message->mutable_render_to_location(); render_to_message->set_x(render_to_location.x); @@ -627,7 +640,8 @@ absl::Status SceneCroppingCalculator::ProcessScene(const bool is_end_of_scene, auto external_render_message = absl::make_unique(); ConstructExternalRenderMessage( crop_from_locations[i], render_to_locations[i], padding_colors[i], - scene_frame_timestamps_[i], external_render_message.get()); + scene_frame_timestamps_[i], external_render_message.get(), + frame_width_, frame_height_); cc->Outputs() .Tag(kExternalRenderingPerFrame) .Add(external_render_message.release(), @@ -640,7 +654,8 @@ absl::Status SceneCroppingCalculator::ProcessScene(const bool is_end_of_scene, ExternalRenderFrame render_frame; ConstructExternalRenderMessage(crop_from_locations[i], render_to_locations[i], padding_colors[i], - scene_frame_timestamps_[i], &render_frame); + scene_frame_timestamps_[i], &render_frame, + frame_width_, frame_height_); external_render_list_->push_back(render_frame); } } diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc index 88728860a..c3285ea58 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc @@ -920,6 +920,41 @@ TEST(SceneCroppingCalculatorTest, OutputsCropMessageKinematicPathNoVideo) { EXPECT_EQ(ext_render_message.render_to_location().height(), 1124); } } + +// Checks external render message with default poly path solver using +// normalized crops. +TEST(SceneCroppingCalculatorTest, OutputsCropMessagePolyPathNormalized) { + const CalculatorGraphConfig::Node config = + ParseTextProtoOrDie( + absl::Substitute(kExternalRenderConfig, kTargetWidth, kTargetHeight)); + auto runner = absl::make_unique(config); + const int num_frames = kSceneSize; + AddScene(0, num_frames, kInputFrameWidth, kInputFrameHeight, kKeyFrameWidth, + kKeyFrameHeight, 1, runner->MutableInputs()); + + MP_EXPECT_OK(runner->Run()); + const auto& outputs = runner->Outputs(); + const auto& ext_render_per_frame = + outputs.Tag(kExternalRenderingPerFrameTag).packets; + EXPECT_EQ(ext_render_per_frame.size(), num_frames); + + for (int i = 0; i < num_frames - 1; ++i) { + const auto& ext_render_message = + ext_render_per_frame[i].Get(); + EXPECT_EQ(ext_render_message.timestamp_us(), i * 20000); + EXPECT_EQ(ext_render_message.normalized_crop_from_location().x(), + 725 / static_cast(kInputFrameWidth)); + EXPECT_EQ(ext_render_message.normalized_crop_from_location().y(), 0); + EXPECT_EQ(ext_render_message.normalized_crop_from_location().width(), + 461 / static_cast(kInputFrameWidth)); + EXPECT_EQ(ext_render_message.normalized_crop_from_location().height(), + 720 / static_cast(kInputFrameHeight)); + EXPECT_EQ(ext_render_message.render_to_location().x(), 0); + EXPECT_EQ(ext_render_message.render_to_location().y(), 0); + EXPECT_EQ(ext_render_message.render_to_location().width(), 720); + EXPECT_EQ(ext_render_message.render_to_location().height(), 1124); + } +} } // namespace } // namespace autoflip } // namespace mediapipe From 8830eefa0b96ccc886feb53e1404bae2e0cdf4d1 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 11 Jan 2023 16:04:57 -0800 Subject: [PATCH 296/346] Internal change. PiperOrigin-RevId: 501403332 --- .../formats/tensor/cpu_buffer_converters.cc | 240 --------------- .../tensor/cpu_buffer_converters_test.cc | 282 ------------------ 2 files changed, 522 deletions(-) delete mode 100644 mediapipe/framework/formats/tensor/cpu_buffer_converters.cc delete mode 100644 mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc diff --git a/mediapipe/framework/formats/tensor/cpu_buffer_converters.cc b/mediapipe/framework/formats/tensor/cpu_buffer_converters.cc deleted file mode 100644 index e4e705be5..000000000 --- a/mediapipe/framework/formats/tensor/cpu_buffer_converters.cc +++ /dev/null @@ -1,240 +0,0 @@ -#include -#include -#include - -#include "mediapipe/framework/formats/tensor/backend.h" -#include "mediapipe/framework/formats/tensor/tensor2.h" -#include "mediapipe/framework/formats/tensor/views/buffer.h" -#include "mediapipe/framework/formats/tensor/views/cpu_buffer.h" -#include "third_party/FP16/include/fp16.h" - -namespace mediapipe { -namespace { - -template -auto ConverterCheckFunction() { - return - [](const Tensor2& tensor, uint64_t source_descriptor_type_id, - const Tensor2::ViewDescriptor& source_base_descriptor, - uint64_t destination_descriptor_type_id, - const Tensor2::ViewDescriptor& destination_base_descriptor) -> bool { - if (source_descriptor_type_id != TensorCpuView::kId || - destination_descriptor_type_id != TensorCpuView::kId) - return false; - auto source_descriptor = - static_cast(source_base_descriptor); - auto destination_descriptor = - static_cast( - destination_base_descriptor); - return source_descriptor.buffer.format == - TensorTypeToFormat::value && - destination_descriptor.buffer.format == - TensorTypeToFormat::value; - }; -} - -template -auto ConvertFunction() { - return [](const Tensor2& tensor, const Tensor2::View& source_base_view, - const Tensor2::View& destination_base_view) -> bool { - auto source = source_base_view.DownCast(); - auto destination = destination_base_view.DownCast(); - if (source->descriptor().buffer.format == - destination->descriptor().buffer.format) { - std::memcpy( - destination->data(), source->data(), - TensorBufferSize(destination->descriptor().buffer, tensor.shape())); - } else { - auto source_pointer = source->data(); - auto destination_pointer = destination->data(); - for (int i = 0; i < tensor.shape().NumElements(); i++) { - *destination_pointer++ = - GpuLikeTypeCast(*source_pointer++); - } - } - return true; - }; -} - -#define REGISTER_CONVERTER(SourceType, DestinationType) \ - TENSOR_REGISTER_CONVERTER( \ - {ConverterCheckFunction(), \ - ConvertFunction()}); - -REGISTER_CONVERTER(float, Float16); -REGISTER_CONVERTER(float, int8_t); -REGISTER_CONVERTER(float, uint8_t); -REGISTER_CONVERTER(float, int16_t); -REGISTER_CONVERTER(float, uint16_t); -REGISTER_CONVERTER(float, int32_t); -REGISTER_CONVERTER(float, uint32_t); - -REGISTER_CONVERTER(Float16, float); -REGISTER_CONVERTER(Float16, int8_t); -REGISTER_CONVERTER(Float16, uint8_t); -REGISTER_CONVERTER(Float16, int16_t); -REGISTER_CONVERTER(Float16, uint16_t); -REGISTER_CONVERTER(Float16, int32_t); -REGISTER_CONVERTER(Float16, uint32_t); - -REGISTER_CONVERTER(int8_t, float); -REGISTER_CONVERTER(int8_t, Float16); -REGISTER_CONVERTER(int8_t, uint8_t); -REGISTER_CONVERTER(int8_t, int16_t); -REGISTER_CONVERTER(int8_t, uint16_t); -REGISTER_CONVERTER(int8_t, int32_t); -REGISTER_CONVERTER(int8_t, uint32_t); - -REGISTER_CONVERTER(uint8_t, float); -REGISTER_CONVERTER(uint8_t, Float16); -REGISTER_CONVERTER(uint8_t, int8_t); -REGISTER_CONVERTER(uint8_t, int16_t); -REGISTER_CONVERTER(uint8_t, uint16_t); -REGISTER_CONVERTER(uint8_t, int32_t); -REGISTER_CONVERTER(uint8_t, uint32_t); - -REGISTER_CONVERTER(int16_t, float); -REGISTER_CONVERTER(int16_t, Float16); -REGISTER_CONVERTER(int16_t, int8_t); -REGISTER_CONVERTER(int16_t, uint8_t); -REGISTER_CONVERTER(int16_t, uint16_t); -REGISTER_CONVERTER(int16_t, uint32_t); -REGISTER_CONVERTER(int16_t, uint32_t); - -REGISTER_CONVERTER(uint16_t, float); -REGISTER_CONVERTER(uint16_t, Float16); -REGISTER_CONVERTER(uint16_t, int8_t); -REGISTER_CONVERTER(uint16_t, uint8_t); -REGISTER_CONVERTER(uint16_t, int16_t); -REGISTER_CONVERTER(uint16_t, int32_t); -REGISTER_CONVERTER(uint16_t, uint32_t); - -REGISTER_CONVERTER(int32_t, float); -REGISTER_CONVERTER(int32_t, Float16); -REGISTER_CONVERTER(int32_t, int8_t); -REGISTER_CONVERTER(int32_t, uint8_t); -REGISTER_CONVERTER(int32_t, int16_t); -REGISTER_CONVERTER(int32_t, uint16_t); -REGISTER_CONVERTER(int32_t, uint32_t); - -REGISTER_CONVERTER(uint32_t, float); -REGISTER_CONVERTER(uint32_t, Float16); -REGISTER_CONVERTER(uint32_t, int8_t); -REGISTER_CONVERTER(uint32_t, uint8_t); -REGISTER_CONVERTER(uint32_t, int16_t); -REGISTER_CONVERTER(uint32_t, uint16_t); -REGISTER_CONVERTER(uint32_t, int32_t); - -template -auto DequantizationCheckFunction() { - return - [](const Tensor2& tensor, uint64_t source_descriptor_type_id, - const Tensor2::ViewDescriptor& source_base_descriptor, - uint64_t destination_descriptor_type_id, - const Tensor2::ViewDescriptor& destination_base_descriptor) -> bool { - if (source_descriptor_type_id != TensorCpuView::kId || - destination_descriptor_type_id != TensorCpuView::kId) - return false; - auto source_descriptor = - static_cast(source_base_descriptor); - auto destination_descriptor = - static_cast( - destination_base_descriptor); - return source_descriptor.buffer.format == - TensorBufferDescriptor::Format::kQuantizedInt8 && - destination_descriptor.buffer.format == - TensorTypeToFormat::value; - }; -} - -template -auto DequantizationConvertFunction() { - return [](const Tensor2& tensor, const Tensor2::View& source_base_view, - const Tensor2::View& destination_base_view) -> bool { - auto source = source_base_view.DownCast(); - auto destination = destination_base_view.DownCast(); - auto source_pointer = source->data(); - auto destination_pointer = destination->data(); - int zero_point = - source->descriptor().buffer.quantization_parameters.zero_point; - float scale = source->descriptor().buffer.quantization_parameters.scale; - for (int i = 0; i < tensor.shape().NumElements(); i++) { - *destination_pointer++ = static_cast( - (*source_pointer++ - zero_point) * scale); - } - return true; - }; -} - -#define REGISTER_DEQUANTIZATION_CONVERTER(DestinationType) \ - TENSOR_REGISTER_CONVERTER( \ - {DequantizationCheckFunction(), \ - DequantizationConvertFunction()}); - -REGISTER_DEQUANTIZATION_CONVERTER(float); -REGISTER_DEQUANTIZATION_CONVERTER(Float16); -REGISTER_DEQUANTIZATION_CONVERTER(int8_t); -REGISTER_DEQUANTIZATION_CONVERTER(uint8_t); -REGISTER_DEQUANTIZATION_CONVERTER(int16_t); -REGISTER_DEQUANTIZATION_CONVERTER(uint16_t); -REGISTER_DEQUANTIZATION_CONVERTER(int32_t); -REGISTER_DEQUANTIZATION_CONVERTER(uint32_t); - -template -auto QuantizationCheckFunction() { - return - [](const Tensor2& tensor, uint64_t source_descriptor_type_id, - const Tensor2::ViewDescriptor& source_base_descriptor, - uint64_t destination_descriptor_type_id, - const Tensor2::ViewDescriptor& destination_base_descriptor) -> bool { - if (source_descriptor_type_id != TensorCpuView::kId || - destination_descriptor_type_id != TensorCpuView::kId) - return false; - auto source_descriptor = - static_cast(source_base_descriptor); - auto destination_descriptor = - static_cast( - destination_base_descriptor); - bool same = source_descriptor.buffer.format == - TensorTypeToFormat::value && - destination_descriptor.buffer.format == - TensorBufferDescriptor::Format::kQuantizedInt8; - return same; - }; -} - -template -auto QuantizationConvertFunction() { - return [](const Tensor2& tensor, const Tensor2::View& source_base_view, - const Tensor2::View& destination_base_view) -> bool { - auto source = source_base_view.DownCast(); - auto destination = destination_base_view.DownCast(); - auto source_pointer = source->data(); - auto destination_pointer = destination->data(); - int zero_point = - destination->descriptor().buffer.quantization_parameters.zero_point; - float scale = - destination->descriptor().buffer.quantization_parameters.scale; - for (int i = 0; i < tensor.shape().NumElements(); i++) { - *destination_pointer++ = - static_cast(*source_pointer++ / scale + zero_point); - } - return true; - }; -} - -#define REGISTER_QUANTIZATION_CONVERTER(SourceType) \ - TENSOR_REGISTER_CONVERTER({QuantizationCheckFunction(), \ - QuantizationConvertFunction()}); - -REGISTER_QUANTIZATION_CONVERTER(float); -REGISTER_QUANTIZATION_CONVERTER(Float16); -REGISTER_QUANTIZATION_CONVERTER(int8_t); -REGISTER_QUANTIZATION_CONVERTER(uint8_t); -REGISTER_QUANTIZATION_CONVERTER(int16_t); -REGISTER_QUANTIZATION_CONVERTER(uint16_t); -REGISTER_QUANTIZATION_CONVERTER(int32_t); -REGISTER_QUANTIZATION_CONVERTER(uint32_t); - -} // namespace -} // namespace mediapipe diff --git a/mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc b/mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc deleted file mode 100644 index 3619ad531..000000000 --- a/mediapipe/framework/formats/tensor/cpu_buffer_converters_test.cc +++ /dev/null @@ -1,282 +0,0 @@ -#include - -#include "mediapipe/framework/formats/tensor/tensor2.h" -#include "mediapipe/framework/formats/tensor/views/buffer.h" -#include "mediapipe/framework/formats/tensor/views/cpu_buffer.h" -#include "testing/base/public/gmock.h" -#include "testing/base/public/gunit.h" - -MATCHER_P(NearWithPrecision, precision, "") { - return std::abs(std::get<0>(arg) - std::get<1>(arg)) < precision; -} -MATCHER_P(IntegerEqual, precision, "") { - return std::get<0>(arg) == std::get<1>(arg); -} - -namespace mediapipe { - -TEST(TensorCpuViewTest, TestWrite32ThenRead16) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = 1234.0f; - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat16}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_EQ(*view->data(), 1234.0f); - } -} - -TEST(TensorCpuViewTest, TestWrite16ThenRead32) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat16}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = 1234.0f; - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_EQ(*view->data(), 1234.0f); - } -} - -TEST(TensorCpuViewTest, TestWriteFloat32ThenReadInt8) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = 0.121569f; - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_EQ( - *view->data(), - static_cast(0.121569f * std::numeric_limits::max())); - } -} - -TEST(TensorCpuViewTest, TestWriteInt8ThenReadFloat32) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = - static_cast(0.123f * std::numeric_limits::max()); - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_NEAR(*view->data(), 0.123f, - 1.0f / std::numeric_limits::max()); - } -} - -TEST(TensorCpuViewTest, TestWriteUInt8ThenReadUInt16) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = 123; - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kUInt16}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_EQ(*view->data(), uint16_t{123} << 8); - } -} - -TEST(TensorCpuViewTest, TestWriteUInt16ThenReadUInt8) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kUInt16}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = uint16_t{123} << 8; - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_EQ(*view->data(), 123); - } -} - -TEST(TensorCpuViewTest, TestWriteNegativeInt8ThenReadUInt8) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kInt8}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = -123; - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_EQ(*view->data(), 0); - } -} - -TEST(TensorCpuViewTest, TestWritePositiveInt8ThenReadUInt8) { - Tensor2 tensor{Tensor2::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kInt8}})); - ASSERT_NE(view->data(), nullptr); - *view->data() = 123; - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kUInt8}})); - ASSERT_NE(view->data(), nullptr); - EXPECT_EQ(*view->data(), 123 * 2); - } -} - -TEST(TensorCpuViewTest, TestDequantization) { - constexpr int num_elements = 20; - // Gives quantization values in range [-100, 90]. - constexpr int zero_point = -100; - constexpr float scale = 2.0f; - Tensor2 tensor{Tensor2::Shape({num_elements})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = { - .format = TensorBufferDescriptor::Format::kQuantizedInt8, - .quantization_parameters = {.scale = scale, - .zero_point = zero_point}}})); - ASSERT_NE(view->data(), nullptr); - auto data = view->data(); - for (int i = 0; i < num_elements; ++i) { - // Add some bias (+1) to make round-up take place. - data[i] = (i * 20 + 1) / scale + zero_point; - } - } - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - ASSERT_NE(view->data(), nullptr); - std::vector reference(num_elements); - for (int i = 0; i < num_elements; ++i) { - reference[i] = i * 20.0f + 1.0f; - } - EXPECT_THAT(absl::Span(view->data(), num_elements), - testing::Pointwise(NearWithPrecision(1.001), reference)); - } -} - -TEST(TensorCpuViewTest, TestQuantization) { - constexpr int num_elements = 20; - // Gives quantization values in range [-100, 90]. - constexpr int zero_point = -100; - constexpr float scale = 2.0f; - Tensor2 tensor{Tensor2::Shape({num_elements})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - ASSERT_NE(view->data(), nullptr); - auto data = view->data(); - for (int i = 0; i < num_elements; ++i) { - // Add some bias (+1) to make round-up take place. - data[i] = i * 20 + 1; - } - } - { - TensorCpuViewDescriptor d{ - .buffer = {.format = TensorBufferDescriptor::Format::kQuantizedInt8, - .quantization_parameters = {.scale = scale, - .zero_point = zero_point}}}; - MP_ASSERT_OK_AND_ASSIGN( - auto view, tensor.GetView(d)); - ASSERT_NE(view->data(), nullptr); - std::vector reference(num_elements); - for (int i = 0; i < num_elements; ++i) { - reference[i] = (i * 20 + 1) / scale + zero_point; - } - EXPECT_THAT(absl::Span(view->data(), num_elements), - testing::Pointwise(IntegerEqual(0), reference)); - } -} - -} // namespace mediapipe From 9cbb76939dd069eacecae103c5c27b6e07c7e9c7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 11 Jan 2023 20:33:26 -0800 Subject: [PATCH 297/346] Adds smaller MobileBERT model. PiperOrigin-RevId: 501451414 --- .../model_maker/models/text_classifier/BUILD | 45 ++++++++++ .../python/text/text_classifier/BUILD | 11 +++ .../python/text/text_classifier/model_spec.py | 13 +-- .../text/text_classifier/model_spec_test.py | 7 +- .../text/text_classifier/testdata/BUILD | 5 +- .../testdata/bert_metadata.json | 84 +++++++++++++++++++ .../text/text_classifier/text_classifier.py | 13 ++- .../text_classifier/text_classifier_test.py | 25 +++++- mediapipe/model_maker/setup.py | 12 ++- third_party/external_files.bzl | 30 +++++++ 10 files changed, 228 insertions(+), 17 deletions(-) create mode 100644 mediapipe/model_maker/models/text_classifier/BUILD create mode 100644 mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json diff --git a/mediapipe/model_maker/models/text_classifier/BUILD b/mediapipe/model_maker/models/text_classifier/BUILD new file mode 100644 index 000000000..4c54bbccc --- /dev/null +++ b/mediapipe/model_maker/models/text_classifier/BUILD @@ -0,0 +1,45 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) + +licenses(["notice"]) + +package( + default_visibility = ["//mediapipe/model_maker/python/text/text_classifier:__subpackages__"], +) + +mediapipe_files( + srcs = [ + "mobilebert_tiny/assets/vocab.txt", + "mobilebert_tiny/keras_metadata.pb", + "mobilebert_tiny/saved_model.pb", + "mobilebert_tiny/variables/variables.data-00000-of-00001", + "mobilebert_tiny/variables/variables.index", + ], +) + +filegroup( + name = "mobilebert_tiny", + srcs = [ + "mobilebert_tiny/assets/vocab.txt", + "mobilebert_tiny/keras_metadata.pb", + "mobilebert_tiny/saved_model.pb", + "mobilebert_tiny/variables/variables.data-00000-of-00001", + "mobilebert_tiny/variables/variables.index", + ], +) diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 7bb41351e..43f2b6c75 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -53,6 +53,7 @@ py_library( deps = [ ":model_options", "//mediapipe/model_maker/python/core:hyperparameters", + "//mediapipe/model_maker/python/core/utils:file_util", "//mediapipe/model_maker/python/text/core:bert_model_spec", ], ) @@ -88,6 +89,9 @@ py_library( py_test( name = "preprocessor_test", srcs = ["preprocessor_test.py"], + data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", + ], tags = ["requires-net:external"], deps = [ ":dataset", @@ -109,6 +113,9 @@ py_library( py_library( name = "text_classifier", srcs = ["text_classifier.py"], + data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", + ], deps = [ ":dataset", ":model_options", @@ -130,6 +137,7 @@ py_test( size = "large", srcs = ["text_classifier_test.py"], data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", "//mediapipe/model_maker/python/text/text_classifier/testdata", ], tags = ["requires-net:external"], @@ -151,6 +159,9 @@ py_library( py_binary( name = "text_classifier_demo", srcs = ["text_classifier_demo.py"], + data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", + ], deps = [ ":text_classifier_demo_lib", ], diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec.py b/mediapipe/model_maker/python/text/text_classifier/model_spec.py index 9df7e1039..a6bdd9522 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec.py @@ -18,12 +18,15 @@ import enum import functools from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.text.core import bert_model_spec from mediapipe.model_maker.python.text.text_classifier import model_options as mo # BERT-based text classifier spec inherited from BertModelSpec BertClassifierSpec = bert_model_spec.BertModelSpec +MOBILEBERT_TINY_PATH = 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny/' + @dataclasses.dataclass class AverageWordEmbeddingClassifierSpec: @@ -49,16 +52,14 @@ average_word_embedding_classifier_spec = functools.partial( mobilebert_classifier_spec = functools.partial( BertClassifierSpec, hparams=hp.BaseHParams( - epochs=3, - batch_size=48, - learning_rate=3e-5, - distribution_strategy='off'), + epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' + ), name='MobileBert', - uri='https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1', + uri=file_util.get_absolute_path(MOBILEBERT_TINY_PATH), tflite_input_name={ 'ids': 'serving_default_input_1:0', 'mask': 'serving_default_input_3:0', - 'segment_ids': 'serving_default_input_2:0' + 'segment_ids': 'serving_default_input_2:0', }, ) diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py index dd7f880f3..3ea019b44 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py @@ -28,9 +28,10 @@ class ModelSpecTest(tf.test.TestCase): model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value() self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec) self.assertEqual(model_spec_obj.name, 'MobileBert') - self.assertEqual( - model_spec_obj.uri, 'https://tfhub.dev/tensorflow/' - 'mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1') + self.assertIn( + 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny', + model_spec_obj.uri, + ) self.assertTrue(model_spec_obj.do_lower_case) self.assertEqual( model_spec_obj.tflite_input_name, { diff --git a/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD b/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD index 663c72082..a581462cf 100644 --- a/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD @@ -19,5 +19,8 @@ package( filegroup( name = "testdata", - srcs = ["average_word_embedding_metadata.json"], + srcs = [ + "average_word_embedding_metadata.json", + "bert_metadata.json", + ], ) diff --git a/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json b/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json new file mode 100644 index 000000000..24214a80d --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json @@ -0,0 +1,84 @@ +{ + "name": "TextClassifier", + "description": "Classify the input text into a set of known categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "ids", + "description": "Tokenized ids of the input text.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "mask", + "description": "Mask with 1 for real tokens and 0 for padding tokens.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "segment_ids", + "description": "0 for the first sequence, 1 for the second sequence if exists.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "Score of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ], + "input_process_units": [ + { + "options_type": "BertTokenizerOptions", + "options": { + "vocab_file": [ + { + "name": "vocab.txt", + "description": "Vocabulary file to convert natural language words to embedding vectors.", + "type": "VOCABULARY" + } + ] + } + } + ] + } + ], + "min_parser_version": "1.1.0" +} diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index 1a338e345..f6abc8bf0 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -269,16 +269,21 @@ class _AverageWordEmbeddingClassifier(TextClassifier): """Creates an Average Word Embedding model.""" self._model = tf.keras.Sequential([ tf.keras.layers.InputLayer( - input_shape=[self._model_options.seq_len], dtype=tf.int32), + input_shape=[self._model_options.seq_len], + dtype=tf.int32, + name="input_ids", + ), tf.keras.layers.Embedding( len(self._text_preprocessor.get_vocab()), self._model_options.wordvec_dim, - input_length=self._model_options.seq_len), + input_length=self._model_options.seq_len, + ), tf.keras.layers.GlobalAveragePooling1D(), tf.keras.layers.Dense( - self._model_options.wordvec_dim, activation=tf.nn.relu), + self._model_options.wordvec_dim, activation=tf.nn.relu + ), tf.keras.layers.Dropout(self._model_options.dropout_rate), - tf.keras.layers.Dense(self._num_classes, activation="softmax") + tf.keras.layers.Dense(self._num_classes, activation="softmax"), ]) def _save_vocab(self, vocab_filepath: str): diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index eb4443b44..1ae2bc553 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -26,6 +26,9 @@ class TextClassifierTest(tf.test.TestCase): _AVERAGE_WORD_EMBEDDING_JSON_FILE = ( test_utils.get_test_data_path('average_word_embedding_metadata.json')) + _BERT_CLASSIFIER_JSON_FILE = test_utils.get_test_data_path( + 'bert_metadata.json' + ) def _get_data(self): labels_and_text = (('pos', 'super good'), (('neg', 'really bad'))) @@ -94,7 +97,27 @@ class TextClassifierTest(tf.test.TestCase): _, accuracy = bert_classifier.evaluate(validation_data) self.assertGreaterEqual(accuracy, 0.0) - # TODO: Add a unit test that does not run OOM. + + # Test export_model + bert_classifier.export_model() + output_metadata_file = os.path.join( + options.hparams.export_dir, 'metadata.json' + ) + output_tflite_file = os.path.join( + options.hparams.export_dir, 'model.tflite' + ) + + self.assertTrue(os.path.exists(output_tflite_file)) + self.assertGreater(os.path.getsize(output_tflite_file), 0) + + self.assertTrue(os.path.exists(output_metadata_file)) + self.assertGreater(os.path.getsize(output_metadata_file), 0) + filecmp.clear_cache() + self.assertTrue( + filecmp.cmp( + output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False + ) + ) def test_label_mismatch(self): options = ( diff --git a/mediapipe/model_maker/setup.py b/mediapipe/model_maker/setup.py index 7114e2080..1dac6301a 100644 --- a/mediapipe/model_maker/setup.py +++ b/mediapipe/model_maker/setup.py @@ -81,7 +81,10 @@ def _setup_build_dir(): file.write(filedata) # Use bazel to download GCS model files - model_build_files = ['models/gesture_recognizer/BUILD'] + model_build_files = [ + 'models/gesture_recognizer/BUILD', + 'models/text_classifier/BUILD', + ] for model_build_file in model_build_files: build_target_file = os.path.join(BUILD_MM_DIR, model_build_file) os.makedirs(os.path.dirname(build_target_file), exist_ok=True) @@ -95,7 +98,12 @@ def _setup_build_dir(): 'models/gesture_recognizer/gesture_embedder/saved_model.pb', 'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001', 'models/gesture_recognizer/gesture_embedder/variables/variables.index', - ] + 'models/text_classifier/mobilebert_tiny/keras_metadata.pb', + 'models/text_classifier/mobilebert_tiny/saved_model.pb', + 'models/text_classifier/mobilebert_tiny/assets/vocab.txt', + 'models/text_classifier/mobilebert_tiny/variables/variables.data-00000-of-00001', + 'models/text_classifier/mobilebert_tiny/variables/variables.index', + ] for elem in external_files: external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem) sys.stderr.write('downloading file: %s\n' % external_file) diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 790486676..5adfbdfc6 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -1006,6 +1006,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"], ) + http_file( + name = "com_google_mediapipe_mobilebert_tiny_keras_metadata_pb", + sha256 = "cef8131a414c602b9d4742ac57f4f90bc5d8a42baec36b65deece884e2d0cf0f", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/keras_metadata.pb?generation=1673297965144159"], + ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_saved_model_pb", + sha256 = "323c997cd3e17df1b2e3bdebe3cfe2b17c5ffd9488a26a4afb59ee819196837a", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/saved_model.pb?generation=1673297968138825"], + ) + http_file( name = "com_google_mediapipe_object_detection_saved_model_model_ckpt_data-00000-of-00001", sha256 = "ad2f733f271dd5000a8c7f926bfea1083e6408b34d4f3b60679e5a6f96251c97", @@ -1053,3 +1065,21 @@ def external_files(): sha256 = "76ea482b8da6bdb3d65d3b2ea989c1699c9fa0d6df0cb6d80863d1dc6fe7c4bd", urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"], ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_assets_vocab_txt", + sha256 = "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/assets/vocab.txt?generation=1673297970948751"], + ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_variables_variables_data-00000-of-00001", + sha256 = "c3857370046cd3a2f345657cf1bb259a4e7e09185d7f0808e57803e9d41ebba4", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/variables/variables.data-00000-of-00001?generation=1673297975132568"], + ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_variables_variables_index", + sha256 = "4df4d7c0fefe99903ab6ebf44b7478196ce613082d2ca692a5a37a7f24e562ed", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/variables/variables.index?generation=1673297977586840"], + ) From 5c74ed2ae58eeb6b6f9b18aa47edf52e08a0eccb Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 12 Jan 2023 08:27:57 -0800 Subject: [PATCH 298/346] EmbeddingAggregationCalculator should fill in the `timestamp_ms` field of the embedding results in the stream mode. Per user feedback, the consistency between the packet timestamp and the timestamp field of the embedding result helps reducing the confusion. PiperOrigin-RevId: 501572379 --- .../calculators/embedding_aggregation_calculator.cc | 4 +++- .../calculators/embedding_aggregation_calculator_test.cc | 8 +++++--- .../processors/embedding_postprocessing_graph_test.cc | 7 ++++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc index bae926b76..6e06c4e32 100644 --- a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc @@ -120,7 +120,9 @@ absl::Status EmbeddingAggregationCalculator::Process(CalculatorContext* cc) { } kTimestampedEmbeddingsOut(cc).Send(std::move(results)); } else { - kEmbeddingsOut(cc).Send(kEmbeddingsIn(cc)); + auto result = kEmbeddingsIn(cc).Get(); + result.set_timestamp_ms(cc->InputTimestamp().Value() / 1000); + kEmbeddingsOut(cc).Send(result); } RET_CHECK(cached_embeddings_.empty()); return absl::OkStatus(); diff --git a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc index ebb4d8880..f2b2fa1d5 100644 --- a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc +++ b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc @@ -120,7 +120,7 @@ class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test { CalculatorGraph calculator_graph_; }; -TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutTimestamps) { +TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutAggregation) { EmbeddingResult embedding = ParseTextProtoOrDie( R"pb(embeddings { head_index: 0 })pb"); @@ -129,10 +129,12 @@ TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutTimestamps) { MP_ASSERT_OK(Send(embedding)); MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult(poller)); - EXPECT_THAT(result, EqualsProto(embedding)); + EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie( + R"pb(timestamp_ms: 0 + embeddings { head_index: 0 })pb"))); } -TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithTimestamps) { +TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithAggregation) { MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true)); MP_ASSERT_OK(Send(ParseTextProtoOrDie(R"pb(embeddings { head_index: 0 diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc index 163e46ee8..809268a63 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc @@ -246,7 +246,7 @@ class PostprocessingTest : public tflite_shims::testing::Test { absl::make_unique>(); }; -TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) { +TEST_F(PostprocessingTest, SucceedsWithoutAggregation) { // Build graph. proto::EmbedderOptions options; MP_ASSERT_OK_AND_ASSIGN(auto poller, @@ -261,7 +261,8 @@ TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) { MP_ASSERT_OK_AND_ASSIGN(auto results, GetResult(poller)); // Validate results. - EXPECT_FALSE(results.has_timestamp_ms()); + EXPECT_TRUE(results.has_timestamp_ms()); + EXPECT_EQ(results.timestamp_ms(), 0); EXPECT_EQ(results.embeddings_size(), 1); EXPECT_EQ(results.embeddings(0).head_index(), 0); EXPECT_EQ(results.embeddings(0).head_name(), "feature"); @@ -273,7 +274,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) { } } -TEST_F(PostprocessingTest, SucceedsWithTimestamps) { +TEST_F(PostprocessingTest, SucceedsWithAggregation) { // Build graph. proto::EmbedderOptions options; MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(kMobileNetV3Embedder, options, From 74b60780c7a5e9fe07c10513201b499a99fd137e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 12 Jan 2023 09:58:34 -0800 Subject: [PATCH 299/346] Internal change PiperOrigin-RevId: 501594400 --- mediapipe/framework/deps/registration.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/framework/deps/registration.h b/mediapipe/framework/deps/registration.h index 1a33b2b24..9d80aafea 100644 --- a/mediapipe/framework/deps/registration.h +++ b/mediapipe/framework/deps/registration.h @@ -253,7 +253,7 @@ class FunctionRegistry { if (names[0].empty()) { names.erase(names.begin()); } else { - CHECK_EQ(1, names.size()) + CHECK_EQ(1u, names.size()) << "A registered class name must be either fully qualified " << "with a leading :: or unqualified, got: " << name << "."; } From 1683d572ed778c444d8c8e1b7f9f9a240a65667e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 12 Jan 2023 10:20:09 -0800 Subject: [PATCH 300/346] Internal change PiperOrigin-RevId: 501600938 --- mediapipe/gpu/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index cc5e50dfc..9074daf61 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -1021,8 +1021,8 @@ objc_library( visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", + ":copy_calculator_cc_proto", ":simple_shaders_mtl", - "//mediapipe/gpu:copy_calculator_cc_proto", "//mediapipe/objc:mediapipe_framework_ios", "//third_party/apple_frameworks:CoreVideo", "//third_party/apple_frameworks:Metal", From 8156da341833e9d6c8042a0cef1494770458c8a0 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 12 Jan 2023 13:52:27 -0800 Subject: [PATCH 301/346] ClassificationAggregationCalculator should fill in the `timestamp_ms` field of the classification results in the stream mode. Per user feedback, the consistency between the packet timestamp and the timestamp field of the classification result helps reducing the confusion. PiperOrigin-RevId: 501657922 --- .../audio_classifier/audio_classifier_test.cc | 3 +- .../classification_aggregation_calculator.cc | 1 + ...ssification_aggregation_calculator_test.cc | 7 ++- ...lassification_postprocessing_graph_test.cc | 4 ++ .../test/vision/image_classifier_test.py | 57 ++++++++++++------- 5 files changed, 46 insertions(+), 26 deletions(-) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc index 596b910f8..2d5b221a9 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc @@ -143,8 +143,9 @@ void CheckStreamingModeResults(std::vector outputs) { EXPECT_EQ(outputs.size(), 5); // Ignore last result, which operates on a too small chunk to return relevant // results. + std::vector timestamps_ms = {0, 975, 1950, 2925}; for (int i = 0; i < outputs.size() - 1; i++) { - EXPECT_FALSE(outputs[i].timestamp_ms.has_value()); + EXPECT_EQ(outputs[i].timestamp_ms.value(), timestamps_ms[i]); EXPECT_EQ(outputs[i].classifications.size(), 1); EXPECT_EQ(outputs[i].classifications[0].head_index, 0); EXPECT_EQ(outputs[i].classifications[0].head_name, "scores"); diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc index ad2c668c3..145076cd3 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc @@ -188,6 +188,7 @@ ClassificationAggregationCalculator::ConvertToClassificationResult( *classifications->mutable_classification_list() = std::move(classification_lists[i]); } + result.set_timestamp_ms(cc->InputTimestamp().Value() / 1000); cached_classifications_.erase(cc->InputTimestamp().Value()); return result; } diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc index 1bc8cafd6..811d70544 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc @@ -150,14 +150,15 @@ class ClassificationAggregationCalculatorTest CalculatorGraph calculator_graph_; }; -TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) { +TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutAggregation) { MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph()); MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)})); MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult(poller)); EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie( - R"pb(classifications { + R"pb(timestamp_ms: 0, + classifications { head_index: 0 head_name: "foo" classification_list { classification { index: 0 } } @@ -169,7 +170,7 @@ TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) { })pb"))); } -TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithTimestamps) { +TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithAggregation) { MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true)); MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)})); MP_ASSERT_OK(Send( diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc index 8eb6f3c3b..a11bad71a 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc @@ -534,6 +534,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { // Validate results. EXPECT_THAT(results, EqualsProto(ParseTextProtoOrDie(R"pb( + timestamp_ms: 0, classifications { head_index: 0 classification_list { @@ -567,6 +568,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) { // Validate results. EXPECT_THAT( results, EqualsProto(ParseTextProtoOrDie(R"pb( + timestamp_ms: 0, classifications { head_index: 0 head_name: "probability" @@ -603,6 +605,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { // Validate results. EXPECT_THAT( results, EqualsProto(ParseTextProtoOrDie(R"pb( + timestamp_ms: 0, classifications { head_index: 0 head_name: "probability" @@ -646,6 +649,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { // Validate results. EXPECT_THAT( results, EqualsProto(ParseTextProtoOrDie(R"pb( + timestamp_ms: 0, classifications { head_index: 0 head_name: "yamnet_classification" diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index cbeaf36bd..b47efb32b 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -61,7 +61,7 @@ def _generate_empty_results() -> ImageClassifierResult: timestamp_ms=0) -def _generate_burger_results() -> ImageClassifierResult: +def _generate_burger_results(timestamp_ms=0) -> ImageClassifierResult: return ImageClassifierResult( classifications=[ _Classifications( @@ -70,30 +70,36 @@ def _generate_burger_results() -> ImageClassifierResult: index=934, score=0.793959, display_name='', - category_name='cheeseburger'), + category_name='cheeseburger', + ), _Category( index=932, score=0.0273929, display_name='', - category_name='bagel'), + category_name='bagel', + ), _Category( index=925, score=0.0193408, display_name='', - category_name='guacamole'), + category_name='guacamole', + ), _Category( index=963, score=0.00632786, display_name='', - category_name='meat loaf') + category_name='meat loaf', + ), ], head_index=0, - head_name='probability') + head_name='probability', + ) ], - timestamp_ms=0) + timestamp_ms=timestamp_ms, + ) -def _generate_soccer_ball_results() -> ImageClassifierResult: +def _generate_soccer_ball_results(timestamp_ms=0) -> ImageClassifierResult: return ImageClassifierResult( classifications=[ _Classifications( @@ -102,12 +108,15 @@ def _generate_soccer_ball_results() -> ImageClassifierResult: index=806, score=0.996527, display_name='', - category_name='soccer ball') + category_name='soccer ball', + ) ], head_index=0, - head_name='probability') + head_name='probability', + ) ], - timestamp_ms=0) + timestamp_ms=timestamp_ms, + ) class ModelFileType(enum.Enum): @@ -379,8 +388,11 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( self.test_image, timestamp) - test_utils.assert_proto_equals(self, classification_result.to_pb2(), - _generate_burger_results().to_pb2()) + test_utils.assert_proto_equals( + self, + classification_result.to_pb2(), + _generate_burger_results(timestamp).to_pb2(), + ) def test_classify_for_video_succeeds_with_region_of_interest(self): options = _ImageClassifierOptions( @@ -398,8 +410,11 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( test_image, timestamp, image_processing_options) - test_utils.assert_proto_equals(self, classification_result.to_pb2(), - _generate_soccer_ball_results().to_pb2()) + test_utils.assert_proto_equals( + self, + classification_result.to_pb2(), + _generate_soccer_ball_results(timestamp).to_pb2(), + ) def test_calling_classify_in_live_stream_mode(self): options = _ImageClassifierOptions( @@ -455,8 +470,7 @@ class ImageClassifierTest(parameterized.TestCase): score_threshold=threshold, result_callback=check_result) with _ImageClassifier.create_from_options(options) as classifier: - for timestamp in range(0, 300, 30): - classifier.classify_async(self.test_image, timestamp) + classifier.classify_async(self.test_image, 0) def test_classify_async_succeeds_with_region_of_interest(self): # Load the test image. @@ -470,8 +484,9 @@ class ImageClassifierTest(parameterized.TestCase): def check_result(result: ImageClassifierResult, output_image: _Image, timestamp_ms: int): - test_utils.assert_proto_equals(self, result.to_pb2(), - _generate_soccer_ball_results().to_pb2()) + test_utils.assert_proto_equals( + self, result.to_pb2(), _generate_soccer_ball_results(100).to_pb2() + ) self.assertEqual(output_image.width, test_image.width) self.assertEqual(output_image.height, test_image.height) self.assertLess(observed_timestamp_ms, timestamp_ms) @@ -483,9 +498,7 @@ class ImageClassifierTest(parameterized.TestCase): max_results=1, result_callback=check_result) with _ImageClassifier.create_from_options(options) as classifier: - for timestamp in range(0, 300, 30): - classifier.classify_async(test_image, timestamp, - image_processing_options) + classifier.classify_async(test_image, 100, image_processing_options) if __name__ == '__main__': From 5642980ab01466b1fce7c1abad701ba2f0f13a76 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:04:03 +0530 Subject: [PATCH 302/346] Updated iOS error implementation to mimic java --- .../tasks/ios/common/sources/MPPCommon.h | 159 +++--------------- .../common/utils/sources/MPPCommonUtils.mm | 125 +++++++++----- 2 files changed, 104 insertions(+), 180 deletions(-) diff --git a/mediapipe/tasks/ios/common/sources/MPPCommon.h b/mediapipe/tasks/ios/common/sources/MPPCommon.h index 09a61e20d..f8047fc35 100644 --- a/mediapipe/tasks/ios/common/sources/MPPCommon.h +++ b/mediapipe/tasks/ios/common/sources/MPPCommon.h @@ -25,153 +25,44 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { // Generic error codes. - // Unspecified error. - MPPTasksErrorCodeError = 1, - // Invalid argument specified. - MPPTasksErrorCodeInvalidArgumentError = 2, - // Invalid FlatBuffer file or buffer specified. - MPPTasksErrorCodeInvalidFlatBufferError = 3, - // Model contains a builtin op that isn't supported by the OpResolver or - // delegates. - MPPTasksErrorCodeUnsupportedBuiltinOp = 4, - // Model contains a custom op that isn't supported by the OpResolver or - // delegates. - MPPTasksErrorCodeUnsupportedCustomOp = 5, + /** Indicates the operation was cancelled, typically by the caller. */ + MPPTasksErrorCodeCancelledError = 1, + /** Indicates an unknown error occurred. */ + MPPTasksErrorCodeUnknownError = 2, + /** Indicates the caller specified an invalid argument, such as a malformed filename. */ + MPPTasksErrorCodeInvalidArgumentError = 3, + /** Indicates a deadline expired before the operation could complete. */ + MPPTasksErrorCodeDeadlineExceededError = 4, + /** Indicates some requested entity (such as a file or directory) was not found. */ + MPPTasksErrorCodeNotFoundError = 5, + /** Indicates that the entity a caller attempted to create (such as a file or directory) is already present. */ + MPPTasksErrorCodeAlreadyExistsError = 6, + /** Indicates that the caller does not have permission to execute the specified operation. */ + MPPTasksErrorCodePermissionDeniedError = 7, - // File I/O error codes. + MPPTasksErrorCodeResourceExhaustedError = 8, - // No such file. - MPPTasksErrorCodeFileNotFoundError = 100, - // Permission issue. - MPPTasksErrorCodeFilePermissionDeniedError, - // I/O error when reading file. - MPPTasksErrorCodeFileReadError, - // I/O error when mmap-ing file. - MPPTasksErrorCodeFileMmapError, - // ZIP I/O error when unpacking the zip file. - MPPTasksErrorCodeFileZipError, + MPPTasksErrorCodeFailedPreconditionError = 9, - // TensorFlow Lite metadata error codes. + MPPTasksErrorCodeAbortedError = 10, - // Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer. - MPPTasksErrorCodeMetadataInvalidSchemaVersionError = 200, - // No such associated file within metadata, or file has not been packed. - MPPTasksErrorCodeMetadataAssociatedFileNotFoundError, - // ZIP I/O error when unpacking an associated file. - MPPTasksErrorCodeMetadataAssociatedFileZipError, - // Inconsistency error between the metadata and actual TF Lite model. - // E.g.: number of labels and output tensor values differ. - MPPTasksErrorCodeMetadataInconsistencyError, - // Invalid process units specified. - // E.g.: multiple ProcessUnits with the same type for a given tensor. - MPPTasksErrorCodeMetadataInvalidProcessUnitsError, - // Inconsistency error with the number of labels. - // E.g.: label files for different locales have a different number of labels. - MPPTasksErrorCodeMetadataNumLabelsMismatchError, - // Score calibration parameters parsing error. - // E.g.: too many parameters provided in the corresponding associated file. - MPPTasksErrorCodeMetadataMalformedScoreCalibrationError, - // Unexpected number of subgraphs for the current task. - // E.g.: image classification expects a single subgraph. - MPPTasksErrorCodeMetadataInvalidNumSubgraphsError, - // A given tensor requires NormalizationOptions but none were found. - // E.g.: float input tensor requires normalization to preprocess input images. - MPPTasksErrorCodeMetadataMissingNormalizationOptionsError, - // Invalid ContentProperties specified. - // E.g. expected ImageProperties, got BoundingBoxProperties. - MPPTasksErrorCodeMetadataInvalidContentPropertiesError, - // Metadata is mandatory but was not found. - // E.g. current task requires TFLite Model Metadata but none was found. - MPPTasksErrorCodeMetadataNotFoundError, - // Associated TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS file is mandatory but - // none was found or it was empty. - // E.g. current task requires labels but none were found. - MPPTasksErrorCodeMetadataMissingLabelsError, - // The ProcessingUnit for tokenizer is not correctly configured. - // E.g BertTokenizer doesn't have a valid vocab file associated. - MPPTasksErrorCodeMetadataInvalidTokenizerError, + MPPTasksErrorCodeOutOfRangeError = 11, - // Input tensor(s) error codes. + MPPTasksErrorCodeUnimplementedError = 12, - // Unexpected number of input tensors for the current task. - // E.g. current task expects a single input tensor. - MPPTasksErrorCodeInvalidNumInputTensorsError = 300, - // Unexpected input tensor dimensions for the current task. - // E.g.: only 4D input tensors supported. - MPPTasksErrorCodeInvalidInputTensorDimensionsError, - // Unexpected input tensor type for the current task. - // E.g.: current task expects a uint8 pixel image as input. - MPPTasksErrorCodeInvalidInputTensorTypeError, - // Unexpected input tensor bytes size. - // E.g.: size in bytes does not correspond to the expected number of pixels. - MPPTasksErrorCodeInvalidInputTensorSizeError, - // No correct input tensor found for the model. - // E.g.: input tensor name is not part of the text model's input tensors. - MPPTasksErrorCodeInputTensorNotFoundError, + MPPTasksErrorCodeInternalError = 13, - // Output tensor(s) error codes. + MPPTasksErrorCodeUnavailableError = 14, - // Unexpected output tensor dimensions for the current task. - // E.g.: only a batch size of 1 is supported. - MPPTasksErrorCodeInvalidOutputTensorDimensionsError = 400, - // Unexpected input tensor type for the current task. - // E.g.: multi-head model with different output tensor types. - MPPTasksErrorCodeInvalidOutputTensorTypeError, - // No correct output tensor found for the model. - // E.g.: output tensor name is not part of the text model's output tensors. - MPPTasksErrorCodeOutputTensorNotFoundError, - // Unexpected number of output tensors for the current task. - // E.g.: current task expects a single output tensor. - MPPTasksErrorCodeInvalidNumOutputTensorsError, + MPPTasksErrorCodeDataLossError = 15, - // Image processing error codes. - - // Unspecified image processing failures. - MPPTasksErrorCodeImageProcessingError = 500, - // Unexpected input or output buffer metadata. - // E.g.: rotate RGBA buffer to Grayscale buffer by 90 degrees. - MPPTasksErrorCodeImageProcessingInvalidArgumentError, - // Image processing operation failures. - // E.g. libyuv rotation failed for an unknown reason. - MPPTasksErrorCodeImageProcessingBackendError, - - // Task runner error codes. - MPPTasksErrorCodeRunnerError = 600, - // Task runner is not initialized. - MPPTasksErrorCodeRunnerInitializationError, - // Task runner is not started successfully. - MPPTasksErrorCodeRunnerFailsToStartError, - // Task runner is not started. - MPPTasksErrorCodeRunnerNotStartedError, - // Task runner API is called in the wrong processing mode. - MPPTasksErrorCodeRunnerApiCalledInWrongModeError, - // Task runner receives/produces invalid MediaPipe packet timestamp. - MPPTasksErrorCodeRunnerInvalidTimestampError, - // Task runner receives unexpected MediaPipe graph input packet. - // E.g. The packet type doesn't match the graph input stream's data type. - MPPTasksErrorCodeRunnerUnexpectedInputError, - // Task runner produces unexpected MediaPipe graph output packet. - // E.g. The number of output packets is not equal to the number of graph - // output streams. - MPPTasksErrorCodeRunnerUnexpectedOutputError, - // Task runner is not closed successfully. - MPPTasksErrorCodeRunnerFailsToCloseError, - // Task runner's model resources cache service is unavailable or the - // targeting model resources bundle is not found. - MPPTasksErrorCodeRunnerModelResourcesCacheServiceError, - - // Task graph error codes. - MPPTasksErrorCodeGraphError = 700, - // Task graph is not implemented. - MPPTasksErrorCodeTaskGraphNotImplementedError, - // Task graph config is invalid. - MPPTasksErrorCodeInvalidTaskGraphConfigError, + MPPTasksErrorCodeUnauthenticatedError = 16, // The first error code in MPPTasksErrorCode (for internal use only). - MPPTasksErrorCodeFirst = MPPTasksErrorCodeError, + MPPTasksErrorCodeFirst = MPPTasksErrorCodeCancelledError, // The last error code in MPPTasksErrorCode (for internal use only). - MPPTasksErrorCodeLast = MPPTasksErrorCodeInvalidTaskGraphConfigError, + MPPTasksErrorCodeLast = MPPTasksErrorCodeUnauthenticatedError, } NS_SWIFT_NAME(TasksErrorCode); diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 1a37f8465..9932dd13c 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -25,6 +25,10 @@ /** Error domain of MediaPipe task library errors. */ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; +namespace { + using absl::StatusCode; +} + @implementation MPPCommonUtils + (void)createCustomError:(NSError **)error @@ -67,52 +71,6 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; if (status.ok()) { return YES; } - // Payload of absl::Status created by the MediaPipe task library stores an appropriate value of - // the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum - // stored in the payload is extracted here to later map to the appropriate error code to be - // returned. In cases where the enum is not stored in (payload is NULL or the payload string - // cannot be converted to an integer), we set the error code value to be 1 - // (MPPTasksErrorCodeError of MPPTasksErrorCode used in the iOS library to signify - // any errors not falling into other categories.) Since payload is of type absl::Cord that can be - // type cast into an absl::optional, we use the std::stoi function to convert it into - // an integer code if possible. - NSUInteger genericErrorCode = MPPTasksErrorCodeError; - NSUInteger errorCode; - try { - // Try converting payload to integer if payload is not empty. Otherwise convert a string - // signifying generic error code MPPTasksErrorCodeError to integer. - errorCode = - (NSUInteger)std::stoi(static_cast>( - status.GetPayload(mediapipe::tasks::kMediaPipeTasksPayload)) - .value_or(std::to_string(genericErrorCode))); - } catch (std::invalid_argument &e) { - // If non empty payload string cannot be converted to an integer. Set error code to 1(kError). - errorCode = MPPTasksErrorCodeError; - } - - // If errorCode is outside the range of enum values possible or is - // MPPTasksErrorCodeError, we try to map the absl::Status::code() to assign - // appropriate MPPTasksErrorCode in default cases. Note: - // The mapping to absl::Status::code() is done to generate a more specific error code than - // MPPTasksErrorCodeError in cases when the payload can't be mapped to - // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn - // returned without modification by MediaPipe cc library methods. - if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) { - switch (status.code()) { - case absl::StatusCode::kInternal: - errorCode = MPPTasksErrorCodeError; - break; - case absl::StatusCode::kInvalidArgument: - errorCode = MPPTasksErrorCodeInvalidArgumentError; - break; - case absl::StatusCode::kNotFound: - errorCode = MPPTasksErrorCodeError; - break; - default: - errorCode = MPPTasksErrorCodeError; - break; - } - } // Creates the NSEror with the appropriate error // MPPTasksErrorCode and message. MPPTasksErrorCode has a one to one @@ -129,6 +87,81 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; NSString *description = [NSString stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str() encoding:NSUTF8StringEncoding]; + + // Payload of absl::Status created by the MediaPipe task library stores an appropriate value of + // the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum + // stored in the payload is extracted here to later map to the appropriate error code to be + // returned. In cases where the enum is not stored in (payload is NULL or the payload string + // cannot be converted to an integer), we set the error code value to be 1 + // (MPPTasksErrorCodeError of MPPTasksErrorCode used in the iOS library to signify + // any errors not falling into other categories.) Since payload is of type absl::Cord that can be + // type cast into an absl::optional, we use the std::stoi function to convert it into + // an integer code if possible. + MPPTasksErrorCode genericErrorCode = MPPTasksErrorCodeUnknownError; + + MPPTasksErrorCode errorCode = genericErrorCode; + + // If errorCode is outside the range of enum values possible or is + // MPPTasksErrorCodeError, we try to map the absl::Status::code() to assign + // appropriate MPPTasksErrorCode in default cases. Note: + // The mapping to absl::Status::code() is done to generate a more specific error code than + // MPPTasksErrorCodeError in cases when the payload can't be mapped to + // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn + // returned without modification by MediaPipe cc library methods. + switch (status.code()) { + case StatusCode::kCancelled: + errorCode = MPPTasksErrorCodeCancelledError; + break; + case StatusCode::kUnknown: + errorCode = MPPTasksErrorCodeUnknownError; + break; + case StatusCode::kInvalidArgument: + errorCode = MPPTasksErrorCodeInvalidArgumentError; + break; + case StatusCode::kDeadlineExceeded: + errorCode = MPPTasksErrorCodeDeadlineExceededError; + break; + case StatusCode::kNotFound: + errorCode = MPPTasksErrorCodeNotFoundError; + break; + case StatusCode::kAlreadyExists: + errorCode = MPPTasksErrorCodeAlreadyExistsError; + break; + case StatusCode::kPermissionDenied: + errorCode = MPPTasksErrorCodePermissionDeniedError; + break; + case StatusCode::kResourceExhausted: + errorCode = MPPTasksErrorCodeResourceExhaustedError; + break; + case StatusCode::kFailedPrecondition: + errorCode = MPPTasksErrorCodeFailedPreconditionError; + break; + case StatusCode::kAborted: + errorCode = MPPTasksErrorCodeAbortedError; + break; + case StatusCode::kOutOfRange: + errorCode = MPPTasksErrorCodeOutOfRangeError; + break; + case StatusCode::kUnimplemented: + errorCode = MPPTasksErrorCodeUnimplementedError; + break; + case StatusCode::kInternal: + errorCode = MPPTasksErrorCodeInternalError; + break; + case StatusCode::kUnavailable: + errorCode = MPPTasksErrorCodeUnavailableError; + break; + case StatusCode::kDataLoss: + errorCode = MPPTasksErrorCodeDataLossError; + break; + case StatusCode::kUnauthenticated: + errorCode = MPPTasksErrorCodeUnauthenticatedError; + break; + default: + errorCode = genericErrorCode; + break; + } + [MPPCommonUtils createCustomError:error withCode:errorCode description:description]; return NO; } From fa30100059330e9498469e4ca5065686a2079ee7 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:04:17 +0530 Subject: [PATCH 303/346] Changed swift name of MPPCategory --- mediapipe/tasks/ios/components/containers/sources/MPPCategory.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h index d05cfe13b..f360d46da 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h @@ -21,7 +21,7 @@ NS_ASSUME_NONNULL_BEGIN * index of the label in the corresponding label file. Typically it's used as the result of * classification tasks. */ -NS_SWIFT_NAME(ClassificationCategory) +NS_SWIFT_NAME(ResultCategory) @interface MPPCategory : NSObject /** From 0a707256e3b6a993447bf9b6206688e1e6bb58f0 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:04:43 +0530 Subject: [PATCH 304/346] Updates to method signatures of iOS text classifier --- .../ios/text/text_classifier/sources/MPPTextClassifier.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h index 60aa94614..e33615dab 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -65,7 +65,7 @@ NS_SWIFT_NAME(TextClassifier) * @return A new instance of `MPPTextClassifier` with the given model path. `nil` if there is an * error in initializing the text classifier. */ -- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; +- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; /** * Creates a new instance of `MPPTextClassifier` from the given `MPPTextClassifierOptions`. @@ -78,7 +78,7 @@ NS_SWIFT_NAME(TextClassifier) * @return A new instance of `MPPTextClassifier` with the given options. `nil` if there is an * error in initializing the text classifier. */ -- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options +- (nullable instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error NS_DESIGNATED_INITIALIZER; /** @@ -90,7 +90,8 @@ NS_SWIFT_NAME(TextClassifier) * * @return A `MPPTextClassifierResult` object that contains a list of text classifications. */ -- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text error:(NSError **)error; +- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text error:(NSError **)error NS_SWIFT_NAME(classify(text:)); + - (instancetype)init NS_UNAVAILABLE; From c40356c62852f9f174b04d790303511d8264fcef Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:04:56 +0530 Subject: [PATCH 305/346] Added ios.bzl --- mediapipe/tasks/ios/ios.bzl | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 mediapipe/tasks/ios/ios.bzl diff --git a/mediapipe/tasks/ios/ios.bzl b/mediapipe/tasks/ios/ios.bzl new file mode 100644 index 000000000..8fe2a24a1 --- /dev/null +++ b/mediapipe/tasks/ios/ios.bzl @@ -0,0 +1,3 @@ +"""MediaPipe Task Library Helper Rules for iOS""" + +MPP_TASK_MINIMUM_OS_VERSION = "11.0" From 9e0b85c9b58b0395442c3e8cdeee46e45c8af380 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:05:17 +0530 Subject: [PATCH 306/346] Added module name for iOS text classifier --- mediapipe/tasks/ios/text/text_classifier/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index aef68c9fe..1afddb5d4 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -46,6 +46,7 @@ objc_library( "-std=c++17", "-x objective-c++", ], + module_name = "MPPTextClassifier", deps = [ ":MPPTextClassifierOptions", ":MPPTextClassifierResult", From 2a53d78ae44bf27ec81ef795e51a5eb6fb863398 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:05:44 +0530 Subject: [PATCH 307/346] Added swift and objective tests for iOS text classifier --- .../tasks/ios/test/text/text_classifier/BUILD | 82 +++++ .../text_classifier/MPPTextClassifierTests.m | 281 ++++++++++++++++++ .../text_classifier/TextClassifierTests.swift | 237 +++++++++++++++ 3 files changed, 600 insertions(+) create mode 100644 mediapipe/tasks/ios/test/text/text_classifier/BUILD create mode 100644 mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m create mode 100644 mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift diff --git a/mediapipe/tasks/ios/test/text/text_classifier/BUILD b/mediapipe/tasks/ios/test/text/text_classifier/BUILD new file mode 100644 index 000000000..b69202b64 --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_classifier/BUILD @@ -0,0 +1,82 @@ +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_unit_test", +) +load( + "@org_tensorflow//tensorflow/lite:special_rules.bzl", + "tflite_ios_lab_runner" +) +load( + "@build_bazel_rules_swift//swift:swift.bzl", + "swift_library" +) +load( + "//mediapipe/tasks:ios/ios.bzl", + "MPP_TASK_MINIMUM_OS_VERSION" +) + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# Default tags for filtering iOS targets. Targets are restricted to Apple platforms. +TFL_DEFAULT_TAGS = [ + "apple", +] + +# Following sanitizer tests are not supported by iOS test targets. +TFL_DISABLED_SANITIZER_TAGS = [ + "noasan", + "nomsan", + "notsan", +] + +objc_library( + name = "MPPTextClassifierObjcTestLibrary", + testonly = 1, + srcs = ["MPPTextClassifierTests.m"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + tags = [], + deps = [ + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier", + ], + +) + +ios_unit_test( + name = "MPPTextClassifierObjcTest", + minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags =[], + deps = [ + ":MPPTextClassifierObjcTestLibrary", + ], +) + +swift_library( + name = "MPPTextClassifierSwiftTestLibrary", + testonly = 1, + srcs = ["TextClassifierTests.swift"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + "//mediapipe/tasks/ios/common:MPPCommon", + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier", + ], +) + +ios_unit_test( + name = "MPPTextClassifierSwiftTest", + minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":MPPTextClassifierSwiftTestLibrary", + ], +) diff --git a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m new file mode 100644 index 000000000..3e2fe4bef --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m @@ -0,0 +1,281 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" + +static NSString *const kBertTextClassifierModelName = @"bert_text_classifier"; +static NSString *const kRegexTextClassifierModelName = + @"test_model_text_classifier_with_regex_tokenizer"; +static NSString *const kNegativeText = @"unflinchingly bleak and desperate"; +static NSString *const kPositiveText = @"it's a charming and often affecting journey"; +static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; + +#define AssertEqualErrors(error, expectedError) \ + XCTAssertNotNil(error); \ + XCTAssertEqualObjects(error.domain, expectedError.domain); \ + XCTAssertEqual(error.code, expectedError.code); \ + XCTAssertNotEqual( \ + [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \ + NSNotFound) + +#define AssertEqualCategoryArrays(categories, expectedCategories) \ + XCTAssertEqual(categories.count, expectedCategories.count); \ + for (int i = 0; i < categories.count; i++) { \ + XCTAssertEqual(categories[i].index, expectedCategories[i].index); \ + XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-6); \ + XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName); \ + XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName); \ + } + +#define AssertTextClassifierResultHasOneHead(textClassifierResult) \ + XCTAssertNotNil(textClassifierResult); \ + \ + XCTAssertNotNil(textClassifierResult.classificationResult); \ + XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); \ + XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0); + +@interface MPPTextClassifierTests : XCTestCase +@end + +@implementation MPPTextClassifierTests + +- (void)setUp { +} + +- (void)tearDown { + // Put teardown code here. This method is called after the invocation of each test method in the + // class. +} + ++ (NSArray *)expectedBertResultCategoriesForNegativeText { + return @[ + [[MPPCategory alloc] initWithIndex:0 score:0.956187f categoryName:@"negative" displayName:nil], + [[MPPCategory alloc] initWithIndex:1 score:0.043812f categoryName:@"positive" displayName:nil] + ]; +} + ++ (NSArray *)expectedBertResultCategoriesForPositiveText { + return @[ + [[MPPCategory alloc] initWithIndex:1 score:0.999945f categoryName:@"positive" displayName:nil], + [[MPPCategory alloc] initWithIndex:0 score:0.000055f categoryName:@"negative" displayName:nil] + ]; +} + ++ (NSArray *)expectedRegexResultCategoriesForNegativeText { + return @[ + [[MPPCategory alloc] initWithIndex:0 score:0.6647746f categoryName:@"Negative" displayName:nil], + [[MPPCategory alloc] initWithIndex:1 score:0.33522537 categoryName:@"Positive" displayName:nil] + ]; +} + ++ (NSArray *)expectedRegexResultCategoriesForPositiveText { + return @[ + [[MPPCategory alloc] initWithIndex:0 score:0.5120041f categoryName:@"Negative" displayName:nil], + [[MPPCategory alloc] initWithIndex:1 score:0.48799595 categoryName:@"Positive" displayName:nil] + ]; +} + ++ (NSArray *)expectedBertResultCategoriesForEdgeCaseTests { + return @[ [[MPPCategory alloc] initWithIndex:0 + score:0.956187f + categoryName:@"negative" + displayName:nil] ]; +} + +- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension { + NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName + ofType:extension]; + return filePath; +} + +- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName { + NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; + MPPTextClassifierOptions *textClassifierOptions = [[MPPTextClassifierOptions alloc] init]; + textClassifierOptions.baseOptions.modelAssetPath = modelPath; + + return textClassifierOptions; +} + +- (MPPTextClassifier *)textClassifierFromModelFileWithName:(NSString *)modelName { + NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath + error:nil]; + XCTAssertNotNil(textClassifier); + + return textClassifier; +} + +- (void)assertCreateTextClassifierWithOptions:(MPPTextClassifierOptions *)textClassifierOptions + failsWithExpectedError:(NSError *)expectedError { + NSError *error = nil; + MPPTextClassifier *textClassifier = + [[MPPTextClassifier alloc] initWithOptions:textClassifierOptions error:&error]; + XCTAssertNil(textClassifier); + AssertEqualErrors(error, expectedError); +} + +- (void)assertResultsOfClassifyText:(NSString *)text + usingTextClassifier:(MPPTextClassifier *)textClassifier + equalsCategories:(NSArray *)expectedCategories { + MPPTextClassifierResult *negativeResult = [textClassifier classifyText:text error:nil]; + AssertTextClassifierResultHasOneHead(negativeResult); + AssertEqualCategoryArrays(negativeResult.classificationResult.classifications[0].categories, + expectedCategories); +} + +- (void)testCreateTextClassifierFailsWithMissingModelPath { + NSString *modelPath = [self filePathWithName:@"" extension:@""]; + + NSError *error = nil; + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath + error:&error]; + XCTAssertNil(textClassifier); + + NSError *expectedError = [NSError + errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', " + @"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'." + }]; + AssertEqualErrors(error, expectedError); +} + +- (void)testCreateTextClassifierFailsWithBothAllowListAndDenyList { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.categoryAllowlist = @[ @"positive" ]; + options.categoryDenylist = @[ @"negative" ]; + + [self assertCreateTextClassifierWithOptions:options + failsWithExpectedError: + [NSError + errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"INVALID_ARGUMENT: `category_allowlist` and " + @"`category_denylist` are mutually exclusive options." + }]]; +} + +- (void)testCreateTextClassifierFailsWithInvalidMaxResults { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.maxResults = 0; + + [self assertCreateTextClassifierWithOptions:options + failsWithExpectedError: + [NSError errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"INVALID_ARGUMENT: Invalid `max_results` option: " + @"value must be != 0." + }]]; +} + +- (void)testClassifyWithBertSucceeds { + MPPTextClassifier *textClassifier = + [self textClassifierFromModelFileWithName:kBertTextClassifierModelName]; + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForNegativeText]]; + + [self assertResultsOfClassifyText:kPositiveText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForPositiveText]]; +} + +- (void)testClassifyWithRegexSucceeds { + MPPTextClassifier *textClassifier = + [self textClassifierFromModelFileWithName:kRegexTextClassifierModelName]; + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedRegexResultCategoriesForNegativeText]]; + [self assertResultsOfClassifyText:kPositiveText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedRegexResultCategoriesForPositiveText]]; +} + +- (void)testClassifyWithMaxResultsSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.maxResults = 1; + + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textClassifier); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +- (void)testClassifyWithCategoryAllowListSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.categoryAllowlist = @[ @"negative" ]; + + NSError *error = nil; + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options + error:&error]; + XCTAssertNotNil(textClassifier); + XCTAssertNil(error); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +- (void)testClassifyWithCategoryDenyListSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.categoryDenylist = @[ @"positive" ]; + + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textClassifier); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +- (void)testClassifyWithScoreThresholdSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.scoreThreshold = 0.5f; + + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textClassifier); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +@end diff --git a/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift new file mode 100644 index 000000000..d2d433c22 --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift @@ -0,0 +1,237 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest + +import MPPCommon + +@testable import MPPTextClassifier + +class TextClassifierTests: XCTestCase { + + static let bundle = Bundle(for: TextClassifierTests.self) + + static let kBertModelPath = bundle.path( + forResource: "bert_text_classifier", + ofType: "tflite") + + static let kPositiveText = "it's a charming and often affecting journey" + + static let kNegativeText = "unflinchingly bleak and desperate" + + static let kBertNegativeTextResults = [ + ResultCategory( + index: 0, + score: 0.956187, + categoryName: "negative", + displayName: nil), + ResultCategory( + index: 1, + score: 0.043812, + categoryName: "positive", + displayName: nil) + ] + + static let kBertNegativeTextResultsForEdgeTestCases = [ + ResultCategory( + index: 0, + score: 0.956187, + categoryName: "negative", + displayName: nil), + ] + + func assertEqualErrorDescriptions( + _ error: Error, expectedLocalizedDescription:String) { + XCTAssertEqual( + error.localizedDescription, + expectedLocalizedDescription) + } + + func assertCategoriesAreEqual( + category: ResultCategory, + expectedCategory: ResultCategory) { + XCTAssertEqual( + category.index, + expectedCategory.index) + XCTAssertEqual( + category.score, + expectedCategory.score, + accuracy:1e-6) + XCTAssertEqual( + category.categoryName, + expectedCategory.categoryName) + XCTAssertEqual( + category.displayName, + expectedCategory.displayName) + } + + func assertEqualCategoryArrays( + categoryArray: [ResultCategory], + expectedCategoryArray:[ResultCategory]) { + + XCTAssertEqual(categoryArray.count, expectedCategoryArray.count) + + for (category, expectedCategory) in + zip(categoryArray, expectedCategoryArray) { + assertCategoriesAreEqual( + category:category, + expectedCategory:expectedCategory) + } + } + + func assertTextClassifierResultHasOneHead( + _ textClassifierResult: TextClassifierResult) { + XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); + XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0); + } + + func textClassifierOptionsWithModelPath( + _ modelPath: String?) throws -> TextClassifierOptions { + let modelPath = try XCTUnwrap(modelPath) + + let textClassifierOptions = TextClassifierOptions(); + textClassifierOptions.baseOptions.modelAssetPath = modelPath; + + return textClassifierOptions + } + + func assertCreateTextClassifierThrowsError( + textClassifierOptions: TextClassifierOptions, + expectedErrorDescription: String) { + do { + let textClassifier = try TextClassifier(options:textClassifierOptions) + XCTAssertNil(textClassifier) + } + catch { + assertEqualErrorDescriptions( + error, + expectedLocalizedDescription: expectedErrorDescription) + } + } + + func assertResultsForClassify( + text: String, + using textClassifier: TextClassifier, + equals expectedCategories: [ResultCategory]) throws { + let textClassifierResult = + try XCTUnwrap( + textClassifier.classify(text: text)); + assertTextClassifierResultHasOneHead(textClassifierResult); + assertEqualCategoryArrays( + categoryArray: + textClassifierResult.classificationResult.classifications[0].categories, + expectedCategoryArray: expectedCategories); + } + + func testCreateTextClassifierWithInvalidMaxResultsFails() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.maxResults = 0 + + assertCreateTextClassifierThrowsError( + textClassifierOptions: textClassifierOptions, + expectedErrorDescription: """ + INVALID_ARGUMENT: Invalid `max_results` option: value must be != 0. + """) + } + + func testCreateTextClassifierWithCategoryAllowlistandDenylistFails() throws { + + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.categoryAllowlist = ["positive"] + textClassifierOptions.categoryDenylist = ["positive"] + + assertCreateTextClassifierThrowsError( + textClassifierOptions: textClassifierOptions, + expectedErrorDescription: """ + INVALID_ARGUMENT: `category_allowlist` and `category_denylist` are \ + mutually exclusive options. + """) + } + + func testClassifyWithBertSucceeds() throws { + + let modelPath = try XCTUnwrap(TextClassifierTests.kBertModelPath) + let textClassifier = try XCTUnwrap(TextClassifier(modelPath: modelPath)) + + try assertResultsForClassify( + text: TextClassifierTests.kNegativeText, + using: textClassifier, + equals: TextClassifierTests.kBertNegativeTextResults) + } + + func testClassifyWithMaxResultsSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.maxResults = 1 + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.kNegativeText, + using: textClassifier, + equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases) + } + + func testClassifyWithCategoryAllowlistSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.categoryAllowlist = ["negative"]; + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.kNegativeText, + using: textClassifier, + equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases) + } + + func testClassifyWithCategoryDenylistSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.categoryDenylist = ["positive"]; + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.kNegativeText, + using: textClassifier, + equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases) + } + + func testClassifyWithScoreThresholdSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.scoreThreshold = 0.5; + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.kNegativeText, + using: textClassifier, + equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases) + } + +} From c4c07acc1e5b2dbc37965b9c714fad2102705dbd Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:18:01 +0530 Subject: [PATCH 308/346] Updated comments of MPPCommonUtils --- .../common/utils/sources/MPPCommonUtils.mm | 141 +++++++----------- 1 file changed, 58 insertions(+), 83 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 9932dd13c..27b75515d 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -26,7 +26,7 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; namespace { - using absl::StatusCode; +using absl::StatusCode; } @implementation MPPCommonUtils @@ -72,95 +72,70 @@ namespace { return YES; } - // Creates the NSEror with the appropriate error - // MPPTasksErrorCode and message. MPPTasksErrorCode has a one to one - // mapping with MediaPipeTasksStatus starting from the value 1(MPPTasksErrorCodeError) - // and hence will be correctly initialized if directly cast from the integer code derived from - // MediaPipeTasksStatus stored in its payload. MPPTasksErrorCode omits kOk = 0 of - // MediaPipeTasksStatusx. - // - // Stores a string including absl status code and message(if non empty) as the - // error message See - // https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h#L514 - // for explanation. absl::Status::message() can also be used but not always - // guaranteed to be non empty. + /** Converts the absl status message to an NSString. */ NSString *description = [NSString stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str() encoding:NSUTF8StringEncoding]; - - // Payload of absl::Status created by the MediaPipe task library stores an appropriate value of - // the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum - // stored in the payload is extracted here to later map to the appropriate error code to be - // returned. In cases where the enum is not stored in (payload is NULL or the payload string - // cannot be converted to an integer), we set the error code value to be 1 - // (MPPTasksErrorCodeError of MPPTasksErrorCode used in the iOS library to signify - // any errors not falling into other categories.) Since payload is of type absl::Cord that can be - // type cast into an absl::optional, we use the std::stoi function to convert it into - // an integer code if possible. + MPPTasksErrorCode genericErrorCode = MPPTasksErrorCodeUnknownError; MPPTasksErrorCode errorCode = genericErrorCode; - // If errorCode is outside the range of enum values possible or is - // MPPTasksErrorCodeError, we try to map the absl::Status::code() to assign - // appropriate MPPTasksErrorCode in default cases. Note: - // The mapping to absl::Status::code() is done to generate a more specific error code than - // MPPTasksErrorCodeError in cases when the payload can't be mapped to - // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn - // returned without modification by MediaPipe cc library methods. - switch (status.code()) { - case StatusCode::kCancelled: - errorCode = MPPTasksErrorCodeCancelledError; - break; - case StatusCode::kUnknown: - errorCode = MPPTasksErrorCodeUnknownError; - break; - case StatusCode::kInvalidArgument: - errorCode = MPPTasksErrorCodeInvalidArgumentError; - break; - case StatusCode::kDeadlineExceeded: - errorCode = MPPTasksErrorCodeDeadlineExceededError; - break; - case StatusCode::kNotFound: - errorCode = MPPTasksErrorCodeNotFoundError; - break; - case StatusCode::kAlreadyExists: - errorCode = MPPTasksErrorCodeAlreadyExistsError; - break; - case StatusCode::kPermissionDenied: - errorCode = MPPTasksErrorCodePermissionDeniedError; - break; - case StatusCode::kResourceExhausted: - errorCode = MPPTasksErrorCodeResourceExhaustedError; - break; - case StatusCode::kFailedPrecondition: - errorCode = MPPTasksErrorCodeFailedPreconditionError; - break; - case StatusCode::kAborted: - errorCode = MPPTasksErrorCodeAbortedError; - break; - case StatusCode::kOutOfRange: - errorCode = MPPTasksErrorCodeOutOfRangeError; - break; - case StatusCode::kUnimplemented: - errorCode = MPPTasksErrorCodeUnimplementedError; - break; - case StatusCode::kInternal: - errorCode = MPPTasksErrorCodeInternalError; - break; - case StatusCode::kUnavailable: - errorCode = MPPTasksErrorCodeUnavailableError; - break; - case StatusCode::kDataLoss: - errorCode = MPPTasksErrorCodeDataLossError; - break; - case StatusCode::kUnauthenticated: - errorCode = MPPTasksErrorCodeUnauthenticatedError; - break; - default: - errorCode = genericErrorCode; - break; - } + /** Maps the absl::StatusCode to the appropriate MPPTasksErrorCode. Note: MPPTasksErrorCode omits + * absl::StatusCode::kOk. */ + switch (status.code()) { + case StatusCode::kCancelled: + errorCode = MPPTasksErrorCodeCancelledError; + break; + case StatusCode::kUnknown: + errorCode = MPPTasksErrorCodeUnknownError; + break; + case StatusCode::kInvalidArgument: + errorCode = MPPTasksErrorCodeInvalidArgumentError; + break; + case StatusCode::kDeadlineExceeded: + errorCode = MPPTasksErrorCodeDeadlineExceededError; + break; + case StatusCode::kNotFound: + errorCode = MPPTasksErrorCodeNotFoundError; + break; + case StatusCode::kAlreadyExists: + errorCode = MPPTasksErrorCodeAlreadyExistsError; + break; + case StatusCode::kPermissionDenied: + errorCode = MPPTasksErrorCodePermissionDeniedError; + break; + case StatusCode::kResourceExhausted: + errorCode = MPPTasksErrorCodeResourceExhaustedError; + break; + case StatusCode::kFailedPrecondition: + errorCode = MPPTasksErrorCodeFailedPreconditionError; + break; + case StatusCode::kAborted: + errorCode = MPPTasksErrorCodeAbortedError; + break; + case StatusCode::kOutOfRange: + errorCode = MPPTasksErrorCodeOutOfRangeError; + break; + case StatusCode::kUnimplemented: + errorCode = MPPTasksErrorCodeUnimplementedError; + break; + case StatusCode::kInternal: + errorCode = MPPTasksErrorCodeInternalError; + break; + case StatusCode::kUnavailable: + errorCode = MPPTasksErrorCodeUnavailableError; + break; + case StatusCode::kDataLoss: + errorCode = MPPTasksErrorCodeDataLossError; + break; + case StatusCode::kUnauthenticated: + errorCode = MPPTasksErrorCodeUnauthenticatedError; + break; + default: + errorCode = genericErrorCode; + break; + } [MPPCommonUtils createCustomError:error withCode:errorCode description:description]; return NO; From 95f9f0fb88c209b147b7822c102f5003c22d3c16 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 13 Jan 2023 21:18:10 +0530 Subject: [PATCH 309/346] Updated formatting --- .../tasks/ios/common/sources/MPPCommon.h | 30 +++++++++++++++++-- .../sources/MPPTextClassifier.h | 6 ++-- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/mediapipe/tasks/ios/common/sources/MPPCommon.h b/mediapipe/tasks/ios/common/sources/MPPCommon.h index f8047fc35..0f885a8c2 100644 --- a/mediapipe/tasks/ios/common/sources/MPPCommon.h +++ b/mediapipe/tasks/ios/common/sources/MPPCommon.h @@ -18,8 +18,7 @@ NS_ASSUME_NONNULL_BEGIN /** * @enum MPPTasksErrorCode - * This enum specifies error codes for MediaPipe Task Library. - * It maintains a 1:1 mapping to MediaPipeTasksStatus of the C ++libray. + * This enum specifies error codes for errors thrown by iOS MediaPipe Task Library. */ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { @@ -27,35 +26,60 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { /** Indicates the operation was cancelled, typically by the caller. */ MPPTasksErrorCodeCancelledError = 1, + /** Indicates an unknown error occurred. */ MPPTasksErrorCodeUnknownError = 2, + /** Indicates the caller specified an invalid argument, such as a malformed filename. */ MPPTasksErrorCodeInvalidArgumentError = 3, + /** Indicates a deadline expired before the operation could complete. */ MPPTasksErrorCodeDeadlineExceededError = 4, + /** Indicates some requested entity (such as a file or directory) was not found. */ MPPTasksErrorCodeNotFoundError = 5, - /** Indicates that the entity a caller attempted to create (such as a file or directory) is already present. */ + + /** Indicates that the entity a caller attempted to create (such as a file or directory) is + already present. */ MPPTasksErrorCodeAlreadyExistsError = 6, + /** Indicates that the caller does not have permission to execute the specified operation. */ MPPTasksErrorCodePermissionDeniedError = 7, + /** Indicates some resource has been exhausted, perhaps a per-user quota, or perhaps the entire + file system is out of space. */ MPPTasksErrorCodeResourceExhaustedError = 8, + /** Indicates that the operation was rejected because the system is not in a state required for + the operation's execution. For example, a directory to be deleted may be non-empty, an "rmdir" + operation is applied to a non-directory, etc. */ MPPTasksErrorCodeFailedPreconditionError = 9, + /** Indicates the operation was aborted, typically due to a concurrency issue such as a sequencer + check failure or a failed transaction. */ MPPTasksErrorCodeAbortedError = 10, + /** Indicates the operation was attempted past the valid range, such as seeking or reading past an + end-of-file. */ MPPTasksErrorCodeOutOfRangeError = 11, + /** Indicates the operation is not implemented or supported in this service. In this case, the + operation should not be re-attempted. */ MPPTasksErrorCodeUnimplementedError = 12, + /** Indicates an internal error has occurred and some invariants expected by the underlying system + have not been satisfied. This error code is reserved for serious errors. */ MPPTasksErrorCodeInternalError = 13, + /** Indicates the service is currently unavailable and that this is most likely a transient + condition. */ MPPTasksErrorCodeUnavailableError = 14, + /** Indicates that unrecoverable data loss or corruption has occurred. */ MPPTasksErrorCodeDataLossError = 15, + /** Indicates that the request does not have valid authentication credentials for the operation. + */ MPPTasksErrorCodeUnauthenticatedError = 16, // The first error code in MPPTasksErrorCode (for internal use only). diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h index e33615dab..33d3c8970 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -79,7 +79,7 @@ NS_SWIFT_NAME(TextClassifier) * error in initializing the text classifier. */ - (nullable instancetype)initWithOptions:(MPPTextClassifierOptions *)options - error:(NSError **)error NS_DESIGNATED_INITIALIZER; + error:(NSError **)error NS_DESIGNATED_INITIALIZER; /** * Performs classification on the input text. @@ -90,8 +90,8 @@ NS_SWIFT_NAME(TextClassifier) * * @return A `MPPTextClassifierResult` object that contains a list of text classifications. */ -- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text error:(NSError **)error NS_SWIFT_NAME(classify(text:)); - +- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text + error:(NSError **)error NS_SWIFT_NAME(classify(text:)); - (instancetype)init NS_UNAVAILABLE; From 69757d7924f84dfbe50520bba1bf1fdd4f177f16 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 13 Jan 2023 09:03:46 -0800 Subject: [PATCH 310/346] Internal change PiperOrigin-RevId: 501862194 --- mediapipe/framework/BUILD | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 83346dad1..da8ef3b3e 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -96,7 +96,9 @@ mediapipe_proto_library( mediapipe_proto_library( name = "mediapipe_options_proto", srcs = ["mediapipe_options.proto"], - visibility = [":mediapipe_internal"], + visibility = [ + ":mediapipe_internal", + ], ) mediapipe_proto_library( From f997c0ab1a8bc69d0ef8760061a515313144af8c Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 13 Jan 2023 09:52:07 -0800 Subject: [PATCH 311/346] Reject RegionOfInterest in not supported tasks PiperOrigin-RevId: 501872455 --- .../vision/core/vision_task_runner.test.ts | 41 +++++++++++++++---- .../web/vision/core/vision_task_runner.ts | 9 +++- .../gesture_recognizer/gesture_recognizer.ts | 2 +- .../gesture_recognizer_test.ts | 8 ++++ .../vision/hand_landmarker/hand_landmarker.ts | 2 +- .../hand_landmarker/hand_landmarker_test.ts | 8 ++++ .../image_classifier/image_classifier.ts | 2 +- .../vision/image_embedder/image_embedder.ts | 2 +- .../vision/object_detector/object_detector.ts | 2 +- .../object_detector/object_detector_test.ts | 8 ++++ 10 files changed, 70 insertions(+), 14 deletions(-) diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts index 4567134b8..4eb51afdb 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -41,14 +41,14 @@ class VisionTaskRunnerFake extends VisionTaskRunner { expectedImageSource?: ImageSource; expectedNormalizedRect?: NormalizedRect; - constructor() { + constructor(roiAllowed = true) { super( jasmine.createSpyObj([ 'addProtoToStream', 'addGpuBufferAsImageToStream', 'setAutoRenderToScreen', 'registerModelResourcesGraphService', 'finishProcessing' ]), - IMAGE_STREAM, NORM_RECT_STREAM); + IMAGE_STREAM, NORM_RECT_STREAM, roiAllowed); this.fakeGraphRunner = this.graphRunner as unknown as jasmine.SpyObj; @@ -71,6 +71,9 @@ class VisionTaskRunnerFake extends VisionTaskRunner { expect(timestamp).toBe(TIMESTAMP); expect(imageSource).toBe(this.expectedImageSource!); }); + + // SetOptions with a modelAssetBuffer runs synchonously + void this.setOptions({baseOptions: {modelAssetBuffer: new Uint8Array([])}}); } protected override refreshGraph(): void {} @@ -108,28 +111,26 @@ class VisionTaskRunnerFake extends VisionTaskRunner { } describe('VisionTaskRunner', () => { - let visionTaskRunner: VisionTaskRunnerFake; - - beforeEach(async () => { + beforeEach(() => { addJasmineCustomFloatEqualityTester(); - visionTaskRunner = new VisionTaskRunnerFake(); - await visionTaskRunner.setOptions( - {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('can enable image mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'IMAGE'}); expect(visionTaskRunner.baseOptions.toObject()) .toEqual(jasmine.objectContaining({useStreamMode: false})); }); it('can enable video mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); expect(visionTaskRunner.baseOptions.toObject()) .toEqual(jasmine.objectContaining({useStreamMode: true})); }); it('can clear running mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); // Clear running mode @@ -140,6 +141,7 @@ describe('VisionTaskRunner', () => { }); it('cannot process images with video mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); expect(() => { visionTaskRunner.processImageData( @@ -148,6 +150,7 @@ describe('VisionTaskRunner', () => { }); it('cannot process video with image mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); // Use default for `useStreamMode` expect(() => { visionTaskRunner.processVideoData( @@ -163,6 +166,7 @@ describe('VisionTaskRunner', () => { }); it('sends packets to graph', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); visionTaskRunner.expectImage(IMAGE); @@ -172,6 +176,7 @@ describe('VisionTaskRunner', () => { }); it('sends packets to graph with image processing options', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); visionTaskRunner.expectImage(IMAGE); @@ -184,6 +189,7 @@ describe('VisionTaskRunner', () => { describe('validates processing options', () => { it('with left > right', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); expect(() => { visionTaskRunner.processImageData(IMAGE, { regionOfInterest: { @@ -197,6 +203,7 @@ describe('VisionTaskRunner', () => { }); it('with top > bottom', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); expect(() => { visionTaskRunner.processImageData(IMAGE, { regionOfInterest: { @@ -210,6 +217,7 @@ describe('VisionTaskRunner', () => { }); it('with out of range values', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); expect(() => { visionTaskRunner.processImageData(IMAGE, { regionOfInterest: { @@ -222,7 +230,24 @@ describe('VisionTaskRunner', () => { }).toThrowError('Expected RectF values to be in [0,1].'); }); + + it('without region of interest support', () => { + const visionTaskRunner = + new VisionTaskRunnerFake(/* roiAllowed= */ false); + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.1, + right: 0.2, + top: 0.1, + bottom: 0.2, + } + }); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + it('with non-90 degree rotation', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); expect(() => { visionTaskRunner.processImageData(IMAGE, {rotationDegrees: 42}); }).toThrowError('Expected rotation to be a multiple of 90°.'); diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 71cac920c..b3e8ed4db 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -42,13 +42,16 @@ export abstract class VisionTaskRunner extends TaskRunner { * @param normRectStreamName the name of the input normalized rect image * stream used to provide (mandatory) rotation and (optional) * region-of-interest. + * @param roiAllowed Whether this task supports Region-Of-Interest + * pre-processing * * @hideconstructor protected */ constructor( protected override readonly graphRunner: VisionGraphRunner, private readonly imageStreamName: string, - private readonly normRectStreamName: string) { + private readonly normRectStreamName: string, + private readonly roiAllowed: boolean) { super(graphRunner); } @@ -96,6 +99,10 @@ export abstract class VisionTaskRunner extends TaskRunner { const normalizedRect = new NormalizedRect(); if (imageProcessingOptions?.regionOfInterest) { + if (!this.roiAllowed) { + throw new Error('This task doesn\'t support region-of-interest.'); + } + const roi = imageProcessingOptions.regionOfInterest; if (roi.left >= roi.right || roi.top >= roi.bottom) { diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 1b7201b9a..beea263ce 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -126,7 +126,7 @@ export class GestureRecognizer extends VisionTaskRunner { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { super( new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, - NORM_RECT_STREAM); + NORM_RECT_STREAM, /* roiAllowed= */ false); this.options = new GestureRecognizerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index dfc252eb6..b2a2c0d72 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -250,6 +250,14 @@ describe('GestureRecognizer', () => { } }); + it('doesn\'t support region of interest', () => { + expect(() => { + gestureRecognizer.recognize( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + it('transforms results', async () => { // Pass the test data to our listener gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index b51fb6a52..cd0459372 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -116,7 +116,7 @@ export class HandLandmarker extends VisionTaskRunner { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { super( new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, - NORM_RECT_STREAM); + NORM_RECT_STREAM, /* roiAllowed= */ false); this.options = new HandLandmarkerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts index 0abd1df27..5fd493424 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -203,6 +203,14 @@ describe('HandLandmarker', () => { } }); + it('doesn\'t support region of interest', () => { + expect(() => { + handLandmarker.detect( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + it('transforms results', async () => { // Pass the test data to our listener handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index cb2849cd8..071513b19 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -101,7 +101,7 @@ export class ImageClassifier extends VisionTaskRunner { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { super( new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, - NORM_RECT_STREAM); + NORM_RECT_STREAM, /* roiAllowed= */ true); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 788646e6d..fdeb92f3f 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -104,7 +104,7 @@ export class ImageEmbedder extends VisionTaskRunner { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { super( new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, - NORM_RECT_STREAM); + NORM_RECT_STREAM, /* roiAllowed= */ true); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 5741a3a0c..5b581432d 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -100,7 +100,7 @@ export class ObjectDetector extends VisionTaskRunner { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { super( new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, - NORM_RECT_STREAM); + NORM_RECT_STREAM, /* roiAllowed= */ false); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts index ceb96acb1..9dd64c0b6 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -170,6 +170,14 @@ describe('ObjectDetector', () => { } }); + it('doesn\'t support region of interest', () => { + expect(() => { + objectDetector.detect( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + it('transforms results', async () => { const detectionProtos: Uint8Array[] = []; From aef4cca40610ced2efd0ed45a465c43368f4a893 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 13 Jan 2023 13:45:47 -0800 Subject: [PATCH 312/346] Copy README.md to NPM package root PiperOrigin-RevId: 501929871 --- mediapipe/tasks/web/BUILD | 161 +--------------------------- mediapipe/tasks/web/audio.ts | 25 ----- mediapipe/tasks/web/audio/BUILD | 58 +++++++++- mediapipe/tasks/web/audio/index.ts | 14 ++- mediapipe/tasks/web/text.ts | 25 ----- mediapipe/tasks/web/text/BUILD | 56 +++++++++- mediapipe/tasks/web/text/index.ts | 14 ++- mediapipe/tasks/web/vision.ts | 35 ------ mediapipe/tasks/web/vision/BUILD | 56 +++++++++- mediapipe/tasks/web/vision/index.ts | 30 ++++-- 10 files changed, 216 insertions(+), 258 deletions(-) delete mode 100644 mediapipe/tasks/web/audio.ts delete mode 100644 mediapipe/tasks/web/text.ts delete mode 100644 mediapipe/tasks/web/vision.ts diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index 02bd70dd0..ff947ef54 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -1,158 +1,5 @@ -# This contains the MediaPipe Tasks NPM package definitions. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") -load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm") -load("@npm//@bazel/rollup:index.bzl", "rollup_bundle") -load( - "//mediapipe/framework/tool:mediapipe_files.bzl", - "mediapipe_files", -) - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -mediapipe_files(srcs = [ - "wasm/audio_wasm_internal.js", - "wasm/audio_wasm_internal.wasm", - "wasm/audio_wasm_nosimd_internal.js", - "wasm/audio_wasm_nosimd_internal.wasm", - "wasm/text_wasm_internal.js", - "wasm/text_wasm_internal.wasm", - "wasm/text_wasm_nosimd_internal.js", - "wasm/text_wasm_nosimd_internal.wasm", - "wasm/vision_wasm_internal.js", - "wasm/vision_wasm_internal.wasm", - "wasm/vision_wasm_nosimd_internal.js", - "wasm/vision_wasm_nosimd_internal.wasm", +exports_files([ + "karma.conf.ts", + "package.json", + "rollup.config.mjs", ]) - -# Audio - -mediapipe_ts_library( - name = "audio_lib", - srcs = ["audio.ts"], - deps = ["//mediapipe/tasks/web/audio:audio_lib"], -) - -rollup_bundle( - name = "audio_bundle", - config_file = "rollup.config.mjs", - entry_point = "audio.ts", - format = "esm", - output_dir = False, - sourcemap = "false", - deps = [ - ":audio_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-terser", - "@npm//google-protobuf", - ], -) - -pkg_npm( - name = "audio_pkg", - package_name = "@mediapipe/tasks-__NAME__", - srcs = ["package.json"], - substitutions = { - "__NAME__": "audio", - "__DESCRIPTION__": "MediaPipe Audio Tasks", - "__TYPES__": "audio.d.ts", - }, - tgz = "audio.tgz", - deps = [ - "wasm/audio_wasm_internal.js", - "wasm/audio_wasm_internal.wasm", - "wasm/audio_wasm_nosimd_internal.js", - "wasm/audio_wasm_nosimd_internal.wasm", - ":audio_bundle", - "//mediapipe/tasks/web/audio:README.md", - ], -) - -# Text - -mediapipe_ts_library( - name = "text_lib", - srcs = ["text.ts"], - deps = ["//mediapipe/tasks/web/text:text_lib"], -) - -rollup_bundle( - name = "text_bundle", - config_file = "rollup.config.mjs", - entry_point = "text.ts", - format = "esm", - output_dir = False, - sourcemap = "false", - deps = [ - ":text_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-terser", - "@npm//google-protobuf", - ], -) - -pkg_npm( - name = "text_pkg", - package_name = "@mediapipe/tasks-__NAME__", - srcs = ["package.json"], - substitutions = { - "__NAME__": "text", - "__DESCRIPTION__": "MediaPipe Text Tasks", - "__TYPES__": "text.d.ts", - }, - tgz = "text.tgz", - deps = [ - "wasm/text_wasm_internal.js", - "wasm/text_wasm_internal.wasm", - "wasm/text_wasm_nosimd_internal.js", - "wasm/text_wasm_nosimd_internal.wasm", - ":text_bundle", - "//mediapipe/tasks/web/text:README.md", - ], -) - -# Vision - -mediapipe_ts_library( - name = "vision_lib", - srcs = ["vision.ts"], - deps = ["//mediapipe/tasks/web/vision:vision_lib"], -) - -rollup_bundle( - name = "vision_bundle", - config_file = "rollup.config.mjs", - entry_point = "vision.ts", - format = "esm", - output_dir = False, - sourcemap = "false", - deps = [ - ":vision_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-terser", - "@npm//google-protobuf", - ], -) - -pkg_npm( - name = "vision_pkg", - package_name = "@mediapipe/tasks-__NAME__", - srcs = ["package.json"], - substitutions = { - "__NAME__": "vision", - "__DESCRIPTION__": "MediaPipe Vision Tasks", - "__TYPES__": "vision.d.ts", - }, - tgz = "vision_pkg.tgz", - deps = [ - "wasm/vision_wasm_internal.js", - "wasm/vision_wasm_internal.wasm", - "wasm/vision_wasm_nosimd_internal.js", - "wasm/vision_wasm_nosimd_internal.wasm", - ":vision_bundle", - "//mediapipe/tasks/web/vision:README.md", - ], -) diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts deleted file mode 100644 index 2f4fb0315..000000000 --- a/mediapipe/tasks/web/audio.ts +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {AudioClassifier as AudioClassifierImpl, AudioEmbedder as AudioEmbedderImpl, FilesetResolver as FilesetResolverImpl} from '../../tasks/web/audio/index'; - -// Declare the variables locally so that Rollup in OSS includes them explcilty -// as exports. -const AudioClassifier = AudioClassifierImpl; -const AudioEmbedder = AudioEmbedderImpl; -const FilesetResolver = FilesetResolverImpl; - -export {AudioClassifier, AudioEmbedder, FilesetResolver}; diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 50a611f41..7e05263fe 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -1,11 +1,15 @@ # This contains the MediaPipe Audio Tasks. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm") +load("@npm//@bazel/rollup:index.bzl", "rollup_bundle") +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) package(default_visibility = ["//mediapipe/tasks:internal"]) -exports_files(["README.md"]) - mediapipe_ts_library( name = "audio_lib", srcs = ["index.ts"], @@ -16,3 +20,53 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core:fileset_resolver", ], ) + +mediapipe_files(srcs = [ + "wasm/audio_wasm_internal.js", + "wasm/audio_wasm_internal.wasm", + "wasm/audio_wasm_nosimd_internal.js", + "wasm/audio_wasm_nosimd_internal.wasm", +]) + +rollup_bundle( + name = "audio_bundle", + config_file = "//mediapipe/tasks/web:rollup.config.mjs", + entry_point = "index.ts", + format = "esm", + output_dir = False, + sourcemap = "false", + deps = [ + ":audio_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", + ], +) + +genrule( + name = "package_json", + srcs = ["//mediapipe/tasks/web:package.json"], + outs = ["package.json"], + cmd = "cp $< $@", +) + +pkg_npm( + name = "audio_pkg", + package_name = "@mediapipe/tasks-__NAME__", + srcs = ["README.md"], + substitutions = { + "__NAME__": "audio", + "__DESCRIPTION__": "MediaPipe Audio Tasks", + "__TYPES__": "audio.d.ts", + }, + tgz = "audio.tgz", + deps = [ + "wasm/audio_wasm_internal.js", + "wasm/audio_wasm_internal.wasm", + "wasm/audio_wasm_nosimd_internal.js", + "wasm/audio_wasm_nosimd_internal.wasm", + ":audio_bundle", + ":package_json", + ], +) diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index dbad8c617..44fa7eb25 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -14,6 +14,14 @@ * limitations under the License. */ -export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; -export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; -export * from '../../../tasks/web/core/fileset_resolver'; +import {AudioClassifier as AudioClassifierImpl} from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +import {AudioEmbedder as AudioEmbedderImpl} from '../../../tasks/web/audio/audio_embedder/audio_embedder'; +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const AudioClassifier = AudioClassifierImpl; +const AudioEmbedder = AudioEmbedderImpl; +const FilesetResolver = FilesetResolverImpl; + +export {AudioClassifier, AudioEmbedder, FilesetResolver}; diff --git a/mediapipe/tasks/web/text.ts b/mediapipe/tasks/web/text.ts deleted file mode 100644 index 0636714b8..000000000 --- a/mediapipe/tasks/web/text.ts +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {FilesetResolver as FilesetResolverImpl, TextClassifier as TextClassifierImpl, TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/index'; - -// Declare the variables locally so that Rollup in OSS includes them explcilty -// as exports. -const FilesetResolver = FilesetResolverImpl; -const TextClassifier = TextClassifierImpl; -const TextEmbedder = TextEmbedderImpl; - -export {FilesetResolver, TextClassifier, TextEmbedder}; diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index 077b25645..6f019aca1 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -1,10 +1,21 @@ # This contains the MediaPipe Text Tasks. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm") +load("@npm//@bazel/rollup:index.bzl", "rollup_bundle") +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) package(default_visibility = ["//mediapipe/tasks:internal"]) -exports_files(["README.md"]) +mediapipe_files(srcs = [ + "wasm/text_wasm_internal.js", + "wasm/text_wasm_internal.wasm", + "wasm/text_wasm_nosimd_internal.js", + "wasm/text_wasm_nosimd_internal.wasm", +]) mediapipe_ts_library( name = "text_lib", @@ -16,3 +27,46 @@ mediapipe_ts_library( "//mediapipe/tasks/web/text/text_embedder", ], ) + +rollup_bundle( + name = "text_bundle", + config_file = "//mediapipe/tasks/web:rollup.config.mjs", + entry_point = "index.ts", + format = "esm", + output_dir = False, + sourcemap = "false", + deps = [ + ":text_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", + ], +) + +genrule( + name = "package_json", + srcs = ["//mediapipe/tasks/web:package.json"], + outs = ["package.json"], + cmd = "cp $< $@", +) + +pkg_npm( + name = "text_pkg", + package_name = "@mediapipe/tasks-__NAME__", + srcs = ["README.md"], + substitutions = { + "__NAME__": "text", + "__DESCRIPTION__": "MediaPipe Text Tasks", + "__TYPES__": "text.d.ts", + }, + tgz = "text.tgz", + deps = [ + "wasm/text_wasm_internal.js", + "wasm/text_wasm_internal.wasm", + "wasm/text_wasm_nosimd_internal.js", + "wasm/text_wasm_nosimd_internal.wasm", + ":package_json", + ":text_bundle", + ], +) diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index a28e4dd1c..2c9e6fead 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -14,6 +14,14 @@ * limitations under the License. */ -export * from '../../../tasks/web/text/text_classifier/text_classifier'; -export * from '../../../tasks/web/text/text_embedder/text_embedder'; -export * from '../../../tasks/web/core/fileset_resolver'; +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {TextClassifier as TextClassifierImpl} from '../../../tasks/web/text/text_classifier/text_classifier'; +import {TextEmbedder as TextEmbedderImpl} from '../../../tasks/web/text/text_embedder/text_embedder'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const FilesetResolver = FilesetResolverImpl; +const TextClassifier = TextClassifierImpl; +const TextEmbedder = TextEmbedderImpl; + +export {FilesetResolver, TextClassifier, TextEmbedder}; diff --git a/mediapipe/tasks/web/vision.ts b/mediapipe/tasks/web/vision.ts deleted file mode 100644 index f1ced59af..000000000 --- a/mediapipe/tasks/web/vision.ts +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {FilesetResolver as FilesetResolverImpl, GestureRecognizer as GestureRecognizerImpl, HandLandmarker as HandLandmarkerImpl, ImageClassifier as ImageClassifierImpl, ImageEmbedder as ImageEmbedderImpl, ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/index'; - -// Declare the variables locally so that Rollup in OSS includes them explcilty -// as exports. -const FilesetResolver = FilesetResolverImpl; -const GestureRecognizer = GestureRecognizerImpl; -const HandLandmarker = HandLandmarkerImpl; -const ImageClassifier = ImageClassifierImpl; -const ImageEmbedder = ImageEmbedderImpl; -const ObjectDetector = ObjectDetectorImpl; - -export { - FilesetResolver, - GestureRecognizer, - HandLandmarker, - ImageClassifier, - ImageEmbedder, - ObjectDetector -}; diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index ea022e900..76b0c084e 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -1,10 +1,21 @@ # This contains the MediaPipe Vision Tasks. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm") +load("@npm//@bazel/rollup:index.bzl", "rollup_bundle") +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) package(default_visibility = ["//mediapipe/tasks:internal"]) -exports_files(["README.md"]) +mediapipe_files(srcs = [ + "wasm/vision_wasm_internal.js", + "wasm/vision_wasm_internal.wasm", + "wasm/vision_wasm_nosimd_internal.js", + "wasm/vision_wasm_nosimd_internal.wasm", +]) mediapipe_ts_library( name = "vision_lib", @@ -19,3 +30,46 @@ mediapipe_ts_library( "//mediapipe/tasks/web/vision/object_detector", ], ) + +rollup_bundle( + name = "vision_bundle", + config_file = "//mediapipe/tasks/web:rollup.config.mjs", + entry_point = "index.ts", + format = "esm", + output_dir = False, + sourcemap = "false", + deps = [ + ":vision_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", + ], +) + +genrule( + name = "package_json", + srcs = ["//mediapipe/tasks/web:package.json"], + outs = ["package.json"], + cmd = "cp $< $@", +) + +pkg_npm( + name = "vision_pkg", + package_name = "@mediapipe/tasks-__NAME__", + srcs = ["README.md"], + substitutions = { + "__NAME__": "vision", + "__DESCRIPTION__": "MediaPipe Vision Tasks", + "__TYPES__": "vision.d.ts", + }, + tgz = "vision_pkg.tgz", + deps = [ + "wasm/vision_wasm_internal.js", + "wasm/vision_wasm_internal.wasm", + "wasm/vision_wasm_nosimd_internal.js", + "wasm/vision_wasm_nosimd_internal.wasm", + ":package_json", + ":vision_bundle", + ], +) diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 0337a0f2f..e13f8183f 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -14,9 +14,27 @@ * limitations under the License. */ -export * from '../../../tasks/web/vision/image_classifier/image_classifier'; -export * from '../../../tasks/web/vision/image_embedder/image_embedder'; -export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; -export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; -export * from '../../../tasks/web/vision/object_detector/object_detector'; -export * from '../../../tasks/web/core/fileset_resolver'; +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; +import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier'; +import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder'; +import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const FilesetResolver = FilesetResolverImpl; +const GestureRecognizer = GestureRecognizerImpl; +const HandLandmarker = HandLandmarkerImpl; +const ImageClassifier = ImageClassifierImpl; +const ImageEmbedder = ImageEmbedderImpl; +const ObjectDetector = ObjectDetectorImpl; + +export { + FilesetResolver, + GestureRecognizer, + HandLandmarker, + ImageClassifier, + ImageEmbedder, + ObjectDetector +}; From 92a2e02ace78105cf120bb68a17c07df6dbd027f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 13 Jan 2023 17:02:59 -0800 Subject: [PATCH 313/346] Internal change PiperOrigin-RevId: 501971410 --- mediapipe/framework/deps/BUILD | 12 ++++++++++++ mediapipe/framework/profiler/BUILD | 1 + mediapipe/util/BUILD | 1 + 3 files changed, 14 insertions(+) diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index 7ff004f1e..7994aae75 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -88,6 +88,7 @@ cc_library( name = "message_matchers", testonly = True, hdrs = ["message_matchers.h"], + # Use this library through "mediapipe/framework/port:gtest_main". visibility = [ "//mediapipe/framework/port:__pkg__", @@ -145,6 +146,7 @@ cc_library( cc_library( name = "map_util", hdrs = ["map_util.h"], + # Use this library through "mediapipe/framework/port:map_util". visibility = ["//mediapipe/framework/port:__pkg__"], deps = ["//mediapipe/framework/port:logging"], @@ -180,6 +182,7 @@ cc_library( cc_library( name = "point", hdrs = ["point2.h"], + # Use this library through "mediapipe/framework/port:point". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -199,6 +202,7 @@ cc_library( cc_library( name = "rectangle", hdrs = ["rectangle.h"], + # Use this library through "mediapipe/framework/port:rectangle". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -239,6 +243,7 @@ cc_library( cc_library( name = "singleton", hdrs = ["singleton.h"], + # Use this library through "mediapipe/framework/port:singleton". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -249,6 +254,7 @@ cc_library( cc_library( name = "source_location", hdrs = ["source_location.h"], + # Use this library through "mediapipe/framework/port:source_location". visibility = ["//mediapipe/framework/port:__pkg__"], ) @@ -265,6 +271,7 @@ cc_library( "status_builder.h", "status_macros.h", ], + # Use this library through "mediapipe/framework/port:status". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -288,6 +295,7 @@ cc_library( name = "status_matchers", testonly = 1, hdrs = ["status_matchers.h"], + # Use this library through "mediapipe/framework/port:gtest_main". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -301,6 +309,7 @@ cc_library( name = "ret_check", srcs = ["ret_check.cc"], hdrs = ["ret_check.h"], + # Use this library through "mediapipe/framework/port:ret_check". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -321,6 +330,7 @@ cc_library( "//conditions:default": ["threadpool_pthread_impl.cc"], }), hdrs = ["threadpool.h"], + # Use this library through "mediapipe/framework/port:threadpool". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -335,6 +345,7 @@ cc_library( name = "topologicalsorter", srcs = ["topologicalsorter.cc"], hdrs = ["topologicalsorter.h"], + # Use this library through "mediapipe/framework/port:topologicalsorter". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -345,6 +356,7 @@ cc_library( cc_library( name = "vector", hdrs = ["vector.h"], + # Use this library through "mediapipe/framework/port:vector". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 3b6976fc8..7ebfd3b8c 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -284,6 +284,7 @@ cc_library( "//mediapipe:ios": ["profiler_resource_util_ios.cc"], }), hdrs = ["profiler_resource_util.h"], + # We use Objective-C++ on iOS. copts = select({ "//conditions:default": [], diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index 55c1df59f..555569552 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -186,6 +186,7 @@ cc_library( hdrs = [ "resource_util.h", ], + # We use Objective-C++ on iOS. copts = select({ "//conditions:default": [], From 30533be321744ddea7f37fea0bf77298596b9b92 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 16 Jan 2023 13:00:10 +0530 Subject: [PATCH 314/346] Reformatted comments --- .../tasks/ios/common/sources/MPPCommon.h | 57 ++++++++++++------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/mediapipe/tasks/ios/common/sources/MPPCommon.h b/mediapipe/tasks/ios/common/sources/MPPCommon.h index 0f885a8c2..3f0a1a7b9 100644 --- a/mediapipe/tasks/ios/common/sources/MPPCommon.h +++ b/mediapipe/tasks/ios/common/sources/MPPCommon.h @@ -39,53 +39,70 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { /** Indicates some requested entity (such as a file or directory) was not found. */ MPPTasksErrorCodeNotFoundError = 5, - /** Indicates that the entity a caller attempted to create (such as a file or directory) is - already present. */ + /** + * Indicates that the entity a caller attempted to create (such as a file or directory) is + * already present. + */ MPPTasksErrorCodeAlreadyExistsError = 6, /** Indicates that the caller does not have permission to execute the specified operation. */ MPPTasksErrorCodePermissionDeniedError = 7, - /** Indicates some resource has been exhausted, perhaps a per-user quota, or perhaps the entire - file system is out of space. */ + /** + * Indicates some resource has been exhausted, perhaps a per-user quota, or perhaps the entire + * file system is out of space. + */ MPPTasksErrorCodeResourceExhaustedError = 8, - /** Indicates that the operation was rejected because the system is not in a state required for - the operation's execution. For example, a directory to be deleted may be non-empty, an "rmdir" - operation is applied to a non-directory, etc. */ + /** + * Indicates that the operation was rejected because the system is not in a state required for + * the operation's execution. For example, a directory to be deleted may be non-empty, an "rmdir" + * operation is applied to a non-directory, etc. + */ MPPTasksErrorCodeFailedPreconditionError = 9, - /** Indicates the operation was aborted, typically due to a concurrency issue such as a sequencer - check failure or a failed transaction. */ + /** + * Indicates the operation was aborted, typically due to a concurrency issue such as a sequencer + * check failure or a failed transaction. + */ MPPTasksErrorCodeAbortedError = 10, - /** Indicates the operation was attempted past the valid range, such as seeking or reading past an - end-of-file. */ + /** + * Indicates the operation was attempted past the valid range, such as seeking or reading past an + * end-of-file. + */ MPPTasksErrorCodeOutOfRangeError = 11, - /** Indicates the operation is not implemented or supported in this service. In this case, the - operation should not be re-attempted. */ + /** + * Indicates the operation is not implemented or supported in this service. In this case, the + * operation should not be re-attempted. + */ MPPTasksErrorCodeUnimplementedError = 12, - /** Indicates an internal error has occurred and some invariants expected by the underlying system - have not been satisfied. This error code is reserved for serious errors. */ + /** + * Indicates an internal error has occurred and some invariants expected by the underlying system + * have not been satisfied. This error code is reserved for serious errors. + */ MPPTasksErrorCodeInternalError = 13, - /** Indicates the service is currently unavailable and that this is most likely a transient - condition. */ + /** + * Indicates the service is currently unavailable and that this is most likely a transient + * condition. + */ MPPTasksErrorCodeUnavailableError = 14, /** Indicates that unrecoverable data loss or corruption has occurred. */ MPPTasksErrorCodeDataLossError = 15, - /** Indicates that the request does not have valid authentication credentials for the operation. + /** + * Indicates that the request does not have valid authentication credentials for the operation. */ MPPTasksErrorCodeUnauthenticatedError = 16, - // The first error code in MPPTasksErrorCode (for internal use only). + /** The first error code in MPPTasksErrorCode (for internal use only). */ MPPTasksErrorCodeFirst = MPPTasksErrorCodeCancelledError, - // The last error code in MPPTasksErrorCode (for internal use only). + /** The last error code in MPPTasksErrorCode (for internal use only). */ MPPTasksErrorCodeLast = MPPTasksErrorCodeUnauthenticatedError, } NS_SWIFT_NAME(TasksErrorCode); From 8ecf77f760c49fd319b80a2bd5daefaba5a7cd72 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 16 Jan 2023 13:02:33 +0530 Subject: [PATCH 315/346] Updated comment style in methods --- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 27b75515d..538023df6 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -72,7 +72,7 @@ using absl::StatusCode; return YES; } - /** Converts the absl status message to an NSString. */ + // Converts the absl status message to an NSString. NSString *description = [NSString stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str() encoding:NSUTF8StringEncoding]; @@ -81,8 +81,8 @@ using absl::StatusCode; MPPTasksErrorCode errorCode = genericErrorCode; - /** Maps the absl::StatusCode to the appropriate MPPTasksErrorCode. Note: MPPTasksErrorCode omits - * absl::StatusCode::kOk. */ + // Maps the absl::StatusCode to the appropriate MPPTasksErrorCode. Note: MPPTasksErrorCode omits + // absl::StatusCode::kOk. switch (status.code()) { case StatusCode::kCancelled: errorCode = MPPTasksErrorCodeCancelledError; From f7fc8a6eca14b2c93fbe7a8c1c5162a1f9d59223 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 16 Jan 2023 13:05:29 +0530 Subject: [PATCH 316/346] Updated method names in tests --- .../ios/test/text/text_classifier/MPPTextClassifierTests.m | 7 +++---- .../test/text/text_classifier/TextClassifierTests.swift | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m index 3e2fe4bef..a8e541014 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m +++ b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m @@ -43,7 +43,6 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; #define AssertTextClassifierResultHasOneHead(textClassifierResult) \ XCTAssertNotNil(textClassifierResult); \ - \ XCTAssertNotNil(textClassifierResult.classificationResult); \ XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); \ XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0); @@ -156,7 +155,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; AssertEqualErrors(error, expectedError); } -- (void)testCreateTextClassifierFailsWithBothAllowListAndDenyList { +- (void)testCreateTextClassifierFailsWithBothAllowlistAndDenylist { MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; options.categoryAllowlist = @[ @"positive" ]; @@ -233,7 +232,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; expectedBertResultCategoriesForEdgeCaseTests]]; } -- (void)testClassifyWithCategoryAllowListSucceeds { +- (void)testClassifyWithCategoryAllowlistSucceeds { MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; options.categoryAllowlist = @[ @"negative" ]; @@ -250,7 +249,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; expectedBertResultCategoriesForEdgeCaseTests]]; } -- (void)testClassifyWithCategoryDenyListSucceeds { +- (void)testClassifyWithCategoryDenylistSucceeds { MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; options.categoryDenylist = @[ @"positive" ]; diff --git a/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift index d2d433c22..01b5748cf 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift +++ b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift @@ -147,7 +147,7 @@ class TextClassifierTests: XCTestCase { """) } - func testCreateTextClassifierWithCategoryAllowlistandDenylistFails() throws { + func testCreateTextClassifierWithCategoryAllowlistAndDenylistFails() throws { let textClassifierOptions = try XCTUnwrap( From a0b3e620e4d024259bd2637198dd3141767f12d9 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 16 Jan 2023 13:12:27 +0530 Subject: [PATCH 317/346] Removed unused methods --- .../test/text/text_classifier/MPPTextClassifierTests.m | 8 -------- 1 file changed, 8 deletions(-) diff --git a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m index a8e541014..5c0964e68 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m +++ b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m @@ -52,14 +52,6 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; @implementation MPPTextClassifierTests -- (void)setUp { -} - -- (void)tearDown { - // Put teardown code here. This method is called after the invocation of each test method in the - // class. -} - + (NSArray *)expectedBertResultCategoriesForNegativeText { return @[ [[MPPCategory alloc] initWithIndex:0 score:0.956187f categoryName:@"negative" displayName:nil], From cf945d3aebc0b705117946cddd583ae1066ef97b Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 16 Jan 2023 13:59:51 +0530 Subject: [PATCH 318/346] Removed unused variable --- mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 538023df6..f3d9ecc79 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -77,9 +77,7 @@ using absl::StatusCode; stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str() encoding:NSUTF8StringEncoding]; - MPPTasksErrorCode genericErrorCode = MPPTasksErrorCodeUnknownError; - - MPPTasksErrorCode errorCode = genericErrorCode; + MPPTasksErrorCode errorCode = MPPTasksErrorCodeUnknownError; // Maps the absl::StatusCode to the appropriate MPPTasksErrorCode. Note: MPPTasksErrorCode omits // absl::StatusCode::kOk. @@ -133,7 +131,6 @@ using absl::StatusCode; errorCode = MPPTasksErrorCodeUnauthenticatedError; break; default: - errorCode = genericErrorCode; break; } From 67735a6fd30518bb68843a140841547540b0ee61 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 16 Jan 2023 14:01:10 +0530 Subject: [PATCH 319/346] Added category indices in iOS failure description --- .../text_classifier/MPPTextClassifierTests.m | 17 ++++---- .../text_classifier/TextClassifierTests.swift | 40 ++++++++++++++----- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m index 5c0964e68..ebeaf863f 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m +++ b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m @@ -32,13 +32,16 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \ NSNotFound) -#define AssertEqualCategoryArrays(categories, expectedCategories) \ - XCTAssertEqual(categories.count, expectedCategories.count); \ - for (int i = 0; i < categories.count; i++) { \ - XCTAssertEqual(categories[i].index, expectedCategories[i].index); \ - XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-6); \ - XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName); \ - XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName); \ +#define AssertEqualCategoryArrays(categories, expectedCategories) \ + XCTAssertEqual(categories.count, expectedCategories.count); \ + for (int i = 0; i < categories.count; i++) { \ + XCTAssertEqual(categories[i].index, expectedCategories[i].index, @"index i = %d", i); \ + XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-6, \ + @"index i = %d", i); \ + XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName, \ + @"index i = %d", i); \ + XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName, \ + @"index i = %d", i); \ } #define AssertTextClassifierResultHasOneHead(textClassifierResult) \ diff --git a/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift index 01b5748cf..186887778 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift +++ b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift @@ -60,33 +60,55 @@ class TextClassifierTests: XCTestCase { func assertCategoriesAreEqual( category: ResultCategory, - expectedCategory: ResultCategory) { + expectedCategory: ResultCategory, + indexInCategoryList: Int) { XCTAssertEqual( category.index, - expectedCategory.index) + expectedCategory.index, + String( + format: """ + category[%d].index and expectedCategory[%d].index are not equal. + """, indexInCategoryList)) XCTAssertEqual( category.score, expectedCategory.score, - accuracy:1e-6) + accuracy:1e-6, + String( + format: """ + category[%d].score and expectedCategory[%d].score are not equal. + """, indexInCategoryList)) XCTAssertEqual( category.categoryName, - expectedCategory.categoryName) + expectedCategory.categoryName, + String( + format: """ + category[%d].categoryName and expectedCategory[%d].categoryName are \ + not equal. + """, indexInCategoryList)) XCTAssertEqual( category.displayName, - expectedCategory.displayName) + expectedCategory.displayName, + String( + format: """ + category[%d].displayName and expectedCategory[%d].displayName are \ + not equal. + """, indexInCategoryList)) } func assertEqualCategoryArrays( categoryArray: [ResultCategory], expectedCategoryArray:[ResultCategory]) { - XCTAssertEqual(categoryArray.count, expectedCategoryArray.count) + XCTAssertEqual( + categoryArray.count, + expectedCategoryArray.count) - for (category, expectedCategory) in - zip(categoryArray, expectedCategoryArray) { + for (index, (category, expectedCategory)) in + zip(categoryArray, expectedCategoryArray).enumerated() { assertCategoriesAreEqual( category:category, - expectedCategory:expectedCategory) + expectedCategory:expectedCategory, + indexInCategoryList:index) } } From ffd8486d0dc045af18c6ab0e1c7bf732e5a9f3ca Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 16 Jan 2023 08:35:56 -0800 Subject: [PATCH 320/346] Add a stub WriteProfile method to GraphProfilerStub. PiperOrigin-RevId: 502388455 --- mediapipe/framework/profiler/graph_profiler_stub.h | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/framework/profiler/graph_profiler_stub.h b/mediapipe/framework/profiler/graph_profiler_stub.h index 12a024fe8..72d5d7275 100644 --- a/mediapipe/framework/profiler/graph_profiler_stub.h +++ b/mediapipe/framework/profiler/graph_profiler_stub.h @@ -93,6 +93,7 @@ class GraphProfilerStub { PopulateGraphConfig populate_config = PopulateGraphConfig::kNo) { return absl::OkStatus(); } + inline absl::Status WriteProfile() { return absl::OkStatus(); } inline void Pause() {} inline void Resume() {} inline void Reset() {} From c1f5920ecf3beed2457d9df9ba0bdb7cd7e5a47c Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 16 Jan 2023 12:57:44 -0800 Subject: [PATCH 321/346] Add web performance tracing to the MEDIAPIPE_PROFILING repertoire This records the MEDIAPIPE_PROFILING tracing annotations to the browser's trace using the user timing API. See https://developer.mozilla.org/en-US/docs/Web/API/User_Timing_API To enable, build with --define MEDIAPIPE_WEB_PROFILING=1 --define DRISHTI_PROFILING=1 PiperOrigin-RevId: 502422030 --- mediapipe/framework/profiler/BUILD | 20 ++++++ .../profiler/web_performance_profiling.h | 68 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 mediapipe/framework/profiler/web_performance_profiling.h diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 7ebfd3b8c..6184ed45b 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -127,6 +127,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:advanced_proto_lite", "//mediapipe/framework/tool:name_util", + ":web_performance_profiling", ] + select({ "//conditions:default": [], }) + select({ @@ -275,6 +276,25 @@ cc_test( ], ) +config_setting( + name = "mediapipe_web_profiling_enabled", + values = { + "define": "MEDIAPIPE_WEB_PROFILING=1", + }, + visibility = ["//visibility:private"], +) + +cc_library( + name = "web_performance_profiling", + hdrs = ["web_performance_profiling.h"], + defines = select({ + ":mediapipe_web_profiling_enabled": ["MEDIAPIPE_WEB_PROFILING_ENABLED"], + "//conditions:default": [], + }), + visibility = ["//mediapipe:__subpackages__"], + deps = ["@com_google_absl//absl/strings"], +) + cc_library( name = "profiler_resource_util", srcs = ["profiler_resource_util_common.cc"] + select({ diff --git a/mediapipe/framework/profiler/web_performance_profiling.h b/mediapipe/framework/profiler/web_performance_profiling.h new file mode 100644 index 000000000..47b76fe88 --- /dev/null +++ b/mediapipe/framework/profiler/web_performance_profiling.h @@ -0,0 +1,68 @@ +#ifndef MEDIAPIPE_FRAMEWORK_PROFILER_WEB_PERFORMANCE_PROFILING_H_ +#define MEDIAPIPE_FRAMEWORK_PROFILER_WEB_PERFORMANCE_PROFILING_H_ + +#if MEDIAPIPE_WEB_PROFILING_ENABLED && __EMSCRIPTEN__ +#include + +#include "absl/strings/str_cat.h" + +// This records MediaPipe profiling events in the browser's performance trace. +// To use, build with: +// --define MEDIAPIPE_PROFILING=1 --define MEDIAPIPE_WEB_PROFILING=1 + +namespace mediapipe { + +class WepPerformanceTraceScope { + public: + explicit WepPerformanceTraceScope(TraceEvent::EventType event_type, + const char* event_type_str, + CalculatorContext* cc) + : event_type_str_(event_type_str), cc_(cc) { + const auto& calculator_name = cc->NodeName(); + std::string start_name = + absl::StrCat(calculator_name, "::", event_type_str_, "_start"); + std::string timestamp_str = cc->InputTimestamp().DebugString(); + EM_ASM( + { + const startName = UTF8ToString($0); + const timestamp = UTF8ToString($1); + performance.mark(startName, {mp_timestamp : timestamp}); + }, + start_name.c_str(), timestamp_str.c_str()); + } + + ~WepPerformanceTraceScope() { + const auto& calculator_name = cc_->NodeName(); + std::string start_name = + absl::StrCat(calculator_name, "::", event_type_str_, "_start"); + std::string end_name = + absl::StrCat(calculator_name, "::", event_type_str_, "_end"); + std::string measure_name = + absl::StrCat(calculator_name, "::", event_type_str_); + EM_ASM( + { + const startName = UTF8ToString($0); + const endName = UTF8ToString($1); + const measureName = UTF8ToString($2); + performance.mark(endName); + performance.measure(measureName, startName, endName); + }, + start_name.c_str(), end_name.c_str(), measure_name.c_str()); + } + + private: + const char* event_type_str_; + CalculatorContext* cc_; +}; + +} // namespace mediapipe + +#define MEDIAPIPE_WEB_PERFORMANCE_SCOPE(event_type, calculator_context) \ + mediapipe::WepPerformanceTraceScope web_trace_scope( \ + mediapipe::TraceEvent::event_type, #event_type, calculator_context) + +#else +#define MEDIAPIPE_WEB_PERFORMANCE_SCOPE(event_type, calculator_context) +#endif // MEDIAPIPE_WEB_PROFILING_ENABLED && __EMSCRIPTEN__ + +#endif // MEDIAPIPE_FRAMEWORK_PROFILER_WEB_PERFORMANCE_PROFILING_H_ From 7974171c3d0364a1bd79b6dc615b60ff57b175e7 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 17 Jan 2023 09:04:54 -0800 Subject: [PATCH 322/346] Merge `classificationResultList()` and `classificationResult()` to be `classificationResults()`, and similar for `embeddingResults()`. PiperOrigin-RevId: 502601043 --- .../AudioClassifierResult.java | 27 +++++++--------- .../audioembedder/AudioEmbedderResult.java | 31 +++++++++---------- 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java index 3102aa8cd..258e5725b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java @@ -20,7 +20,6 @@ import com.google.mediapipe.tasks.components.containers.proto.ClassificationsPro import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.List; -import java.util.Optional; /** Represents the classification results generated by {@link AudioClassifier}. */ @AutoValue @@ -40,8 +39,7 @@ public abstract class AudioClassifierResult implements TaskResult { for (ClassificationsProto.ClassificationResult proto : protoList) { classificationResultList.add(ClassificationResult.createFromProto(proto)); } - return new AutoValue_AudioClassifierResult( - Optional.of(classificationResultList), Optional.empty(), timestampMs); + return new AutoValue_AudioClassifierResult(classificationResultList, timestampMs); } /** @@ -53,23 +51,22 @@ public abstract class AudioClassifierResult implements TaskResult { */ static AudioClassifierResult createFromProto( ClassificationsProto.ClassificationResult proto, long timestampMs) { - return new AutoValue_AudioClassifierResult( - Optional.empty(), Optional.of(ClassificationResult.createFromProto(proto)), timestampMs); + List classificationResultList = new ArrayList<>(); + classificationResultList.add(ClassificationResult.createFromProto(proto)); + return new AutoValue_AudioClassifierResult(classificationResultList, timestampMs); } /** * A list of of timestamped {@link ClassificationResult} objects, each contains one set of results - * per classifier head. The list represents the audio classification result of an audio clip, and - * is only available when running with the audio clips mode. + * per classifier head. + * + *

In the "audio stream" mode, the list only contains one element, representing the + * classification result of the audio block that starts at {@link + * ClassificationResult.timestampMs} in the audio stream. Otherwise, in the "audio clips" mode, + * the list may include multiple {@link ClassificationResult} objects, each classifying an + * interval of the entire audio clip that starts at {@link ClassificationResult.timestampMs}. */ - public abstract Optional> classificationResultList(); - - /** - * Contains one set of results per classifier head. A {@link ClassificationResult} usually - * represents one audio classification result in an audio stream, and s only available when - * running with the audio stream mode. - */ - public abstract Optional classificationResult(); + public abstract List classificationResults(); @Override public abstract long timestampMs(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java index a986048f0..0cfd2297c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java @@ -20,7 +20,6 @@ import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.List; -import java.util.Optional; /** Represents the embedding results generated by {@link AudioEmbedder}. */ @AutoValue @@ -35,12 +34,11 @@ public abstract class AudioEmbedderResult implements TaskResult { */ static AudioEmbedderResult createFromProtoList( List protoList, long timestampMs) { - List classificationResultList = new ArrayList<>(); + List embeddingResultList = new ArrayList<>(); for (EmbeddingsProto.EmbeddingResult proto : protoList) { - classificationResultList.add(EmbeddingResult.createFromProto(proto)); + embeddingResultList.add(EmbeddingResult.createFromProto(proto)); } - return new AutoValue_AudioEmbedderResult( - Optional.of(classificationResultList), Optional.empty(), timestampMs); + return new AutoValue_AudioEmbedderResult(embeddingResultList, timestampMs); } /** @@ -52,23 +50,22 @@ public abstract class AudioEmbedderResult implements TaskResult { */ static AudioEmbedderResult createFromProto( EmbeddingsProto.EmbeddingResult proto, long timestampMs) { - return new AutoValue_AudioEmbedderResult( - Optional.empty(), Optional.of(EmbeddingResult.createFromProto(proto)), timestampMs); + List embeddingResultList = new ArrayList<>(); + embeddingResultList.add(EmbeddingResult.createFromProto(proto)); + return new AutoValue_AudioEmbedderResult(embeddingResultList, timestampMs); } /** * A list of of timpstamped {@link EmbeddingResult} objects, each contains one set of results per - * embedder head. The list represents the audio embedding result of an audio clip, and is only - * available when running with the audio clips mode. + * embedder head. + * + *

In the "audio stream" mode, the list only contains one element, representing the embedding + * result of the audio block that starts at {@link EmbeddingResult.timestampMs} in the audio + * stream. Otherwise, in the "audio clips" mode, the list may include multiple {@link + * EmbeddingResult} objects, each contains the embedding of an interval of the entire audio clip + * that starts at {@link EmbeddingResult.timestampMs}. */ - public abstract Optional> embeddingResultList(); - - /** - * Contains one set of results per classifier head. A {@link EmbeddingResult} usually represents - * one audio embedding result in an audio stream, and is only available when running with the - * audio stream mode. - */ - public abstract Optional embeddingResult(); + public abstract List embeddingResults(); @Override public abstract long timestampMs(); From 7a4b450c501ca14f2a34dc6d2810361d7424e03d Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 17 Jan 2023 10:51:04 -0800 Subject: [PATCH 323/346] Resolve the error "call to 'abs' is ambiguous". PiperOrigin-RevId: 502630518 --- mediapipe/tasks/cc/components/containers/rect.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/cc/components/containers/rect.h b/mediapipe/tasks/cc/components/containers/rect.h index 551d91588..72c7a8acb 100644 --- a/mediapipe/tasks/cc/components/containers/rect.h +++ b/mediapipe/tasks/cc/components/containers/rect.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ +#include #include namespace mediapipe::tasks::components::containers { @@ -48,10 +49,10 @@ struct RectF { }; inline bool operator==(const RectF& lhs, const RectF& rhs) { - return abs(lhs.left - rhs.left) < kRectFTolerance && - abs(lhs.top - rhs.top) < kRectFTolerance && - abs(lhs.right - rhs.right) < kRectFTolerance && - abs(lhs.bottom - rhs.bottom) < kRectFTolerance; + return std::fabs(lhs.left - rhs.left) < kRectFTolerance && + std::fabs(lhs.top - rhs.top) < kRectFTolerance && + std::fabs(lhs.right - rhs.right) < kRectFTolerance && + std::fabs(lhs.bottom - rhs.bottom) < kRectFTolerance; } RectF ToRectF(const Rect& rect, int image_height, int image_width); From 088249eb3697865dcd05c19dfb9065ddcf498d7e Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 17 Jan 2023 11:56:05 -0800 Subject: [PATCH 324/346] Export all input and output types PiperOrigin-RevId: 502649430 --- mediapipe/tasks/web/audio/index.ts | 14 +++----------- mediapipe/tasks/web/text/index.ts | 14 +++----------- mediapipe/tasks/web/vision/index.ts | 30 ++++++----------------------- 3 files changed, 12 insertions(+), 46 deletions(-) diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index 44fa7eb25..dbad8c617 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -14,14 +14,6 @@ * limitations under the License. */ -import {AudioClassifier as AudioClassifierImpl} from '../../../tasks/web/audio/audio_classifier/audio_classifier'; -import {AudioEmbedder as AudioEmbedderImpl} from '../../../tasks/web/audio/audio_embedder/audio_embedder'; -import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; - -// Declare the variables locally so that Rollup in OSS includes them explcilty -// as exports. -const AudioClassifier = AudioClassifierImpl; -const AudioEmbedder = AudioEmbedderImpl; -const FilesetResolver = FilesetResolverImpl; - -export {AudioClassifier, AudioEmbedder, FilesetResolver}; +export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index 2c9e6fead..f32c16c36 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -14,14 +14,6 @@ * limitations under the License. */ -import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; -import {TextClassifier as TextClassifierImpl} from '../../../tasks/web/text/text_classifier/text_classifier'; -import {TextEmbedder as TextEmbedderImpl} from '../../../tasks/web/text/text_embedder/text_embedder'; - -// Declare the variables locally so that Rollup in OSS includes them explcilty -// as exports. -const FilesetResolver = FilesetResolverImpl; -const TextClassifier = TextClassifierImpl; -const TextEmbedder = TextEmbedderImpl; - -export {FilesetResolver, TextClassifier, TextEmbedder}; +export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/text/text_classifier/text_classifier'; +export * from '../../../tasks/web/text/text_embedder/text_embedder'; diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index e13f8183f..2ba6ca812 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -14,27 +14,9 @@ * limitations under the License. */ -import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; -import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; -import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; -import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier'; -import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder'; -import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector'; - -// Declare the variables locally so that Rollup in OSS includes them explcilty -// as exports. -const FilesetResolver = FilesetResolverImpl; -const GestureRecognizer = GestureRecognizerImpl; -const HandLandmarker = HandLandmarkerImpl; -const ImageClassifier = ImageClassifierImpl; -const ImageEmbedder = ImageEmbedderImpl; -const ObjectDetector = ObjectDetectorImpl; - -export { - FilesetResolver, - GestureRecognizer, - HandLandmarker, - ImageClassifier, - ImageEmbedder, - ObjectDetector -}; +export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; +export * from '../../../tasks/web/vision/image_classifier/image_classifier'; +export * from '../../../tasks/web/vision/image_embedder/image_embedder'; +export * from '../../../tasks/web/vision/object_detector/object_detector'; From 7894c92ab7edded4810665958cc904b4e768e29a Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 17 Jan 2023 15:45:15 -0800 Subject: [PATCH 325/346] Internal change PiperOrigin-RevId: 502709070 --- mediapipe/util/log_fatal_to_breakpad.cc | 50 +++++++++++++++++++++++++ mediapipe/util/log_fatal_to_breakpad.h | 15 ++++++++ 2 files changed, 65 insertions(+) create mode 100644 mediapipe/util/log_fatal_to_breakpad.cc create mode 100644 mediapipe/util/log_fatal_to_breakpad.h diff --git a/mediapipe/util/log_fatal_to_breakpad.cc b/mediapipe/util/log_fatal_to_breakpad.cc new file mode 100644 index 000000000..45087f2e3 --- /dev/null +++ b/mediapipe/util/log_fatal_to_breakpad.cc @@ -0,0 +1,50 @@ +#include "mediapipe/util/log_fatal_to_breakpad.h" + +#import + +#include "absl/log/log.h" +#include "absl/log/log_sink.h" +#include "absl/log/log_sink_registry.h" +#import "googlemac/iPhone/Shared/GoogleIOSBreakpad/Classes/GoogleBreakpadController.h" + +namespace mediapipe { +namespace { +NSString* MakeNSString(absl::string_view str) { + return [[NSString alloc] initWithBytes:str.data() + length:str.length() + encoding:NSUTF8StringEncoding]; +} +} // namespace + +static NSString* const kFatalLogMessageKey = @"fatal_log_message"; + +class BreakpadFatalLogSink : public absl::LogSink { + public: + BreakpadFatalLogSink() + : breakpad_controller_([GoogleBreakpadController sharedInstance]) {} + void Send(const absl::LogEntry& entry) override { + if (entry.log_severity() != absl::LogSeverity::kFatal) return; + __block NSString* message = MakeNSString(entry.text_message_with_prefix()); + [breakpad_controller_ withBreakpadRef:^(BreakpadRef breakpad) { + // NOTE: This block runs on Breakpad's background queue. + if (!breakpad) return; + BreakpadAddUploadParameter(breakpad, kFatalLogMessageKey, message); + }]; + } + + private: + GoogleBreakpadController* breakpad_controller_; +}; + +absl::LogSink* GetBreakpadFatalLogSink() { + static BreakpadFatalLogSink sink; + return &sink; +} + +// This log sink is automatically enabled when including this library. +static const auto kRegisterLogSink = [] { + absl::AddLogSink(GetBreakpadFatalLogSink()); + return true; +}(); + +} // namespace mediapipe diff --git a/mediapipe/util/log_fatal_to_breakpad.h b/mediapipe/util/log_fatal_to_breakpad.h new file mode 100644 index 000000000..1712a9af8 --- /dev/null +++ b/mediapipe/util/log_fatal_to_breakpad.h @@ -0,0 +1,15 @@ +#ifndef MEDIAPIPE_UTIL_LOG_FATAL_TO_BREAKPAD_H_ +#define MEDIAPIPE_UTIL_LOG_FATAL_TO_BREAKPAD_H_ + +#include "absl/log/log_sink.h" + +namespace mediapipe { + +// Returns a singleton instance of a log sink that sends FATAL log messages to +// Breakpad. This log sink is enabled by default when this library is included +// in your binary. +absl::LogSink* GetBreakpadFatalLogSink(); + +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_LOG_FATAL_TO_BREAKPAD_H_ From 0b97c6e67d316d023ee4ea61366a6c9d886f7ac4 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 17 Jan 2023 15:45:28 -0800 Subject: [PATCH 326/346] Update the MP Wasm builds to latest version. PiperOrigin-RevId: 502709126 --- third_party/wasm_files.bzl | 48 +++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/third_party/wasm_files.bzl b/third_party/wasm_files.bzl index 504f8567a..017d84466 100644 --- a/third_party/wasm_files.bzl +++ b/third_party/wasm_files.bzl @@ -12,72 +12,72 @@ def wasm_files(): http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_js", - sha256 = "42d2d0ade6e2e8b81425b23686be93eb1423b7777f043eb8f18ad671e2ca803f", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1669173769507080"], + sha256 = "d4d205d08e3e1b09662a9a358d0107e8a8023827ba9b6982a3777bb6c040f936", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1673996821002628"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm", - sha256 = "20200ee9b0866d5176f633a9b375e8a44e53204c01ea2e159e2f9245afb00e80", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1669173772528997"], + sha256 = "1b2ffe82b0a25d20188237a724a7cad68d068818a7738f91c69c782314f55965", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1673996823772372"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js", - sha256 = "11bbf73d48723b19a5a6a13ec296ecdb2aa178cdc3db9d7bc54265a7d4b94c6a", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1669173774625527"], + sha256 = "1f367c2d667628b178251aec7fd464327351570edac4549450b11fb82f5f0fd4", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1673996826132845"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm", - sha256 = "d4528972219033996a83a62798952b6ee8b6b396bcffd96fd5bda5458d57d3a3", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1669173777474822"], + sha256 = "35c6ad888c06025dba1f9c8edb70e6c7be7e94e45dc2c0236a2fcfe61991dc44", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1673996828935550"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_js", - sha256 = "29e72e177122f92bda6a3ecd463ebacf30b920559b06c97068112a22eeea4d0e", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1669173779706893"], + sha256 = "68c0134e0b3cb986c3526cd645f74cc5a1f6ab19292276ca7d3558b89801e205", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1673996831356232"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_wasm", - sha256 = "84e5f5ac70f7718baeaa09a89b155abbea67386e7d50663301b3af7ef0941e74", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1669173782728605"], + sha256 = "df82bb192ea852dc1bcc8f9f28fbd8c3d6b219dc4fec2b2a92451678d98ee1f0", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1673996834657078"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js", - sha256 = "36f247673124e32535f217265b96508c1badee8fe2458c11c1efa95b6bec5daa", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1669173785027190"], + sha256 = "de1a4aabefb2e42ae4fee68b7e762e328623a163257a7ddc72365fc2502bd090", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1673996837104551"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm", - sha256 = "cc74d90a8aaf6d006ec24048cc80c33f96baeeb0075a6c6739f30d41da54e450", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1669173787903754"], + sha256 = "828dd1e73fa9478a97a62539117f92b813833ab35d37a986c466df15a8cfdc7b", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1673996840120504"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_js", - sha256 = "c3451423186766b08008e07ef6d52f628fcc0aca75beedd9bb4d87d380f29edd", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1669173790070986"], + sha256 = "c146b68523c256d41132230e811fc224dafb6a0bce6fc318c29dad37dfac06de", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1673996842448396"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm", - sha256 = "d1e8ad748913e3f190bfd3f72e0e8a4a308f78b918d54c79cec60a2cf30a49f0", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1669173792993881"], + sha256 = "8dbccaaf944ef1251cf78190450ab7074abea233e18ebb37d2c2ce0f18d14a0c", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1673996845499070"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js", - sha256 = "e5f1b5e8264ff9a90371653cb0fdbf9ce3b30b712acbd72068af18ebca2293ac", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1669173794969702"], + sha256 = "705f9e3c2c62d12903ea2cadc22d2c328bc890f96fffc47b51f989471196ecea", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1673996847915731"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm", - sha256 = "24351fe580e88f2065b1978b8b3c0f3ad7b90f1c95805aafa07971ce422b5854", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1669173797596874"], + sha256 = "c7ff6a7d8dc22380e2e8457a15a51b6bc1e70c6262fecca25825f54ecc593d1f", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1673996850980344"], ) From d5e60eb658c231424209d5274d9edb28bebca367 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 17 Jan 2023 20:51:40 -0800 Subject: [PATCH 327/346] Internal change PiperOrigin-RevId: 502764352 --- mediapipe/tasks/web/audio/BUILD | 19 ++++++++++++++----- mediapipe/tasks/web/audio/types.ts | 19 +++++++++++++++++++ mediapipe/tasks/web/text/BUILD | 19 ++++++++++++++----- mediapipe/tasks/web/text/types.ts | 19 +++++++++++++++++++ mediapipe/tasks/web/vision/BUILD | 25 +++++++++++++++++-------- mediapipe/tasks/web/vision/types.ts | 22 ++++++++++++++++++++++ 6 files changed, 105 insertions(+), 18 deletions(-) create mode 100644 mediapipe/tasks/web/audio/types.ts create mode 100644 mediapipe/tasks/web/text/types.ts create mode 100644 mediapipe/tasks/web/vision/types.ts diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 7e05263fe..409836800 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -10,15 +10,24 @@ load( package(default_visibility = ["//mediapipe/tasks:internal"]) +AUDIO_LIBS = [ + "//mediapipe/tasks/web/audio/audio_classifier", + "//mediapipe/tasks/web/audio/audio_embedder", + "//mediapipe/tasks/web/core:fileset_resolver", +] + mediapipe_ts_library( name = "audio_lib", srcs = ["index.ts"], visibility = ["//visibility:public"], - deps = [ - "//mediapipe/tasks/web/audio/audio_classifier", - "//mediapipe/tasks/web/audio/audio_embedder", - "//mediapipe/tasks/web/core:fileset_resolver", - ], + deps = AUDIO_LIBS, +) + +mediapipe_ts_library( + name = "audio_types", + srcs = ["types.ts"], + visibility = ["//visibility:public"], + deps = AUDIO_LIBS, ) mediapipe_files(srcs = [ diff --git a/mediapipe/tasks/web/audio/types.ts b/mediapipe/tasks/web/audio/types.ts new file mode 100644 index 000000000..19073b708 --- /dev/null +++ b/mediapipe/tasks/web/audio/types.ts @@ -0,0 +1,19 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index 6f019aca1..ebe3403b2 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -17,15 +17,24 @@ mediapipe_files(srcs = [ "wasm/text_wasm_nosimd_internal.wasm", ]) +TEXT_LIBS = [ + "//mediapipe/tasks/web/core:fileset_resolver", + "//mediapipe/tasks/web/text/text_classifier", + "//mediapipe/tasks/web/text/text_embedder", +] + mediapipe_ts_library( name = "text_lib", srcs = ["index.ts"], visibility = ["//visibility:public"], - deps = [ - "//mediapipe/tasks/web/core:fileset_resolver", - "//mediapipe/tasks/web/text/text_classifier", - "//mediapipe/tasks/web/text/text_embedder", - ], + deps = TEXT_LIBS, +) + +mediapipe_ts_library( + name = "text_types", + srcs = ["types.ts"], + visibility = ["//visibility:public"], + deps = TEXT_LIBS, ) rollup_bundle( diff --git a/mediapipe/tasks/web/text/types.ts b/mediapipe/tasks/web/text/types.ts new file mode 100644 index 000000000..bd01b1c6f --- /dev/null +++ b/mediapipe/tasks/web/text/types.ts @@ -0,0 +1,19 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/text/text_classifier/text_classifier'; +export * from '../../../tasks/web/text/text_embedder/text_embedder'; diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 76b0c084e..8ba9c85b3 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -17,18 +17,27 @@ mediapipe_files(srcs = [ "wasm/vision_wasm_nosimd_internal.wasm", ]) +VISION_LIBS = [ + "//mediapipe/tasks/web/core:fileset_resolver", + "//mediapipe/tasks/web/vision/gesture_recognizer", + "//mediapipe/tasks/web/vision/hand_landmarker", + "//mediapipe/tasks/web/vision/image_classifier", + "//mediapipe/tasks/web/vision/image_embedder", + "//mediapipe/tasks/web/vision/object_detector", +] + mediapipe_ts_library( name = "vision_lib", srcs = ["index.ts"], visibility = ["//visibility:public"], - deps = [ - "//mediapipe/tasks/web/core:fileset_resolver", - "//mediapipe/tasks/web/vision/gesture_recognizer", - "//mediapipe/tasks/web/vision/hand_landmarker", - "//mediapipe/tasks/web/vision/image_classifier", - "//mediapipe/tasks/web/vision/image_embedder", - "//mediapipe/tasks/web/vision/object_detector", - ], + deps = VISION_LIBS, +) + +mediapipe_ts_library( + name = "vision_types", + srcs = ["types.ts"], + visibility = ["//visibility:public"], + deps = VISION_LIBS, ) rollup_bundle( diff --git a/mediapipe/tasks/web/vision/types.ts b/mediapipe/tasks/web/vision/types.ts new file mode 100644 index 000000000..dd1f58294 --- /dev/null +++ b/mediapipe/tasks/web/vision/types.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; +export * from '../../../tasks/web/vision/image_classifier/image_classifier'; +export * from '../../../tasks/web/vision/image_embedder/image_embedder'; +export * from '../../../tasks/web/vision/object_detector/object_detector'; From e484bd681e03223a09619a6088dbb8b1a6c7557e Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 17 Jan 2023 20:53:07 -0800 Subject: [PATCH 328/346] Export all input and output types PiperOrigin-RevId: 502764544 --- mediapipe/tasks/web/audio/index.ts | 14 +++++++++++--- mediapipe/tasks/web/text/index.ts | 14 +++++++++++--- mediapipe/tasks/web/vision/index.ts | 30 +++++++++++++++++++++++------ 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index dbad8c617..44fa7eb25 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -14,6 +14,14 @@ * limitations under the License. */ -export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; -export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; -export * from '../../../tasks/web/core/fileset_resolver'; +import {AudioClassifier as AudioClassifierImpl} from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +import {AudioEmbedder as AudioEmbedderImpl} from '../../../tasks/web/audio/audio_embedder/audio_embedder'; +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const AudioClassifier = AudioClassifierImpl; +const AudioEmbedder = AudioEmbedderImpl; +const FilesetResolver = FilesetResolverImpl; + +export {AudioClassifier, AudioEmbedder, FilesetResolver}; diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index f32c16c36..2c9e6fead 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -14,6 +14,14 @@ * limitations under the License. */ -export * from '../../../tasks/web/core/fileset_resolver'; -export * from '../../../tasks/web/text/text_classifier/text_classifier'; -export * from '../../../tasks/web/text/text_embedder/text_embedder'; +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {TextClassifier as TextClassifierImpl} from '../../../tasks/web/text/text_classifier/text_classifier'; +import {TextEmbedder as TextEmbedderImpl} from '../../../tasks/web/text/text_embedder/text_embedder'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const FilesetResolver = FilesetResolverImpl; +const TextClassifier = TextClassifierImpl; +const TextEmbedder = TextEmbedderImpl; + +export {FilesetResolver, TextClassifier, TextEmbedder}; diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 2ba6ca812..e13f8183f 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -14,9 +14,27 @@ * limitations under the License. */ -export * from '../../../tasks/web/core/fileset_resolver'; -export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; -export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; -export * from '../../../tasks/web/vision/image_classifier/image_classifier'; -export * from '../../../tasks/web/vision/image_embedder/image_embedder'; -export * from '../../../tasks/web/vision/object_detector/object_detector'; +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; +import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier'; +import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder'; +import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const FilesetResolver = FilesetResolverImpl; +const GestureRecognizer = GestureRecognizerImpl; +const HandLandmarker = HandLandmarkerImpl; +const ImageClassifier = ImageClassifierImpl; +const ImageEmbedder = ImageEmbedderImpl; +const ObjectDetector = ObjectDetectorImpl; + +export { + FilesetResolver, + GestureRecognizer, + HandLandmarker, + ImageClassifier, + ImageEmbedder, + ObjectDetector +}; From 3688757d1706a5252de8196dfa56947dc0164671 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 18 Jan 2023 07:26:38 -0800 Subject: [PATCH 329/346] Fix `load_metadata_buffer` for empty metadata PiperOrigin-RevId: 502870428 --- mediapipe/tasks/python/metadata/metadata.py | 2 ++ .../python/test/metadata/metadata_test.py | 26 +++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/python/metadata/metadata.py b/mediapipe/tasks/python/metadata/metadata.py index 10a0b9b66..2327ebbdf 100644 --- a/mediapipe/tasks/python/metadata/metadata.py +++ b/mediapipe/tasks/python/metadata/metadata.py @@ -860,6 +860,8 @@ def get_metadata_buffer(model_buf): if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME: buffer_index = meta.Buffer() metadata = tflite_model.Buffers(buffer_index) + if metadata.DataLength() == 0: + continue return metadata.DataAsNumpy().tobytes() return None diff --git a/mediapipe/tasks/python/test/metadata/metadata_test.py b/mediapipe/tasks/python/test/metadata/metadata_test.py index bed9c2833..d892f1b61 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_test.py +++ b/mediapipe/tasks/python/test/metadata/metadata_test.py @@ -550,7 +550,7 @@ class MetadataPopulatorTest(MetadataTest): ("The number of output tensors (1) should match the number of " "output tensor metadata (0)"), str(error.exception)) - def testLoadMetadataAndAssociatedFilesShouldSucceeds(self): + def testLoadMetadataAndAssociatedFilesShouldSucceed(self): # Create a src model with metadata and two associated files. src_model_buf = self._create_model_buf() populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf) @@ -566,7 +566,7 @@ class MetadataPopulatorTest(MetadataTest): populator_src.get_model_buffer()) populator_dst.populate() - # Tests if the metadata and associated files are populated correctly. + # Test if the metadata and associated files are populated correctly. dst_model_file = self.create_tempfile().full_path with open(dst_model_file, "wb") as f: f.write(populator_dst.get_model_buffer()) @@ -575,6 +575,28 @@ class MetadataPopulatorTest(MetadataTest): recorded_files = populator_dst.get_recorded_associated_file_list() self.assertEqual(set(recorded_files), set(self.expected_recorded_files)) + def testLoadMetadataAndAssociatedFilesShouldSucceedOnEmptyMetadata(self): + # When the user hasn't specified the metadata, but only the associated + # files, an empty metadata buffer is created. Previously, it caused an + # exception when reading. + + # Create a source model with two associated files but no metadata. + src_model_buf = self._create_model_buf() + populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf) + populator_src.load_associated_files([self._file1, self._file2]) + populator_src.populate() + + # Create a model to be populated with the files from `src_model_buf`. + dst_model_buf = self._create_model_buf() + populator_dst = _metadata.MetadataPopulator.with_model_buffer(dst_model_buf) + populator_dst.load_metadata_and_associated_files( + populator_src.get_model_buffer()) + populator_dst.populate() + + # Test if the metadata and associated files are populated correctly. + packed_files = populator_dst.get_packed_associated_file_list() + self.assertEqual(set(packed_files), set(self.expected_recorded_files)) + @parameterized.named_parameters( { "testcase_name": "InputTensorWithBert", From 29484702cef7908881262579382f4f4f8055170f Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 18 Jan 2023 08:00:48 -0800 Subject: [PATCH 330/346] Add `process_timestamp_bounds` into RectToRenderScaleCalculatorOptions. PiperOrigin-RevId: 502877541 --- mediapipe/calculators/util/rect_to_render_scale_calculator.cc | 4 +++- .../calculators/util/rect_to_render_scale_calculator.proto | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc index 6ff6b3d51..85ed1db72 100644 --- a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc @@ -80,7 +80,9 @@ absl::Status RectToRenderScaleCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kNormRectTag).Set(); cc->Inputs().Tag(kImageSizeTag).Set>(); cc->Outputs().Tag(kRenderScaleTag).Set(); - + cc->SetProcessTimestampBounds( + cc->Options() + .process_timestamp_bounds()); return absl::OkStatus(); } diff --git a/mediapipe/calculators/util/rect_to_render_scale_calculator.proto b/mediapipe/calculators/util/rect_to_render_scale_calculator.proto index dda6e2c9c..377b12412 100644 --- a/mediapipe/calculators/util/rect_to_render_scale_calculator.proto +++ b/mediapipe/calculators/util/rect_to_render_scale_calculator.proto @@ -29,4 +29,8 @@ message RectToRenderScaleCalculatorOptions { // when actual object size on the image will be `B`, than all RenderData // primitives will be scaled with factor `B/A`. optional float multiplier = 1 [default = 0.01]; + + // When true, Process is called for every new timestamp bound, with or without + // new packets. + optional bool process_timestamp_bounds = 2 [default = false]; } From 5687d19dec64dbea7ec70337ea67dd015d366d77 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 18 Jan 2023 09:06:48 -0800 Subject: [PATCH 331/346] Tensor: remove unused and unimplemented SetPreferredStorageType methods. PiperOrigin-RevId: 502893019 --- mediapipe/framework/formats/tensor.h | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 4a952ae09..fe0be31d1 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -370,13 +370,6 @@ class Tensor { bool ready_as_opengl_texture_2d() const { return valid_ & kValidOpenGlTexture2d; } - // Sets the type of underlying resource that is going to be allocated. - enum class StorageType { - kDefault, - kAhwb, - }; - static void SetPreferredStorageType(StorageType type); - static StorageType GetPreferredStorageType(); private: void Move(Tensor*); From e56fa8f258dbc32458e595ecca8043e7a8aeb893 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 18 Jan 2023 10:59:56 -0800 Subject: [PATCH 332/346] Source/SideSource -> Stream/SidePacket PiperOrigin-RevId: 502923931 --- mediapipe/framework/api2/builder_test.cc | 50 ++++++++++++------------ 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index b01c2b759..08f4f0ca1 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -53,20 +53,20 @@ TEST(BuilderTest, BuildGraph) { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } -TEST(BuilderTest, CopyableSource) { +TEST(BuilderTest, CopyableStream) { Graph graph; - Source a = graph.In("A").SetName("a").Cast(); - Source b = graph.In("B").SetName("b").Cast(); - SideSource side_a = + Stream a = graph.In("A").SetName("a").Cast(); + Stream b = graph.In("B").SetName("b").Cast(); + SidePacket side_a = graph.SideIn("SIDE_A").SetName("side_a").Cast(); - SideSource side_b = + SidePacket side_b = graph.SideIn("SIDE_B").SetName("side_b").Cast(); Destination out = graph.Out("OUT").Cast(); SideDestination side_out = graph.SideOut("SIDE_OUT").Cast(); - Source input = a; + Stream input = a; input = b; - SideSource side_input = side_b; + SidePacket side_input = side_b; side_input = side_a; input >> out; @@ -87,23 +87,23 @@ TEST(BuilderTest, CopyableSource) { TEST(BuilderTest, BuildGraphWithFunctions) { Graph graph; - Source base = graph.In("IN").SetName("base").Cast(); - SideSource side = graph.SideIn("SIDE").SetName("side").Cast(); + Stream base = graph.In("IN").SetName("base").Cast(); + SidePacket side = graph.SideIn("SIDE").SetName("side").Cast(); - auto foo_fn = [](Source base, SideSource side, Graph& graph) { + auto foo_fn = [](Stream base, SidePacket side, Graph& graph) { auto& foo = graph.AddNode("Foo"); base >> foo.In("BASE"); side >> foo.SideIn("SIDE"); return foo.Out("OUT")[0].Cast(); }; - Source foo_out = foo_fn(base, side, graph); + Stream foo_out = foo_fn(base, side, graph); - auto bar_fn = [](Source in, Graph& graph) { + auto bar_fn = [](Stream in, Graph& graph) { auto& bar = graph.AddNode("Bar"); in >> bar.In("IN"); return bar.Out("OUT")[0].Cast(); }; - Source bar_out = bar_fn(foo_out, graph); + Stream bar_out = bar_fn(foo_out, graph); bar_out.SetName("out") >> graph.Out("OUT"); @@ -375,26 +375,26 @@ class AnyAndSameTypeCalculator : public NodeIntf { TEST(BuilderTest, AnyAndSameTypeHandledProperly) { Graph graph; - Source any_input = graph.In("GRAPH_ANY_INPUT"); - Source int_input = graph.In("GRAPH_INT_INPUT").Cast(); + Stream any_input = graph.In("GRAPH_ANY_INPUT"); + Stream int_input = graph.In("GRAPH_INT_INPUT").Cast(); auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; int_input >> node[AnyAndSameTypeCalculator::kIntInput]; - Source any_type_output = + Stream any_type_output = node[AnyAndSameTypeCalculator::kAnyTypeOutput]; any_type_output.SetName("any_type_output"); - Source same_type_output = + Stream same_type_output = node[AnyAndSameTypeCalculator::kSameTypeOutput]; same_type_output.SetName("same_type_output"); - Source recursive_same_type_output = + Stream recursive_same_type_output = node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput]; recursive_same_type_output.SetName("recursive_same_type_output"); - Source same_int_output = node[AnyAndSameTypeCalculator::kSameIntOutput]; + Stream same_int_output = node[AnyAndSameTypeCalculator::kSameIntOutput]; same_int_output.SetName("same_int_output"); - Source recursive_same_int_type_output = + Stream recursive_same_int_type_output = node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput]; recursive_same_int_type_output.SetName("recursive_same_int_type_output"); @@ -418,12 +418,12 @@ TEST(BuilderTest, AnyAndSameTypeHandledProperly) { TEST(BuilderTest, AnyTypeCanBeCast) { Graph graph; - Source any_input = + Stream any_input = graph.In("GRAPH_ANY_INPUT").Cast(); auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; - Source any_type_output = + Stream any_type_output = node[AnyAndSameTypeCalculator::kAnyTypeOutput] .SetName("any_type_output") .Cast(); @@ -462,7 +462,7 @@ TEST(BuilderTest, MultiPortIsCastToMultiPort) { TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) { Graph graph; MultiSource any_multi_input = graph.In("ANY_INPUT"); - Source any_input = any_multi_input; + Stream any_input = any_multi_input; MultiDestination any_multi_output = graph.Out("ANY_OUTPUT"); Destination any_output = any_multi_output; any_input >> any_output; @@ -477,8 +477,8 @@ TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) { TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) { Graph graph; - Source int_input = graph.In("INT_INPUT").Cast(); - Source any_input = graph.In("ANY_OUTPUT"); + Stream int_input = graph.In("INT_INPUT").Cast(); + Stream any_input = graph.In("ANY_OUTPUT"); Destination int_output = graph.Out("INT_OUTPUT").Cast(); Destination any_output = graph.Out("ANY_OUTPUT"); int_input >> int_output; From 66634bbef88c390ccd5c85774b839a81ea73240f Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Wed, 18 Jan 2023 16:36:09 -0800 Subject: [PATCH 333/346] Internal change PiperOrigin-RevId: 503011674 --- mediapipe/framework/tool/switch/BUILD | 34 +++++++ .../framework/tool/switch/packet_processor.h | 88 +++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 mediapipe/framework/tool/switch/BUILD create mode 100644 mediapipe/framework/tool/switch/packet_processor.h diff --git a/mediapipe/framework/tool/switch/BUILD b/mediapipe/framework/tool/switch/BUILD new file mode 100644 index 000000000..62f9095ef --- /dev/null +++ b/mediapipe/framework/tool/switch/BUILD @@ -0,0 +1,34 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +licenses(["notice"]) + +package(default_visibility = ["//visibility:private"]) + +cc_library( + name = "packet_processor", + hdrs = ["packet_processor.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_contract", + "//mediapipe/framework:collection_item_id", + "//mediapipe/framework:packet", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], +) diff --git a/mediapipe/framework/tool/switch/packet_processor.h b/mediapipe/framework/tool/switch/packet_processor.h new file mode 100644 index 000000000..1789a46c5 --- /dev/null +++ b/mediapipe/framework/tool/switch/packet_processor.h @@ -0,0 +1,88 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_TOOL_PACKET_PROCESSOR_H_ +#define MEDIAPIPE_FRAMEWORK_TOOL_PACKET_PROCESSOR_H_ + +#include + +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +// PacketConsumer accepts several tagged streams of packets. +class PacketConsumer { + public: + virtual ~PacketConsumer() = default; + + // Accepts a tagged input packet. + virtual absl::Status AddPacket(CollectionItemId id, Packet packet) = 0; + + // Returns the id for each input tag. + virtual std::shared_ptr InputTags() = 0; +}; + +// PacketConsumer delivers several tagged streams of packets. +class PacketProducer { + public: + virtual ~PacketProducer() = default; + + // Connects a consumer to recieve packets from this producer. + virtual void SetConsumer(PacketConsumer* consumer) = 0; +}; + +// SidePacketConsumer accepts several tagged constant packets. +class SidePacketConsumer { + public: + virtual ~SidePacketConsumer() = default; + + // Accepts a tagged input side-packet. + virtual absl::Status SetSidePacket(CollectionItemId id, Packet packet) = 0; + + // Returns the id for each input side-packet tag. + virtual std::shared_ptr SideInputTags() = 0; +}; + +// SidePacketProducer deleivers several tagged constant packets. +class SidePacketProducer { + public: + virtual ~SidePacketProducer() = default; + + // Connects a consumer to recieve packets from this producer. + virtual void SetSideConsumer(SidePacketConsumer* consumer) = 0; +}; + +// PacketProcessor consumes and produces packet streams and constant packets. +class PacketProcessor : public PacketConsumer, + public PacketProducer, + public SidePacketConsumer, + public SidePacketProducer { + public: + virtual ~PacketProcessor() = default; + + // Activate this PacketProcessor. + virtual absl::Status Start() = 0; + + // Block until this PacketProcessor has no remaining work to do. + virtual absl::Status WaitUntilIdle() = 0; + + // Deactivate this PacketProcessor. + virtual absl::Status Shutdown() = 0; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_TOOL_PACKET_PROCESSOR_H_ From 97af47ebf55e910b5c2125cba2f878e396be1b14 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Wed, 18 Jan 2023 18:51:17 -0800 Subject: [PATCH 334/346] Internal change PiperOrigin-RevId: 503035081 --- mediapipe/framework/tool/switch/BUILD | 26 +++++ .../framework/tool/switch/graph_processor.cc | 110 ++++++++++++++++++ .../framework/tool/switch/graph_processor.h | 59 ++++++++++ .../framework/tool/switch/packet_processor.h | 2 +- 4 files changed, 196 insertions(+), 1 deletion(-) create mode 100644 mediapipe/framework/tool/switch/graph_processor.cc create mode 100644 mediapipe/framework/tool/switch/graph_processor.h diff --git a/mediapipe/framework/tool/switch/BUILD b/mediapipe/framework/tool/switch/BUILD index 62f9095ef..e7a3ba741 100644 --- a/mediapipe/framework/tool/switch/BUILD +++ b/mediapipe/framework/tool/switch/BUILD @@ -32,3 +32,29 @@ cc_library( "//mediapipe/framework/port:status", ], ) + +cc_library( + name = "graph_processor", + srcs = ["graph_processor.cc"], + hdrs = ["graph_processor.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":packet_processor", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection_item_id", + "//mediapipe/framework:input_stream_shard", + "//mediapipe/framework:output_stream_shard", + "//mediapipe/framework:validated_graph_config", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "@com_google_absl//absl/synchronization", + ], + alwayslink = 1, +) diff --git a/mediapipe/framework/tool/switch/graph_processor.cc b/mediapipe/framework/tool/switch/graph_processor.cc new file mode 100644 index 000000000..f35730761 --- /dev/null +++ b/mediapipe/framework/tool/switch/graph_processor.cc @@ -0,0 +1,110 @@ +#include "mediapipe/framework/tool/switch/graph_processor.h" + +#include "absl/synchronization/mutex.h" + +namespace mediapipe { + +// TODO: add support for input and output side packets. +absl::Status GraphProcessor::Initialize(CalculatorGraphConfig graph_config) { + graph_config_ = graph_config; + + ASSIGN_OR_RETURN(graph_input_map_, + tool::TagMap::Create(graph_config_.input_stream())); + ASSIGN_OR_RETURN(graph_output_map_, + tool::TagMap::Create(graph_config_.output_stream())); + return absl::OkStatus(); +} + +absl::Status GraphProcessor::AddPacket(CollectionItemId id, Packet packet) { + absl::MutexLock lock(&graph_mutex_); + const std::string& stream_name = graph_input_map_->Names().at(id.value()); + return graph_->AddPacketToInputStream(stream_name, packet); +} + +std::shared_ptr GraphProcessor::InputTags() { + return graph_input_map_; +} + +absl::Status GraphProcessor::SendPacket(CollectionItemId id, Packet packet) { + MP_RETURN_IF_ERROR(WaitUntilInitialized()); + auto it = consumer_ids_.find(id); + if (it == consumer_ids_.end()) { + return absl::NotFoundError( + absl::StrCat("Consumer stream not found: ", id.value())); + } + return consumer_->AddPacket(it->second, packet); +} + +void GraphProcessor::SetConsumer(PacketConsumer* consumer) { + absl::MutexLock lock(&graph_mutex_); + consumer_ = consumer; + auto input_map = consumer_->InputTags(); + for (auto id = input_map->BeginId(); id != input_map->EndId(); ++id) { + auto tag_index = input_map->TagAndIndexFromId(id); + auto stream_id = graph_input_map_->GetId(tag_index.first, tag_index.second); + consumer_ids_[stream_id] = id; + } +} + +absl::Status GraphProcessor::ObserveGraph() { + for (auto id = graph_output_map_->BeginId(); id != graph_output_map_->EndId(); + ++id) { + std::string stream_name = graph_output_map_->Names().at(id.value()); + MP_RETURN_IF_ERROR(graph_->ObserveOutputStream( + stream_name, + [this, id](const Packet& packet) { return SendPacket(id, packet); }, + true)); + } + return absl::OkStatus(); +} + +absl::Status GraphProcessor::WaitUntilInitialized() { + absl::MutexLock lock(&graph_mutex_); + auto is_initialized = [this]() ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_) { + return graph_ != nullptr && consumer_ != nullptr; + }; + graph_mutex_.AwaitWithTimeout(absl::Condition(&is_initialized), + absl::Seconds(4)); + RET_CHECK(is_initialized()) << "GraphProcessor initialization timed out."; + return absl::OkStatus(); +} + +absl::Status GraphProcessor::Start() { + absl::MutexLock lock(&graph_mutex_); + graph_ = std::make_unique(); + + // The graph is validated here with its specified inputs and output. + MP_RETURN_IF_ERROR(graph_->Initialize(graph_config_, side_packets_)); + MP_RETURN_IF_ERROR(ObserveGraph()); + MP_RETURN_IF_ERROR(graph_->StartRun({})); + return absl::OkStatus(); +} + +absl::Status GraphProcessor::Shutdown() { + absl::MutexLock lock(&graph_mutex_); + if (!graph_) { + return absl::OkStatus(); + } + MP_RETURN_IF_ERROR(graph_->CloseAllPacketSources()); + MP_RETURN_IF_ERROR(graph_->WaitUntilDone()); + graph_ = nullptr; + return absl::OkStatus(); +} + +absl::Status GraphProcessor::WaitUntilIdle() { + absl::MutexLock lock(&graph_mutex_); + return graph_->WaitUntilIdle(); +} + +// TODO +absl::Status GraphProcessor::SetSidePacket(CollectionItemId id, Packet packet) { + return absl::OkStatus(); +} +// TODO +std::shared_ptr GraphProcessor::SideInputTags() { + return nullptr; +} +// TODO +void GraphProcessor::SetSideConsumer(SidePacketConsumer* consumer) {} + +} // namespace mediapipe diff --git a/mediapipe/framework/tool/switch/graph_processor.h b/mediapipe/framework/tool/switch/graph_processor.h new file mode 100644 index 000000000..e2220b5dc --- /dev/null +++ b/mediapipe/framework/tool/switch/graph_processor.h @@ -0,0 +1,59 @@ +#ifndef MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_ +#define MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_ + +#include + +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/tool/switch/packet_processor.h" + +namespace mediapipe { + +// Processes MediaPipe Packets using a MediaPipe CalculatorGraph. +class GraphProcessor : public PacketProcessor { + public: + GraphProcessor() = default; + + // Configures this GraphProcessor to create a run a CalculatorGraph. + absl::Status Initialize(CalculatorGraphConfig graph_config); + + public: + // The PacketProcessor interface. + absl::Status AddPacket(CollectionItemId id, Packet packet) override; + std::shared_ptr InputTags() override; + absl::Status SetSidePacket(CollectionItemId id, Packet packet) override; + std::shared_ptr SideInputTags() override; + void SetConsumer(PacketConsumer* consumer) override; + void SetSideConsumer(SidePacketConsumer* consumer) override; + absl::Status Start() override; + absl::Status Shutdown() override; + absl::Status WaitUntilIdle() override; + + private: + // Sends a tagged output packet. + absl::Status SendPacket(CollectionItemId id, Packet packet); + + // Observes output packets from the calculator graph. + absl::Status ObserveGraph() ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_); + + // Blocks until this GraphProcessor is initialized. + absl::Status WaitUntilInitialized(); + + private: + CalculatorGraphConfig graph_config_; + std::shared_ptr graph_input_map_; + std::shared_ptr graph_output_map_; + std::map consumer_ids_; + + PacketConsumer* consumer_ = nullptr; + std::map side_packets_; + std::unique_ptr graph_ ABSL_GUARDED_BY(graph_mutex_) = + nullptr; + absl::Mutex graph_mutex_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_ diff --git a/mediapipe/framework/tool/switch/packet_processor.h b/mediapipe/framework/tool/switch/packet_processor.h index 1789a46c5..d97883c53 100644 --- a/mediapipe/framework/tool/switch/packet_processor.h +++ b/mediapipe/framework/tool/switch/packet_processor.h @@ -56,7 +56,7 @@ class SidePacketConsumer { virtual std::shared_ptr SideInputTags() = 0; }; -// SidePacketProducer deleivers several tagged constant packets. +// SidePacketProducer delivers several tagged constant packets. class SidePacketProducer { public: virtual ~SidePacketProducer() = default; From e2dedcbfe569d4a33ad24ac77fee51a2ed53d5b2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 18 Jan 2023 19:40:19 -0800 Subject: [PATCH 335/346] Add SQRT_HANN window type to both SpectrogramCalculator and InverseSpectrogramCalculator. PiperOrigin-RevId: 503041493 --- mediapipe/calculators/audio/spectrogram_calculator.cc | 7 +++++++ mediapipe/calculators/audio/spectrogram_calculator.proto | 1 + 2 files changed, 8 insertions(+) diff --git a/mediapipe/calculators/audio/spectrogram_calculator.cc b/mediapipe/calculators/audio/spectrogram_calculator.cc index c038c0cd7..bd4d8f3bf 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.cc +++ b/mediapipe/calculators/audio/spectrogram_calculator.cc @@ -280,6 +280,13 @@ absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) { audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_, &window); break; + case SpectrogramCalculatorOptions::SQRT_HANN: { + audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_, + &window); + absl::c_transform(window, window.begin(), + [](double x) { return std::sqrt(x); }); + break; + } } // Propagate settings down to the actual Spectrogram object. diff --git a/mediapipe/calculators/audio/spectrogram_calculator.proto b/mediapipe/calculators/audio/spectrogram_calculator.proto index 8e1e18051..ddfca1d1c 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.proto +++ b/mediapipe/calculators/audio/spectrogram_calculator.proto @@ -68,6 +68,7 @@ message SpectrogramCalculatorOptions { HANN = 0; HAMMING = 1; COSINE = 2; + SQRT_HANN = 4; } optional WindowType window_type = 6 [default = HANN]; From 7a7cc77a8154c6ac873763d39e54b14ae4de403a Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Thu, 19 Jan 2023 07:17:40 -0800 Subject: [PATCH 336/346] Internal change PiperOrigin-RevId: 503157344 --- .../unpack_media_sequence_calculator_test.cc | 2 +- .../framework/calculator_context_test.cc | 4 ++-- mediapipe/framework/port/proto_ns.h | 5 +++-- .../framework/profiler/graph_profiler_test.cc | 18 ++++++++------- .../framework/tool/options_lib_template.cc | 2 +- mediapipe/framework/tool/options_registry.cc | 22 ++++++++++--------- mediapipe/framework/tool/options_registry.h | 2 +- 7 files changed, 30 insertions(+), 25 deletions(-) diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc index d8562ffc4..fbf775403 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc @@ -647,7 +647,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptionsOverride) { TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) { // TODO: Suport proto3 proto.Any in CalculatorOptions. - // TODO: Avoid proto2 extensions in "RESAMPLER_OPTIONS". + // TODO: Avoid google::protobuf extensions in "RESAMPLER_OPTIONS". CalculatorOptions options; options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) ->set_padding_before_label(1); diff --git a/mediapipe/framework/calculator_context_test.cc b/mediapipe/framework/calculator_context_test.cc index e7612501a..be9103b4d 100644 --- a/mediapipe/framework/calculator_context_test.cc +++ b/mediapipe/framework/calculator_context_test.cc @@ -131,10 +131,10 @@ TEST(CalculatorTest, GetOptions) { auto calculator_state_3 = MakeCalculatorState(config.node(3), 3); auto cc_3 = MakeCalculatorContext(&*calculator_state_3); - // Get a proto2 options extension from Node::options. + // Get a google::protobuf options extension from Node::options. EXPECT_EQ(cc_0->Options().jitter(), 0.123); - // Get a proto2 options extension from Node::node_options. + // Get a google::protobuf options extension from Node::node_options. EXPECT_EQ(cc_1->Options().jitter(), 0.123); // Get a proto3 options protobuf::Any from Node::node_options. diff --git a/mediapipe/framework/port/proto_ns.h b/mediapipe/framework/port/proto_ns.h index 83aecdf49..53b854ff7 100644 --- a/mediapipe/framework/port/proto_ns.h +++ b/mediapipe/framework/port/proto_ns.h @@ -17,8 +17,9 @@ #include -// Temporary forward declarations for proto2 support on portable targets. -// Use proto_ns inside namespace mediapipe instead of proto2 namespace. +// Temporary forward declarations for google::protobuf support on portable +// targets. Use proto_ns inside namespace mediapipe instead of google::protobuf +// namespace. #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" #include "google/protobuf/repeated_field.h" diff --git a/mediapipe/framework/profiler/graph_profiler_test.cc b/mediapipe/framework/profiler/graph_profiler_test.cc index 75d1c7ebd..e9badaa25 100644 --- a/mediapipe/framework/profiler/graph_profiler_test.cc +++ b/mediapipe/framework/profiler/graph_profiler_test.cc @@ -39,13 +39,15 @@ constexpr char kDummyTestCalculatorName[] = "DummyTestCalculator"; CalculatorGraphConfig::Node CreateNodeConfig( const std::string& raw_node_config) { CalculatorGraphConfig::Node node_config; - QCHECK(proto2::TextFormat::ParseFromString(raw_node_config, &node_config)); + QCHECK(google::protobuf::TextFormat::ParseFromString(raw_node_config, + &node_config)); return node_config; } CalculatorGraphConfig CreateGraphConfig(const std::string& raw_graph_config) { CalculatorGraphConfig graph_config; - QCHECK(proto2::TextFormat::ParseFromString(raw_graph_config, &graph_config)); + QCHECK(google::protobuf::TextFormat::ParseFromString(raw_graph_config, + &graph_config)); return graph_config; } @@ -1167,7 +1169,7 @@ TEST_F(GraphProfilerTestPeer, AddProcessSampleWithStreamLatency) { TEST(GraphProfilerTest, ParallelReads) { // A graph that processes a certain number of packets before finishing. CalculatorGraphConfig config; - QCHECK(proto2::TextFormat::ParseFromString(R"( + QCHECK(google::protobuf::TextFormat::ParseFromString(R"( profiler_config { enable_profiler: true } @@ -1189,7 +1191,7 @@ TEST(GraphProfilerTest, ParallelReads) { } output_stream: "OUT:0:the_integers" )", - &config)); + &config)); // Start running the graph on its own threads. absl::Mutex out_1_mutex; @@ -1246,7 +1248,7 @@ std::set GetCalculatorNames(const CalculatorGraphConfig& config) { TEST(GraphProfilerTest, CalculatorProfileFilter) { CalculatorGraphConfig config; - QCHECK(proto2::TextFormat::ParseFromString(R"( + QCHECK(google::protobuf::TextFormat::ParseFromString(R"( profiler_config { enable_profiler: true } @@ -1268,7 +1270,7 @@ TEST(GraphProfilerTest, CalculatorProfileFilter) { } output_stream: "OUT:0:the_integers" )", - &config)); + &config)); std::set expected_names; expected_names = {"RangeCalculator", "PassThroughCalculator"}; @@ -1295,7 +1297,7 @@ TEST(GraphProfilerTest, CalculatorProfileFilter) { TEST(GraphProfilerTest, CaptureProfilePopulateConfig) { CalculatorGraphConfig config; - QCHECK(proto2::TextFormat::ParseFromString(R"( + QCHECK(google::protobuf::TextFormat::ParseFromString(R"( profiler_config { enable_profiler: true trace_enabled: true @@ -1310,7 +1312,7 @@ TEST(GraphProfilerTest, CaptureProfilePopulateConfig) { input_stream: "input_stream" } )", - &config)); + &config)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); GraphProfile profile; diff --git a/mediapipe/framework/tool/options_lib_template.cc b/mediapipe/framework/tool/options_lib_template.cc index 21a5db10f..4861132a2 100644 --- a/mediapipe/framework/tool/options_lib_template.cc +++ b/mediapipe/framework/tool/options_lib_template.cc @@ -28,7 +28,7 @@ constexpr char kDescriptorContents[] = mediapipe::FieldData ReadFileDescriptorSet(const std::string& pb) { mediapipe::FieldData result; *result.mutable_message_value()->mutable_type_url() = - "proto2.FileDescriptorSet"; + "google::protobuf.FileDescriptorSet"; *result.mutable_message_value()->mutable_value() = pb; // Force linking of the generated options protobuf. diff --git a/mediapipe/framework/tool/options_registry.cc b/mediapipe/framework/tool/options_registry.cc index f6858be0a..07cc65a95 100644 --- a/mediapipe/framework/tool/options_registry.cc +++ b/mediapipe/framework/tool/options_registry.cc @@ -66,26 +66,28 @@ std::string GetFieldString(const FieldData& message_data, void RegisterDescriptorProtos( absl::flat_hash_map& result) { std::vector descriptors = { - {"proto2.FileDescriptorSet", + {"google::protobuf.FileDescriptorSet", { - {"file", 1, FieldType::TYPE_MESSAGE, "proto2.FileDescriptorProto"}, + {"file", 1, FieldType::TYPE_MESSAGE, + "google::protobuf.FileDescriptorProto"}, }}, - {"proto2.FileDescriptorProto", + {"google::protobuf.FileDescriptorProto", { {"package", 2, FieldType::TYPE_STRING, ""}, {"message_type", 4, FieldType::TYPE_MESSAGE, - "proto2.DescriptorProto"}, + "google::protobuf.DescriptorProto"}, }}, - {"proto2.DescriptorProto", + {"google::protobuf.DescriptorProto", { {"name", 1, FieldType::TYPE_STRING, ""}, - {"field", 2, FieldType::TYPE_MESSAGE, "proto2.FieldDescriptorProto"}, + {"field", 2, FieldType::TYPE_MESSAGE, + "google::protobuf.FieldDescriptorProto"}, {"extension", 6, FieldType::TYPE_MESSAGE, - "proto2.FieldDescriptorProto"}, + "google::protobuf.FieldDescriptorProto"}, {"nested_type", 3, FieldType::TYPE_MESSAGE, - "proto2.DescriptorProto"}, + "google::protobuf.DescriptorProto"}, }}, - {"proto2.FieldDescriptorProto", + {"google::protobuf.FieldDescriptorProto", { {"name", 1, FieldType::TYPE_STRING, ""}, {"number", 3, FieldType::TYPE_INT32, ""}, @@ -140,7 +142,7 @@ void OptionsRegistry::Register(const FieldData& message_type, const Descriptor* OptionsRegistry::GetProtobufDescriptor( const std::string& type_name) { - if (descriptors().count("proto2.DescriptorProto") == 0) { + if (descriptors().count("google::protobuf.DescriptorProto") == 0) { RegisterDescriptorProtos(descriptors()); } absl::ReaderMutexLock lock(&mutex()); diff --git a/mediapipe/framework/tool/options_registry.h b/mediapipe/framework/tool/options_registry.h index b843b113a..3b2d2be89 100644 --- a/mediapipe/framework/tool/options_registry.h +++ b/mediapipe/framework/tool/options_registry.h @@ -28,7 +28,7 @@ class OptionsRegistry { // Finds the descriptor for a protobuf. static const Descriptor* GetProtobufDescriptor(const std::string& type_name); - // Returns all known proto2 extensions to a type. + // Returns all known google::protobuf extensions to a type. static void FindAllExtensions(absl::string_view extendee, std::vector* result); From dcd2adad532f3f65703a7c387f182090a1229c51 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 19 Jan 2023 09:17:39 -0800 Subject: [PATCH 337/346] Removing broken links. They might not be relevant since we only support TfLite models. PiperOrigin-RevId: 503183358 --- docs/solutions/models.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/solutions/models.md b/docs/solutions/models.md index 18bcf0c8b..325c41f1b 100644 --- a/docs/solutions/models.md +++ b/docs/solutions/models.md @@ -94,8 +94,6 @@ one over the other. * [TFLite model](https://storage.googleapis.com/mediapipe-assets/ssdlite_object_detection.tflite) * [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/object-detector-quantized_edgetpu.tflite) -* [TensorFlow model](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model) -* [Model information](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md) ### [Objectron](https://google.github.io/mediapipe/solutions/objectron) From a02097ea083cc318d33edc236a3824a0d50002a8 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 19 Jan 2023 10:06:42 -0800 Subject: [PATCH 338/346] Fix comments PiperOrigin-RevId: 503195768 --- mediapipe/tasks/web/audio/index.ts | 2 +- mediapipe/tasks/web/text/index.ts | 2 +- mediapipe/tasks/web/vision/index.ts | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index 44fa7eb25..e7465878b 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -18,7 +18,7 @@ import {AudioClassifier as AudioClassifierImpl} from '../../../tasks/web/audio/a import {AudioEmbedder as AudioEmbedderImpl} from '../../../tasks/web/audio/audio_embedder/audio_embedder'; import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; -// Declare the variables locally so that Rollup in OSS includes them explcilty +// Declare the variables locally so that Rollup in OSS includes them explicitly // as exports. const AudioClassifier = AudioClassifierImpl; const AudioEmbedder = AudioEmbedderImpl; diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index 2c9e6fead..cfa990e58 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -18,7 +18,7 @@ import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fi import {TextClassifier as TextClassifierImpl} from '../../../tasks/web/text/text_classifier/text_classifier'; import {TextEmbedder as TextEmbedderImpl} from '../../../tasks/web/text/text_embedder/text_embedder'; -// Declare the variables locally so that Rollup in OSS includes them explcilty +// Declare the variables locally so that Rollup in OSS includes them explicitly // as exports. const FilesetResolver = FilesetResolverImpl; const TextClassifier = TextClassifierImpl; diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index e13f8183f..49f23c243 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -21,7 +21,7 @@ import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/ import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder'; import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector'; -// Declare the variables locally so that Rollup in OSS includes them explcilty +// Declare the variables locally so that Rollup in OSS includes them explicitly // as exports. const FilesetResolver = FilesetResolverImpl; const GestureRecognizer = GestureRecognizerImpl; From db1a89324e6ffc100bf7723fbfaf2673b4f36ecc Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 19 Jan 2023 10:39:05 -0800 Subject: [PATCH 339/346] Add mediapipe::Image output to the graph runner PiperOrigin-RevId: 503204918 --- .../graph_runner/graph_runner_image_lib.ts | 87 ++++++++++++++++--- 1 file changed, 74 insertions(+), 13 deletions(-) diff --git a/mediapipe/web/graph_runner/graph_runner_image_lib.ts b/mediapipe/web/graph_runner/graph_runner_image_lib.ts index 7a4ea09e2..9608ebcc7 100644 --- a/mediapipe/web/graph_runner/graph_runner_image_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner_image_lib.ts @@ -1,4 +1,6 @@ -import {ImageSource, GraphRunner} from './graph_runner'; +import {GraphRunner, ImageSource} from './graph_runner'; + + /** * We extend from a GraphRunner constructor. This ensures our mixin has @@ -8,6 +10,12 @@ import {ImageSource, GraphRunner} from './graph_runner'; // tslint:disable-next-line:no-any type LibConstructor = new (...args: any[]) => GraphRunner; +/** An image returned from a MediaPipe graph. */ +export interface WasmImage { + data: Uint8Array|Float32Array; + width: number; + height: number; +} /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler * doesn't break our JS/C++ bridge. @@ -16,26 +24,33 @@ export declare interface WasmImageModule { _addBoundTextureAsImageToStream: (streamNamePtr: number, width: number, height: number, timestamp: number) => void; + _attachImageListener: (streamNamePtr: number) => void; + _attachImageVectorListener: (streamNamePtr: number) => void; } /** * An implementation of GraphRunner that supports binding GPU image data as - * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for - * effective multiple inheritance. Example usage: - * `const GraphRunnerImageLib = SupportImage(GraphRunner);` + * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow + * for effective multiple inheritance. Example usage: `const GraphRunnerImageLib + * = SupportImage(GraphRunner);` */ // tslint:disable-next-line:enforce-name-casing export function SupportImage(Base: TBase) { return class extends Base { + get wasmImageModule(): WasmImageModule { + return this.wasmModule as unknown as WasmImageModule; + } + /** - * Takes the relevant information from the HTML video or image element, and - * passes it into the WebGL-based graph for processing on the given stream - * at the given timestamp as a MediaPipe image. Processing will not occur - * until a blocking call (like processVideoGl or finishProcessing) is made. + * Takes the relevant information from the HTML video or image element, + * and passes it into the WebGL-based graph for processing on the given + * stream at the given timestamp as a MediaPipe image. Processing will not + * occur until a blocking call (like processVideoGl or finishProcessing) + * is made. * @param imageSource Reference to the video frame we wish to add into our * graph. - * @param streamName The name of the MediaPipe graph stream to add the frame - * to. + * @param streamName The name of the MediaPipe graph stream to add the + * frame to. * @param timestamp The timestamp of the input frame, in ms. */ addGpuBufferAsImageToStream( @@ -43,9 +58,55 @@ export function SupportImage(Base: TBase) { this.wrapStringPtr(streamName, (streamNamePtr: number) => { const [width, height] = this.bindTextureToStream(imageSource, streamNamePtr); - (this.wasmModule as unknown as WasmImageModule) - ._addBoundTextureAsImageToStream( - streamNamePtr, width, height, timestamp); + this.wasmImageModule._addBoundTextureAsImageToStream( + streamNamePtr, width, height, timestamp); + }); + } + + /** + * Attaches a mediapipe:Image packet listener to the specified output + * stream. + * @param outputStreamName The name of the graph output stream to grab + * mediapipe::Image data from. + * @param callbackFcn The function that will be called back with the data, + * as it is received. Note that the data is only guaranteed to exist + * for the duration of the callback, and the callback will be called + * inline, so it should not perform overly complicated (or any async) + * behavior. + */ + attachImageListener( + outputStreamName: string, + callbackFcn: (data: WasmImage, timestamp: number) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for mediapipe::Image packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmImageModule._attachImageListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a mediapipe:Image[] packet listener to the specified + * output_stream. + * @param outputStreamName The name of the graph output stream to grab + * std::vector data from. + * @param callbackFcn The function that will be called back with the data, + * as it is received. Note that the data is only guaranteed to exist + * for the duration of the callback, and the callback will be called + * inline, so it should not perform overly complicated (or any async) + * behavior. + */ + attachImageVectorListener( + outputStreamName: string, + callbackFcn: (data: WasmImage[], timestamp: number) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setVectorListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for std::vector packets on + // this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmImageModule._attachImageVectorListener(outputStreamNamePtr); }); } }; From 921b6a6befae381ba873fb61ba170d902a1c6b02 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 19 Jan 2023 22:09:55 -0800 Subject: [PATCH 340/346] This CL will fix the typo from _PALM_LANMARKS to _PALM_LANDMARKS. PiperOrigin-RevId: 503352055 --- mediapipe/python/solutions/drawing_styles.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mediapipe/python/solutions/drawing_styles.py b/mediapipe/python/solutions/drawing_styles.py index b43bca8d3..5d75d5b30 100644 --- a/mediapipe/python/solutions/drawing_styles.py +++ b/mediapipe/python/solutions/drawing_styles.py @@ -37,9 +37,10 @@ _THICKNESS_FINGER = 2 _THICKNESS_DOT = -1 # Hand landmarks -_PALM_LANMARKS = (HandLandmark.WRIST, HandLandmark.THUMB_CMC, - HandLandmark.INDEX_FINGER_MCP, HandLandmark.MIDDLE_FINGER_MCP, - HandLandmark.RING_FINGER_MCP, HandLandmark.PINKY_MCP) +_PALM_LANDMARKS = (HandLandmark.WRIST, HandLandmark.THUMB_CMC, + HandLandmark.INDEX_FINGER_MCP, + HandLandmark.MIDDLE_FINGER_MCP, HandLandmark.RING_FINGER_MCP, + HandLandmark.PINKY_MCP) _THUMP_LANDMARKS = (HandLandmark.THUMB_MCP, HandLandmark.THUMB_IP, HandLandmark.THUMB_TIP) _INDEX_FINGER_LANDMARKS = (HandLandmark.INDEX_FINGER_PIP, @@ -54,7 +55,7 @@ _RING_FINGER_LANDMARKS = (HandLandmark.RING_FINGER_PIP, _PINKY_FINGER_LANDMARKS = (HandLandmark.PINKY_PIP, HandLandmark.PINKY_DIP, HandLandmark.PINKY_TIP) _HAND_LANDMARK_STYLE = { - _PALM_LANMARKS: + _PALM_LANDMARKS: DrawingSpec( color=_RED, thickness=_THICKNESS_DOT, circle_radius=_RADIUS), _THUMP_LANDMARKS: From 1124569c29edad16e86a77e57407ca7abf0dc4a2 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Mon, 23 Jan 2023 10:58:14 -0800 Subject: [PATCH 341/346] Tensor: Make tensor not requiring "-x objective-c++" option. In this case tensor.h is compiled differently for C++ and Objective-C++ that violates ODR (once definition rule). Tensor has no virtual methods conditionally compiled but some Metal-related data members. Instead, unique_ptr to MtlResources that is declared as forward structure is unconditionally defined in the tensor class. MtlResources is defined differently in cc-file only that compiled just once per project so no ODR violation is here. PiperOrigin-RevId: 504029286 --- mediapipe/calculators/tensor/BUILD | 81 +----------- .../tensor/image_to_tensor_converter_metal.cc | 6 +- .../tensor/inference_calculator_metal.cc | 18 ++- .../tensor/tensor_converter_calculator.cc | 3 +- .../tensors_to_detections_calculator.cc | 23 ++-- .../tensors_to_segmentation_calculator.cc | 4 +- mediapipe/framework/formats/BUILD | 5 +- mediapipe/framework/formats/tensor.cc | 125 ++++++++++-------- mediapipe/framework/formats/tensor.h | 46 +------ .../formats/tensor_mtl_buffer_view.h | 61 +++++++++ .../tasks/cc/components/calculators/BUILD | 8 -- .../tasks/cc/components/processors/BUILD | 16 --- 12 files changed, 177 insertions(+), 219 deletions(-) create mode 100644 mediapipe/framework/formats/tensor_mtl_buffer_view.h diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 127280107..69d666092 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -53,14 +53,6 @@ mediapipe_proto_library( cc_library( name = "audio_to_tensor_calculator", srcs = ["audio_to_tensor_calculator.cc"], - copts = select({ - # b/215212850 - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", - ], - "//conditions:default": [], - }), deps = [ ":audio_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -161,14 +153,6 @@ mediapipe_proto_library( cc_library( name = "feedback_tensors_calculator", srcs = ["feedback_tensors_calculator.cc"], - copts = select({ - # b/215212850 - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", - ], - "//conditions:default": [], - }), deps = [ ":feedback_tensors_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -207,14 +191,6 @@ mediapipe_proto_library( cc_library( name = "bert_preprocessor_calculator", srcs = ["bert_preprocessor_calculator.cc"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ ":bert_preprocessor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -267,14 +243,6 @@ mediapipe_proto_library( cc_library( name = "regex_preprocessor_calculator", srcs = ["regex_preprocessor_calculator.cc"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ ":regex_preprocessor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -316,14 +284,6 @@ cc_test( cc_library( name = "text_to_tensor_calculator", srcs = ["text_to_tensor_calculator.cc"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", @@ -414,14 +374,6 @@ cc_library( name = "inference_calculator_interface", srcs = ["inference_calculator.cc"], hdrs = ["inference_calculator.h"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ ":inference_calculator_cc_proto", ":inference_calculator_options_lib", @@ -495,6 +447,7 @@ cc_library( tags = ["ios"], deps = [ "inference_calculator_interface", + "//mediapipe/framework/formats:tensor", "//mediapipe/gpu:MPPMetalHelper", "//mediapipe/gpu:MPPMetalUtil", "//mediapipe/gpu:gpu_buffer", @@ -513,14 +466,6 @@ cc_library( cc_library( name = "inference_runner", hdrs = ["inference_runner.h"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework/formats:tensor", @@ -532,14 +477,6 @@ cc_library( name = "inference_interpreter_delegate_runner", srcs = ["inference_interpreter_delegate_runner.cc"], hdrs = ["inference_interpreter_delegate_runner.h"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ ":inference_runner", "//mediapipe/framework:mediapipe_profiling", @@ -561,14 +498,6 @@ cc_library( srcs = [ "inference_calculator_cpu.cc", ], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ ":inference_calculator_interface", ":inference_calculator_utils", @@ -607,14 +536,6 @@ cc_library( srcs = [ "inference_calculator_xnnpack.cc", ], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ ":inference_calculator_interface", ":inference_calculator_utils", diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc index a8211d39b..354547042 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc @@ -36,6 +36,10 @@ #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/types.h" +#if MEDIAPIPE_METAL_ENABLED +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" +#endif // MEDIAPIPE_METAL_ENABLED + namespace mediapipe { namespace { @@ -376,7 +380,7 @@ class MetalProcessor : public ImageToTensorConverter { id command_buffer = [metal_helper_ commandBuffer]; const auto& buffer_view = - output_tensor.GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(output_tensor, command_buffer); MP_RETURN_IF_ERROR(extractor_->Execute( texture, roi, /*flip_horizontaly=*/false, transform.scale, transform.offset, diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc index 750f0456e..fba18a81c 100644 --- a/mediapipe/calculators/tensor/inference_calculator_metal.cc +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -24,6 +24,8 @@ #include "absl/memory/memory.h" #include "absl/strings/str_format.h" #include "mediapipe/calculators/tensor/inference_calculator.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" #import "mediapipe/gpu/MPPMetalHelper.h" #include "mediapipe/gpu/MPPMetalUtil.h" #include "mediapipe/gpu/gpu_buffer.h" @@ -150,11 +152,12 @@ absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) { command_buffer.label = @"InferenceCalculator"; // Explicit copy input with conversion float 32 bits to 16 bits. for (int i = 0; i < input_tensors.size(); ++i) { - auto input_view = input_tensors[i].GetMtlBufferReadView(command_buffer); + auto input_view = + MtlBufferView::GetReadView(input_tensors[i], command_buffer); // Reshape tensor. tflite::gpu::BHWC shape = BhwcFromTensorShape(input_tensors[i].shape()); auto gpu_buffer_view = - gpu_buffers_in_[i]->GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(*gpu_buffers_in_[i], command_buffer); id input_encoder = [command_buffer computeCommandEncoder]; [converter_to_BPHWC4_ convertWithEncoder:input_encoder @@ -174,9 +177,10 @@ absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) { output_shapes_[i]); // Reshape tensor. tflite::gpu::BHWC shape = BhwcFromTensorShape(output_shapes_[i]); - auto read_view = gpu_buffers_out_[i]->GetMtlBufferReadView(command_buffer); + auto read_view = + MtlBufferView::GetReadView(*gpu_buffers_out_[i], command_buffer); auto write_view = - output_tensors->at(i).GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(output_tensors->at(i), command_buffer); id output_encoder = [command_buffer computeCommandEncoder]; [converter_from_BPHWC4_ convertWithEncoder:output_encoder @@ -258,7 +262,7 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( : Tensor::ElementType::kFloat32, Tensor::Shape{dims})); auto buffer_view = - gpu_buffers_in_[i]->GetMtlBufferWriteView(gpu_helper_.mtlDevice); + MtlBufferView::GetWriteView(*gpu_buffers_in_[i], gpu_helper_.mtlDevice); RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( delegate_.get(), input_indices[i], buffer_view.buffer()), true); @@ -286,8 +290,8 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( Tensor::Shape{dims})); RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( delegate_.get(), output_indices[i], - gpu_buffers_out_[i] - ->GetMtlBufferWriteView(gpu_helper_.mtlDevice) + MtlBufferView::GetWriteView(*gpu_buffers_out_[i], + gpu_helper_.mtlDevice) .buffer()), true); } diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator.cc b/mediapipe/calculators/tensor/tensor_converter_calculator.cc index 0b750b859..4b05488fd 100644 --- a/mediapipe/calculators/tensor/tensor_converter_calculator.cc +++ b/mediapipe/calculators/tensor/tensor_converter_calculator.cc @@ -31,6 +31,7 @@ #import #import +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" #import "mediapipe/gpu/MPPMetalHelper.h" #elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include "mediapipe/gpu/gl_calculator_helper.h" @@ -304,7 +305,7 @@ absl::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) { id src_texture = [gpu_helper_ metalTextureWithGpuBuffer:input]; [compute_encoder setTexture:src_texture atIndex:0]; auto output_view = - output_tensors->at(0).GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(output_tensors->at(0), command_buffer); [compute_encoder setBuffer:output_view.buffer() offset:0 atIndex:1]; MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, kWorkgroupSize, 1); MTLSize threadgroups = diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index 97ef01b4c..4bb3f0f57 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -41,6 +41,7 @@ #import #import +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" #import "mediapipe/gpu/MPPMetalHelper.h" #include "mediapipe/gpu/MPPMetalUtil.h" #endif // MEDIAPIPE_METAL_ENABLED @@ -536,10 +537,11 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( if (input_tensors.size() == kNumInputTensorsWithAnchors) { RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); auto command_buffer = [gpu_helper_ commandBuffer]; - auto src_buffer = input_tensors[tensor_mapping_.anchors_tensor_index()] - .GetMtlBufferReadView(command_buffer); + auto src_buffer = MtlBufferView::GetReadView( + input_tensors[tensor_mapping_.anchors_tensor_index()], + command_buffer); auto dest_buffer = - raw_anchors_buffer_->GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(*raw_anchors_buffer_, command_buffer); id blit_command = [command_buffer blitCommandEncoder]; [blit_command copyFromBuffer:src_buffer.buffer() @@ -571,15 +573,16 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( [command_encoder setComputePipelineState:decode_program_]; { auto scored_boxes_view = - scored_boxes_buffer_->GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(*scored_boxes_buffer_, command_buffer); auto decoded_boxes_view = - decoded_boxes_buffer_->GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(*decoded_boxes_buffer_, command_buffer); [command_encoder setBuffer:decoded_boxes_view.buffer() offset:0 atIndex:0]; - auto input0_view = input_tensors[tensor_mapping_.detections_tensor_index()] - .GetMtlBufferReadView(command_buffer); + auto input0_view = MtlBufferView::GetReadView( + input_tensors[tensor_mapping_.detections_tensor_index()], + command_buffer); [command_encoder setBuffer:input0_view.buffer() offset:0 atIndex:1]; auto raw_anchors_view = - raw_anchors_buffer_->GetMtlBufferReadView(command_buffer); + MtlBufferView::GetReadView(*raw_anchors_buffer_, command_buffer); [command_encoder setBuffer:raw_anchors_view.buffer() offset:0 atIndex:2]; MTLSize decode_threads_per_group = MTLSizeMake(1, 1, 1); MTLSize decode_threadgroups = MTLSizeMake(num_boxes_, 1, 1); @@ -588,8 +591,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( [command_encoder setComputePipelineState:score_program_]; [command_encoder setBuffer:scored_boxes_view.buffer() offset:0 atIndex:0]; - auto input1_view = input_tensors[tensor_mapping_.scores_tensor_index()] - .GetMtlBufferReadView(command_buffer); + auto input1_view = MtlBufferView::GetReadView( + input_tensors[tensor_mapping_.scores_tensor_index()], command_buffer); [command_encoder setBuffer:input1_view.buffer() offset:0 atIndex:1]; MTLSize score_threads_per_group = MTLSizeMake(1, num_classes_, 1); MTLSize score_threadgroups = MTLSizeMake(num_boxes_, 1, 1); diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc index 172f70880..839451ab7 100644 --- a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc @@ -53,6 +53,7 @@ #import #import +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" #import "mediapipe/gpu/MPPMetalHelper.h" #include "mediapipe/gpu/MPPMetalUtil.h" #endif // MEDIAPIPE_METAL_ENABLED @@ -485,7 +486,8 @@ absl::Status TensorsToSegmentationCalculator::ProcessGpu( [command_buffer computeCommandEncoder]; [command_encoder setComputePipelineState:mask_program_]; - auto read_view = input_tensors[0].GetMtlBufferReadView(command_buffer); + auto read_view = + MtlBufferView::GetReadView(input_tensors[0], command_buffer); [command_encoder setBuffer:read_view.buffer() offset:0 atIndex:0]; mediapipe::GpuBuffer small_mask_buffer = [metal_helper_ diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index 371f23ed1..10aa3fca0 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -431,7 +431,10 @@ cc_library( hdrs = [ "tensor.h", "//mediapipe/framework/formats/tensor:internal.h", - ], + ] + select({ + "//mediapipe:ios": ["tensor_mtl_buffer_view.h"], + "//conditions:default": [], + }), copts = select({ "//mediapipe:apple": [ "-x objective-c++", diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index 3f11d368a..1dbd8f8ac 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -25,8 +25,11 @@ #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #if MEDIAPIPE_METAL_ENABLED +#import #include #include + +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" #else #include #endif // MEDIAPIPE_METAL_ENABLED @@ -61,6 +64,12 @@ int BhwcDepthFromShape(const Tensor::Shape& shape) { // 3) pad/"unpad" the bitmap after transfer CPU <-> GPU #if MEDIAPIPE_METAL_ENABLED +// No ODR violation here because this file compiled just once per project. +struct MtlResources { + id command_buffer = nil; + id device = nil; + id metal_buffer = nil; +}; namespace { // MTLBuffer can use existing properly aligned and allocated CPU memory. size_t AlignToPageSize(size_t size) { @@ -83,52 +92,56 @@ void DeallocateVirtualMemory(void* pointer, size_t size) { } } // namespace -Tensor::MtlBufferView Tensor::GetMtlBufferReadView( - id command_buffer) const { - LOG_IF(FATAL, valid_ == kValidNone) +void MtlBufferView::AllocateMtlBuffer(const Tensor& tensor, + id device) { + tensor.mtl_resources_->device = device; + if (!tensor.cpu_buffer_) { + // It also means that the metal buffer is not allocated yet. + tensor.cpu_buffer_ = AllocateVirtualMemory(tensor.bytes()); + } + if (!tensor.mtl_resources_->metal_buffer) { + tensor.mtl_resources_->metal_buffer = [tensor.mtl_resources_->device + newBufferWithBytesNoCopy:tensor.cpu_buffer_ + length:AlignToPageSize(tensor.bytes()) + options:MTLResourceStorageModeShared | + MTLResourceCPUCacheModeDefaultCache + deallocator:^(void* pointer, NSUInteger length) { + DeallocateVirtualMemory(pointer, length); + }]; + } +} + +MtlBufferView MtlBufferView::GetReadView(const Tensor& tensor, + id command_buffer) { + LOG_IF(FATAL, tensor.valid_ == Tensor::kValidNone) << "Tensor must be written prior to read from."; - LOG_IF(FATAL, !(valid_ & (kValidCpu | kValidMetalBuffer))) + LOG_IF(FATAL, + !(tensor.valid_ & (Tensor::kValidCpu | Tensor::kValidMetalBuffer))) << "Tensor conversion between different GPU resources is not supported " "yet."; - auto lock(absl::make_unique(&view_mutex_)); - valid_ |= kValidMetalBuffer; - AllocateMtlBuffer([command_buffer device]); - return {metal_buffer_, std::move(lock)}; + auto lock(absl::make_unique(&tensor.view_mutex_)); + tensor.valid_ |= Tensor::kValidMetalBuffer; + AllocateMtlBuffer(tensor, [command_buffer device]); + return {tensor.mtl_resources_->metal_buffer, std::move(lock)}; } -Tensor::MtlBufferView Tensor::GetMtlBufferWriteView( - id command_buffer) const { +MtlBufferView MtlBufferView::GetWriteView(const Tensor& tensor, + id command_buffer) { // Don't overwrite command buffer at which the metal buffer has been written // so we can wait until completed. - command_buffer_ = command_buffer; - return GetMtlBufferWriteView([command_buffer device]); + tensor.mtl_resources_->command_buffer = command_buffer; + return GetWriteView(tensor, [command_buffer device]); } -Tensor::MtlBufferView Tensor::GetMtlBufferWriteView( - id device) const { - auto lock(absl::make_unique(&view_mutex_)); - valid_ = kValidMetalBuffer; - AllocateMtlBuffer(device); - return {metal_buffer_, std::move(lock)}; -} - -void Tensor::AllocateMtlBuffer(id device) const { - device_ = device; - if (!cpu_buffer_) { - // It also means that the metal buffer is not allocated yet. - cpu_buffer_ = AllocateVirtualMemory(bytes()); - } - if (!metal_buffer_) { - metal_buffer_ = - [device_ newBufferWithBytesNoCopy:cpu_buffer_ - length:AlignToPageSize(bytes()) - options:MTLResourceStorageModeShared | - MTLResourceCPUCacheModeDefaultCache - deallocator:^(void* pointer, NSUInteger length) { - DeallocateVirtualMemory(pointer, length); - }]; - } +MtlBufferView MtlBufferView::GetWriteView(const Tensor& tensor, + id device) { + auto lock(absl::make_unique(&tensor.view_mutex_)); + tensor.valid_ = Tensor::kValidMetalBuffer; + AllocateMtlBuffer(tensor, device); + return {tensor.mtl_resources_->metal_buffer, std::move(lock)}; } +#else +struct MtlResources {}; #endif // MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 @@ -379,6 +392,9 @@ Tensor& Tensor::operator=(Tensor&& src) { return *this; } +Tensor::Tensor(Tensor&& src) { Move(&src); } +Tensor::~Tensor() { Invalidate(); } + void Tensor::Move(Tensor* src) { valid_ = src->valid_; src->valid_ = kValidNone; @@ -388,15 +404,7 @@ void Tensor::Move(Tensor* src) { cpu_buffer_ = src->cpu_buffer_; src->cpu_buffer_ = nullptr; ahwb_tracking_key_ = src->ahwb_tracking_key_; -#if MEDIAPIPE_METAL_ENABLED - device_ = src->device_; - src->device_ = nil; - command_buffer_ = src->command_buffer_; - src->command_buffer_ = nil; - metal_buffer_ = src->metal_buffer_; - src->metal_buffer_ = nil; -#endif // MEDIAPIPE_METAL_ENABLED - + mtl_resources_ = std::move(src->mtl_resources_); MoveAhwbStuff(src); #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 @@ -415,12 +423,15 @@ void Tensor::Move(Tensor* src) { } Tensor::Tensor(ElementType element_type, const Shape& shape) - : element_type_(element_type), shape_(shape) {} + : element_type_(element_type), + shape_(shape), + mtl_resources_(std::make_unique()) {} Tensor::Tensor(ElementType element_type, const Shape& shape, const QuantizationParameters& quantization_parameters) : element_type_(element_type), shape_(shape), - quantization_parameters_(quantization_parameters) {} + quantization_parameters_(quantization_parameters), + mtl_resources_(std::make_unique()) {} #if MEDIAPIPE_METAL_ENABLED void Tensor::Invalidate() { @@ -432,13 +443,16 @@ void Tensor::Invalidate() { absl::MutexLock lock(&view_mutex_); // If memory is allocated and not owned by the metal buffer. // TODO: Re-design cpu buffer memory management. - if (cpu_buffer_ && !metal_buffer_) { + if (cpu_buffer_ && !mtl_resources_->metal_buffer) { DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes())); } - metal_buffer_ = nil; - command_buffer_ = nil; - device_ = nil; cpu_buffer_ = nullptr; + // This becomes NULL if the tensor is moved. + if (mtl_resources_) { + mtl_resources_->metal_buffer = nil; + mtl_resources_->command_buffer = nil; + mtl_resources_->device = nil; + } #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 // Don't need to wait for the resource to be deleted bacause if will be // released on last reference deletion inside the OpenGL driver. @@ -532,10 +546,11 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { // GPU-to-CPU synchronization and read-back. #if MEDIAPIPE_METAL_ENABLED if (valid_ & kValidMetalBuffer) { - LOG_IF(FATAL, !command_buffer_) << "Metal -> CPU synchronization " - "requires MTLCommandBuffer to be set."; - if (command_buffer_) { - [command_buffer_ waitUntilCompleted]; + LOG_IF(FATAL, !mtl_resources_->command_buffer) + << "Metal -> CPU synchronization " + "requires MTLCommandBuffer to be set."; + if (mtl_resources_->command_buffer) { + [mtl_resources_->command_buffer waitUntilCompleted]; } } #endif // MEDIAPIPE_METAL_ENABLED diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index fe0be31d1..1d670d805 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -29,9 +29,6 @@ #include "mediapipe/framework/formats/tensor/internal.h" #include "mediapipe/framework/port.h" -#if MEDIAPIPE_METAL_ENABLED -#import -#endif // MEDIAPIPE_METAL_ENABLED #ifndef MEDIAPIPE_NO_JNI #if __ANDROID_API__ >= 26 || defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__) #define MEDIAPIPE_TENSOR_USE_AHWB 1 @@ -66,7 +63,6 @@ #endif namespace mediapipe { - // Tensor is a container of multi-dimensional data that supports sharing the // content across different backends and APIs, currently: CPU / Metal / OpenGL. // Texture2DView is limited to 4 dimensions. @@ -91,6 +87,7 @@ namespace mediapipe { // float* pointer = view.buffer(); // ...reading the cpu memory... +struct MtlResources; class Tensor { class View { public: @@ -144,9 +141,9 @@ class Tensor { Tensor(const Tensor&) = delete; Tensor& operator=(const Tensor&) = delete; // Move-only. - Tensor(Tensor&& src) { Move(&src); } + Tensor(Tensor&& src); Tensor& operator=(Tensor&&); - ~Tensor() { Invalidate(); } + ~Tensor(); template class CpuView : public View { @@ -182,33 +179,6 @@ class Tensor { uint64_t source_location_hash = tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const; -#if MEDIAPIPE_METAL_ENABLED - // TODO: id vs. MtlBufferView. - class MtlBufferView : public View { - public: - id buffer() const { return buffer_; } - MtlBufferView(MtlBufferView&& src) - : View(std::move(src)), buffer_(src.buffer_) { - src.buffer_ = nil; - } - - protected: - friend class Tensor; - MtlBufferView(id buffer, std::unique_ptr&& lock) - : View(std::move(lock)), buffer_(buffer) {} - id buffer_; - }; - // The command buffer status is checked for completeness if GPU-to-CPU - // synchronization is required. - // TODO: Design const and non-const view acquiring. - MtlBufferView GetMtlBufferReadView(id command_buffer) const; - MtlBufferView GetMtlBufferWriteView( - id command_buffer) const; - // Allocate new buffer. - // TODO: GPU-to-CPU design considerations. - MtlBufferView GetMtlBufferWriteView(id device) const; -#endif // MEDIAPIPE_METAL_ENABLED - #ifdef MEDIAPIPE_TENSOR_USE_AHWB using FinishingFunc = std::function; class AHardwareBufferView : public View { @@ -372,6 +342,7 @@ class Tensor { } private: + friend class MtlBufferView; void Move(Tensor*); void Invalidate(); @@ -396,12 +367,9 @@ class Tensor { mutable void* cpu_buffer_ = nullptr; void AllocateCpuBuffer() const; -#if MEDIAPIPE_METAL_ENABLED - mutable id command_buffer_ = nil; - mutable id device_ = nil; - mutable id metal_buffer_ = nil; - void AllocateMtlBuffer(id device) const; -#endif // MEDIAPIPE_METAL_ENABLED + // Forward declaration of the MtlResources provides compile-time verification + // of ODR if this header includes any actual code that uses MtlResources. + mutable std::unique_ptr mtl_resources_; #ifdef MEDIAPIPE_TENSOR_USE_AHWB mutable AHardwareBuffer* ahwb_ = nullptr; diff --git a/mediapipe/framework/formats/tensor_mtl_buffer_view.h b/mediapipe/framework/formats/tensor_mtl_buffer_view.h new file mode 100644 index 000000000..a61659d3d --- /dev/null +++ b/mediapipe/framework/formats/tensor_mtl_buffer_view.h @@ -0,0 +1,61 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_MTL_BUFFER_VIEW_H_ +#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_MTL_BUFFER_VIEW_H_ + +#import + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port.h" + +namespace mediapipe { +class MtlBufferView : public Tensor::View { + public: + // The command buffer status is checked for completeness if GPU-to-CPU + // synchronization is required. + static MtlBufferView GetReadView(const Tensor& tensor, + id command_buffer); + static MtlBufferView GetWriteView(const Tensor& tensor, + id command_buffer); + static MtlBufferView GetWriteView(const Tensor& tensor, id device); + + id buffer() const { return buffer_; } + MtlBufferView(MtlBufferView&& src) + : Tensor::View(std::move(src)), buffer_(src.buffer_) { + src.buffer_ = nil; + } + + protected: + friend class Tensor; + static void AllocateMtlBuffer(const Tensor& tensor, id device); + MtlBufferView(id buffer, std::unique_ptr&& lock) + : Tensor::View(std::move(lock)), buffer_(buffer) {} + id buffer_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_MTL_BUFFER_VIEW_H_ diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index bf31134e4..16931811c 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -79,14 +79,6 @@ mediapipe_proto_library( cc_library( name = "score_calibration_calculator", srcs = ["score_calibration_calculator.cc"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ ":score_calibration_calculator_cc_proto", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 517a27114..cec44a9e3 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -28,14 +28,6 @@ cc_library( name = "classification_postprocessing_graph", srcs = ["classification_postprocessing_graph.cc"], hdrs = ["classification_postprocessing_graph.h"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ "//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/core:split_vector_calculator_cc_proto", @@ -148,14 +140,6 @@ cc_library( name = "text_preprocessing_graph", srcs = ["text_preprocessing_graph.cc"], hdrs = ["text_preprocessing_graph.h"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ "//mediapipe/calculators/tensor:bert_preprocessor_calculator", "//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto", From 69d354fc89173035007280daef793b7a640542fe Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 23 Jan 2023 12:09:41 -0800 Subject: [PATCH 342/346] Use c++ struct as hand landmark detection results. PiperOrigin-RevId: 504048095 --- .../tasks/cc/components/containers/BUILD | 9 ++ .../containers/classification_result.cc | 13 +++ .../containers/classification_result.h | 7 ++ .../cc/components/containers/landmark.cc | 65 +++++++++++ .../tasks/cc/components/containers/landmark.h | 103 ++++++++++++++++++ .../tasks/cc/vision/hand_landmarker/BUILD | 3 + .../vision/hand_landmarker/hand_landmarker.cc | 60 ++++++---- .../hand_landmarker/hand_landmarker_result.cc | 56 ++++++++++ .../hand_landmarker/hand_landmarker_result.h | 15 ++- .../hand_landmarker_result_test.cc | 88 +++++++++++++++ .../hand_landmarker/hand_landmarker_test.cc | 62 ++++++++--- 11 files changed, 439 insertions(+), 42 deletions(-) create mode 100644 mediapipe/tasks/cc/components/containers/landmark.cc create mode 100644 mediapipe/tasks/cc/components/containers/landmark.h create mode 100644 mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.cc create mode 100644 mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result_test.cc diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index 0750a1482..a7307b2ce 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -62,3 +62,12 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", ], ) + +cc_library( + name = "landmark", + srcs = ["landmark.cc"], + hdrs = ["landmark.h"], + deps = [ + "//mediapipe/framework/formats:landmark_cc_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/containers/classification_result.cc b/mediapipe/tasks/cc/components/containers/classification_result.cc index 98583ff15..f2d88406d 100644 --- a/mediapipe/tasks/cc/components/containers/classification_result.cc +++ b/mediapipe/tasks/cc/components/containers/classification_result.cc @@ -40,6 +40,19 @@ Classifications ConvertToClassifications(const proto::Classifications& proto) { return classifications; } +Classifications ConvertToClassifications( + const mediapipe::ClassificationList& proto, int head_index, + std::optional head_name) { + Classifications classifications; + classifications.categories.reserve(proto.classification_size()); + for (const auto& classification : proto.classification()) { + classifications.categories.push_back(ConvertToCategory(classification)); + } + classifications.head_index = head_index; + classifications.head_name = head_name; + return classifications; +} + ClassificationResult ConvertToClassificationResult( const proto::ClassificationResult& proto) { ClassificationResult classification_result; diff --git a/mediapipe/tasks/cc/components/containers/classification_result.h b/mediapipe/tasks/cc/components/containers/classification_result.h index 88273fd00..e359fb33e 100644 --- a/mediapipe/tasks/cc/components/containers/classification_result.h +++ b/mediapipe/tasks/cc/components/containers/classification_result.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/tasks/cc/components/containers/category.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" @@ -58,6 +59,12 @@ struct ClassificationResult { // Classifications struct. Classifications ConvertToClassifications(const proto::Classifications& proto); +// Utility function to convert from ClassificationList proto to +// Classifications struct. +Classifications ConvertToClassifications( + const mediapipe::ClassificationList& proto, int head_index = 0, + std::optional head_name = std::nullopt); + // Utility function to convert from ClassificationResult proto to // ClassificationResult struct. ClassificationResult ConvertToClassificationResult( diff --git a/mediapipe/tasks/cc/components/containers/landmark.cc b/mediapipe/tasks/cc/components/containers/landmark.cc new file mode 100644 index 000000000..6d80cb835 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/landmark.cc @@ -0,0 +1,65 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/landmark.h" + +#include +#include + +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe::tasks::components::containers { + +Landmark ConvertToLandmark(const mediapipe::Landmark& proto) { + return {/*x=*/proto.x(), /*y=*/proto.y(), /*z=*/proto.z(), + /*visibility=*/proto.has_visibility() + ? std::optional(proto.visibility()) + : std::nullopt, + /*presence=*/proto.has_presence() + ? std::optional(proto.presence()) + : std::nullopt}; +} + +NormalizedLandmark ConvertToNormalizedLandmark( + const mediapipe::NormalizedLandmark& proto) { + return {/*x=*/proto.x(), /*y=*/proto.y(), /*z=*/proto.z(), + /*visibility=*/proto.has_visibility() + ? std::optional(proto.visibility()) + : std::nullopt, + /*presence=*/proto.has_presence() + ? std::optional(proto.presence()) + : std::nullopt}; +} + +Landmarks ConvertToLandmarks(const mediapipe::LandmarkList& proto) { + Landmarks landmarks; + landmarks.landmarks.reserve(proto.landmark_size()); + for (const auto& landmark : proto.landmark()) { + landmarks.landmarks.push_back(ConvertToLandmark(landmark)); + } + return landmarks; +} + +NormalizedLandmarks ConvertToNormalizedLandmarks( + const mediapipe::NormalizedLandmarkList& proto) { + NormalizedLandmarks landmarks; + landmarks.landmarks.reserve(proto.landmark_size()); + for (const auto& landmark : proto.landmark()) { + landmarks.landmarks.push_back(ConvertToNormalizedLandmark(landmark)); + } + return landmarks; +} + +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/landmark.h b/mediapipe/tasks/cc/components/containers/landmark.h new file mode 100644 index 000000000..15b730001 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/landmark.h @@ -0,0 +1,103 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ + +#include +#include +#include + +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe::tasks::components::containers { +constexpr float kLandmarkTolerance = 1e-6; + +// Landmark represents a point in 3D space with x, y, z coordinates. The +// landmark coordinates are in meters. z represents the landmark depth, and the +// smaller the value the closer the world landmark is to the camera. +struct Landmark { + float x; + float y; + float z; + // Landmark visibility. Should stay unset if not supported. + // Float score of whether landmark is visible or occluded by other objects. + // Landmark considered as invisible also if it is not present on the screen + // (out of scene bounds). Depending on the model, visibility value is either a + // sigmoid or an argument of sigmoid. + std::optional visibility = std::nullopt; + // Landmark presence. Should stay unset if not supported. + // Float score of whether landmark is present on the scene (located within + // scene bounds). Depending on the model, presence value is either a result of + // sigmoid or an argument of sigmoid function to get landmark presence + // probability. + std::optional presence = std::nullopt; + // Landmark name. Should stay unset if not supported. + std::optional name = std::nullopt; +}; + +inline bool operator==(const Landmark& lhs, const Landmark& rhs) { + return abs(lhs.x - rhs.x) < kLandmarkTolerance && + abs(lhs.y - rhs.y) < kLandmarkTolerance && + abs(lhs.z - rhs.z) < kLandmarkTolerance; +} + +// A normalized version of above Landmark struct. All coordinates should be +// within [0, 1]. +struct NormalizedLandmark { + float x; + float y; + float z; + std::optional visibility = std::nullopt; + std::optional presence = std::nullopt; + std::optional name = std::nullopt; +}; + +inline bool operator==(const NormalizedLandmark& lhs, + const NormalizedLandmark& rhs) { + return abs(lhs.x - rhs.x) < kLandmarkTolerance && + abs(lhs.y - rhs.y) < kLandmarkTolerance && + abs(lhs.z - rhs.z) < kLandmarkTolerance; +} + +// A list of Landmarks. +struct Landmarks { + std::vector landmarks; +}; + +// A list of NormalizedLandmarks. +struct NormalizedLandmarks { + std::vector landmarks; +}; + +// Utility function to convert from Landmark proto to Landmark struct. +Landmark ConvertToLandmark(const mediapipe::Landmark& proto); + +// Utility function to convert from NormalizedLandmark proto to +// NormalizedLandmark struct. +NormalizedLandmark ConvertToNormalizedLandmark( + const mediapipe::NormalizedLandmark& proto); + +// Utility function to convert from LandmarkList proto to Landmarks struct. +Landmarks ConvertToLandmarks(const mediapipe::LandmarkList& proto); + +// Utility function to convert from NormalizedLandmarkList proto to +// NormalizedLandmarks struct. +NormalizedLandmarks ConvertToNormalizedLandmarks( + const mediapipe::NormalizedLandmarkList& proto); + +} // namespace mediapipe::tasks::components::containers + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 03ec45f7d..2552e7a10 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -154,11 +154,14 @@ cc_library( cc_library( name = "hand_landmarker_result", + srcs = ["hand_landmarker_result.cc"], hdrs = ["hand_landmarker_result.h"], visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/tasks/cc/components/containers:classification_result", + "//mediapipe/tasks/cc/components/containers:landmark", ], ) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc index 3bb1ee8d8..ab66fe136 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc @@ -155,9 +155,13 @@ absl::StatusOr> HandLandmarker::Create( Packet hand_world_landmarks_packet = status_or_packets.value()[kHandWorldLandmarksStreamName]; result_callback( - {{handedness_packet.Get>(), - hand_landmarks_packet.Get>(), - hand_world_landmarks_packet.Get>()}}, + ConvertToHandLandmarkerResult( + /* handedness= */ handedness_packet + .Get>(), + /* hand_landmarks= */ + hand_landmarks_packet.Get>(), + /* hand_world_landmarks= */ + hand_world_landmarks_packet.Get>()), image_packet.Get(), hand_landmarks_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); @@ -193,15 +197,21 @@ absl::StatusOr HandLandmarker::Detect( if (output_packets[kHandLandmarksStreamName].IsEmpty()) { return {HandLandmarkerResult()}; } - return {{/* handedness= */ - {output_packets[kHandednessStreamName] - .Get>()}, - /* hand_landmarks= */ - {output_packets[kHandLandmarksStreamName] - .Get>()}, - /* hand_world_landmarks */ - {output_packets[kHandWorldLandmarksStreamName] - .Get>()}}}; + return ConvertToHandLandmarkerResult(/* handedness= */ + output_packets[kHandednessStreamName] + .Get>(), + /* hand_landmarks= */ + output_packets[kHandLandmarksStreamName] + .Get>(), + /* hand_world_landmarks */ + output_packets + [kHandWorldLandmarksStreamName] + .Get>()); } absl::StatusOr HandLandmarker::DetectForVideo( @@ -228,17 +238,21 @@ absl::StatusOr HandLandmarker::DetectForVideo( if (output_packets[kHandLandmarksStreamName].IsEmpty()) { return {HandLandmarkerResult()}; } - return { - {/* handedness= */ - {output_packets[kHandednessStreamName] - .Get>()}, - /* hand_landmarks= */ - {output_packets[kHandLandmarksStreamName] - .Get>()}, - /* hand_world_landmarks */ - {output_packets[kHandWorldLandmarksStreamName] - .Get>()}}, - }; + return ConvertToHandLandmarkerResult(/* handedness= */ + output_packets[kHandednessStreamName] + .Get>(), + /* hand_landmarks= */ + output_packets[kHandLandmarksStreamName] + .Get>(), + /* hand_world_landmarks */ + output_packets + [kHandWorldLandmarksStreamName] + .Get>()); } absl::Status HandLandmarker::DetectAsync( diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.cc new file mode 100644 index 000000000..9d2ae2be8 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.cc @@ -0,0 +1,56 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h" + +#include + +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +HandLandmarkerResult ConvertToHandLandmarkerResult( + const std::vector& handedness_proto, + const std::vector& hand_landmarks_proto, + const std::vector& hand_world_landmarks_proto) { + HandLandmarkerResult result; + result.handedness.resize(handedness_proto.size()); + result.hand_landmarks.resize(hand_landmarks_proto.size()); + result.hand_world_landmarks.resize(hand_world_landmarks_proto.size()); + std::transform(handedness_proto.begin(), handedness_proto.end(), + result.handedness.begin(), + [](const mediapipe::ClassificationList& classification_list) { + return components::containers::ConvertToClassifications( + classification_list); + }); + std::transform(hand_landmarks_proto.begin(), hand_landmarks_proto.end(), + result.hand_landmarks.begin(), + components::containers::ConvertToNormalizedLandmarks); + std::transform(hand_world_landmarks_proto.begin(), + hand_world_landmarks_proto.end(), + result.hand_world_landmarks.begin(), + components::containers::ConvertToLandmarks); + return result; +} + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h index 5e51c244e..1bca8e66a 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ limitations under the License. #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" namespace mediapipe { namespace tasks { @@ -28,13 +30,18 @@ namespace hand_landmarker { // element represents a single hand detected in the image. struct HandLandmarkerResult { // Classification of handedness. - std::vector handedness; + std::vector handedness; // Detected hand landmarks in normalized image coordinates. - std::vector hand_landmarks; + std::vector hand_landmarks; // Detected hand landmarks in world coordinates. - std::vector hand_world_landmarks; + std::vector hand_world_landmarks; }; +HandLandmarkerResult ConvertToHandLandmarkerResult( + const std::vector& handedness_proto, + const std::vector& hand_landmarks_proto, + const std::vector& hand_world_landmarks_proto); + } // namespace hand_landmarker } // namespace vision } // namespace tasks diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result_test.cc new file mode 100644 index 000000000..109749b01 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result_test.cc @@ -0,0 +1,88 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h" + +#include +#include + +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +TEST(ConvertFromProto, Succeeds) { + mediapipe::ClassificationList classification_list_proto; + mediapipe::Classification& classification_proto = + *classification_list_proto.add_classification(); + classification_proto.set_index(1); + classification_proto.set_score(0.5); + classification_proto.set_label("Left"); + classification_proto.set_display_name("Left_Hand"); + + mediapipe::NormalizedLandmarkList normalized_landmark_list_proto; + mediapipe::NormalizedLandmark& normalized_landmark_proto = + *normalized_landmark_list_proto.add_landmark(); + normalized_landmark_proto.set_x(0.1); + normalized_landmark_proto.set_y(0.2); + normalized_landmark_proto.set_z(0.3); + + mediapipe::LandmarkList landmark_list_proto; + mediapipe::Landmark& landmark_proto = *landmark_list_proto.add_landmark(); + landmark_proto.set_x(3.1); + landmark_proto.set_y(5.2); + landmark_proto.set_z(4.3); + + std::vector classification_lists = { + classification_list_proto}; + std::vector normalized_landmarks_lists = { + normalized_landmark_list_proto}; + std::vector landmarks_lists = {landmark_list_proto}; + + HandLandmarkerResult hand_landmarker_result = ConvertToHandLandmarkerResult( + classification_lists, normalized_landmarks_lists, landmarks_lists); + + EXPECT_EQ(hand_landmarker_result.handedness.size(), 1); + EXPECT_EQ(hand_landmarker_result.handedness[0].categories.size(), 1); + EXPECT_THAT( + hand_landmarker_result.handedness[0].categories[0], + testing::FieldsAre(1, testing::FloatEq(0.5), "Left", "Left_Hand")); + + EXPECT_EQ(hand_landmarker_result.hand_landmarks.size(), 1); + EXPECT_EQ(hand_landmarker_result.hand_landmarks[0].landmarks.size(), 1); + EXPECT_THAT(hand_landmarker_result.hand_landmarks[0].landmarks[0], + testing::FieldsAre(testing::FloatEq(0.1), testing::FloatEq(0.2), + testing::FloatEq(0.3), std::nullopt, + std::nullopt, std::nullopt)); + + EXPECT_EQ(hand_landmarker_result.hand_world_landmarks.size(), 1); + EXPECT_EQ(hand_landmarker_result.hand_world_landmarks[0].landmarks.size(), 1); + EXPECT_THAT(hand_landmarker_result.hand_world_landmarks[0].landmarks[0], + testing::FieldsAre(testing::FloatEq(3.1), testing::FloatEq(5.2), + testing::FloatEq(4.3), std::nullopt, + std::nullopt, std::nullopt)); +} + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc index 94d1b1c12..b21f1bee9 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc @@ -32,6 +32,8 @@ limitations under the License. #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" #include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" @@ -50,18 +52,16 @@ namespace { using ::file::Defaults; using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::ConvertToClassifications; +using ::mediapipe::tasks::components::containers::ConvertToNormalizedLandmarks; using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; -using ::testing::EqualsProto; using ::testing::HasSubstr; using ::testing::Optional; -using ::testing::Pointwise; using ::testing::TestParamInfo; using ::testing::TestWithParam; using ::testing::Values; -using ::testing::proto::Approximately; -using ::testing::proto::Partially; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kHandLandmarkerBundleAsset[] = "hand_landmarker.task"; @@ -74,7 +74,6 @@ constexpr char kPointingUpImage[] = "pointing_up.jpg"; constexpr char kPointingUpRotatedImage[] = "pointing_up_rotated.jpg"; constexpr char kNoHandsImage[] = "cats_and_dogs.jpg"; -constexpr float kLandmarksFractionDiff = 0.03; // percentage constexpr float kLandmarksAbsMargin = 0.03; constexpr float kHandednessMargin = 0.05; @@ -101,13 +100,47 @@ HandLandmarkerResult GetExpectedHandLandmarkerResult( const auto landmarks_detection_result = GetLandmarksDetectionResult(file_name); expected_results.hand_landmarks.push_back( - landmarks_detection_result.landmarks()); + ConvertToNormalizedLandmarks(landmarks_detection_result.landmarks())); expected_results.handedness.push_back( - landmarks_detection_result.classifications()); + ConvertToClassifications(landmarks_detection_result.classifications())); } return expected_results; } +MATCHER_P2(HandednessMatches, expected_handedness, tolerance, "") { + for (int i = 0; i < arg.size(); i++) { + for (int j = 0; j < arg[i].categories.size(); j++) { + if (arg[i].categories[j].index != + expected_handedness[i].categories[j].index) { + return false; + } + if (std::abs(arg[i].categories[j].score - + expected_handedness[i].categories[j].score) > tolerance) { + return false; + } + if (arg[i].categories[j].category_name != + expected_handedness[i].categories[j].category_name) { + return false; + } + } + } + return true; +} + +MATCHER_P2(LandmarksMatches, expected_landmarks, toleration, "") { + for (int i = 0; i < arg.size(); i++) { + for (int j = 0; j < arg[i].landmarks.size(); j++) { + if (std::abs(arg[i].landmarks[j].x - + expected_landmarks[i].landmarks[j].x) > toleration || + std::abs(arg[i].landmarks[j].y - + expected_landmarks[i].landmarks[j].y) > toleration) { + return false; + } + } + } + return true; +} + void ExpectHandLandmarkerResultsCorrect( const HandLandmarkerResult& actual_results, const HandLandmarkerResult& expected_results) { @@ -119,16 +152,15 @@ void ExpectHandLandmarkerResultsCorrect( ASSERT_EQ(actual_landmarks.size(), expected_landmarks.size()); ASSERT_EQ(actual_handedness.size(), expected_handedness.size()); + if (actual_landmarks.empty()) { + return; + } + ASSERT_GE(actual_landmarks.size(), 1); - EXPECT_THAT( - actual_handedness, - Pointwise(Approximately(Partially(EqualsProto()), kHandednessMargin), - expected_handedness)); + EXPECT_THAT(actual_handedness, + HandednessMatches(expected_handedness, kHandednessMargin)); EXPECT_THAT(actual_landmarks, - Pointwise(Approximately(Partially(EqualsProto()), - /*margin=*/kLandmarksAbsMargin, - /*fraction=*/kLandmarksFractionDiff), - expected_landmarks)); + LandmarksMatches(expected_landmarks, kLandmarksAbsMargin)); } } // namespace From ccd1461add4b6ecc974a46df597bcac8c154bbc9 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 23 Jan 2023 13:36:32 -0800 Subject: [PATCH 343/346] Don't error in ExternalFile handler on Windows if FileContent is provided PiperOrigin-RevId: 504069137 --- mediapipe/tasks/cc/core/external_file_handler.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mediapipe/tasks/cc/core/external_file_handler.cc b/mediapipe/tasks/cc/core/external_file_handler.cc index 33dfeca0b..ff30bea72 100644 --- a/mediapipe/tasks/cc/core/external_file_handler.cc +++ b/mediapipe/tasks/cc/core/external_file_handler.cc @@ -84,12 +84,6 @@ ExternalFileHandler::CreateFromExternalFile( } absl::Status ExternalFileHandler::MapExternalFile() { -// TODO: Add Windows support -#ifdef _WIN32 - return CreateStatusWithPayload(StatusCode::kFailedPrecondition, - "File loading is not yet supported on Windows", - MediaPipeTasksStatus::kFileReadError); -#else if (!external_file_.file_content().empty()) { return absl::OkStatus(); } else if (external_file_.has_file_pointer_meta()) { @@ -106,6 +100,13 @@ absl::Status ExternalFileHandler::MapExternalFile() { } return absl::OkStatus(); } + +// TODO: Add Windows support +#ifdef _WIN32 + return CreateStatusWithPayload(StatusCode::kFailedPrecondition, + "File loading is not yet supported on Windows", + MediaPipeTasksStatus::kFileReadError); +#else if (external_file_.file_name().empty() && !external_file_.has_file_descriptor_meta()) { return CreateStatusWithPayload( From 873d7181bf60fb29a5a441c8207a219029dafb98 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 23 Jan 2023 14:13:38 -0800 Subject: [PATCH 344/346] Add mediapipe tasks face detector graph PiperOrigin-RevId: 504078951 --- mediapipe/tasks/cc/vision/face_detector/BUILD | 61 +++++ .../face_detector/face_detector_graph.cc | 208 ++++++++++++++++++ .../face_detector/face_detector_graph_test.cc | 183 +++++++++++++++ .../tasks/cc/vision/face_detector/proto/BUILD | 31 +++ .../proto/face_detector_graph_options.proto | 42 ++++ mediapipe/tasks/testdata/vision/BUILD | 8 + .../vision/portrait_expected_detection.pbtxt | 35 +++ third_party/external_files.bzl | 20 +- 8 files changed, 584 insertions(+), 4 deletions(-) create mode 100644 mediapipe/tasks/cc/vision/face_detector/BUILD create mode 100644 mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc create mode 100644 mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc create mode 100644 mediapipe/tasks/cc/vision/face_detector/proto/BUILD create mode 100644 mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto create mode 100644 mediapipe/tasks/testdata/vision/portrait_expected_detection.pbtxt diff --git a/mediapipe/tasks/cc/vision/face_detector/BUILD b/mediapipe/tasks/cc/vision/face_detector/BUILD new file mode 100644 index 000000000..09af34aa0 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/BUILD @@ -0,0 +1,61 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = [ + # "//mediapipe/tasks:internal", + "//visibility:public", +]) + +licenses(["notice"]) + +cc_library( + name = "face_detector_graph", + srcs = ["face_detector_graph.cc"], + deps = [ + "//mediapipe/calculators/core:clip_vector_size_calculator", + "//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", + "//mediapipe/calculators/tflite:ssd_anchors_calculator", + "//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", + "//mediapipe/calculators/util:detection_projection_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", + "//mediapipe/calculators/util:non_max_suppression_calculator", + "//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto", + "//mediapipe/calculators/util:rect_transformation_calculator", + "//mediapipe/calculators/util:rect_transformation_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc new file mode 100644 index 000000000..6b60621a6 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc @@ -0,0 +1,208 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" +#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" +#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h" +#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h" +#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h" +#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace face_detector { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::Tensor; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::vision::face_detector::proto:: + FaceDetectorGraphOptions; + +namespace { +constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kDetectionsTag[] = "DETECTIONS"; + +void ConfigureSsdAnchorsCalculator( + mediapipe::SsdAnchorsCalculatorOptions* options) { + // TODO config SSD anchors parameters from metadata. + options->set_num_layers(1); + options->set_min_scale(0.1484375); + options->set_max_scale(0.75); + options->set_input_size_height(192); + options->set_input_size_width(192); + options->set_anchor_offset_x(0.5); + options->set_anchor_offset_y(0.5); + options->add_strides(4); + options->add_aspect_ratios(1.0); + options->set_fixed_anchor_size(true); + options->set_interpolated_scale_aspect_ratio(0.0); +} + +void ConfigureTensorsToDetectionsCalculator( + const FaceDetectorGraphOptions& tasks_options, + mediapipe::TensorsToDetectionsCalculatorOptions* options) { + // TODO use metadata to configure these fields. + options->set_num_classes(1); + options->set_num_boxes(2304); + options->set_num_coords(16); + options->set_box_coord_offset(0); + options->set_keypoint_coord_offset(4); + options->set_num_keypoints(6); + options->set_num_values_per_keypoint(2); + options->set_sigmoid_score(true); + options->set_score_clipping_thresh(100.0); + options->set_reverse_output_order(true); + options->set_min_score_thresh(tasks_options.min_detection_confidence()); + options->set_x_scale(192.0); + options->set_y_scale(192.0); + options->set_w_scale(192.0); + options->set_h_scale(192.0); +} + +void ConfigureNonMaxSuppressionCalculator( + const FaceDetectorGraphOptions& tasks_options, + mediapipe::NonMaxSuppressionCalculatorOptions* options) { + options->set_min_suppression_threshold( + tasks_options.min_suppression_threshold()); + options->set_overlap_type( + mediapipe::NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION); + options->set_algorithm( + mediapipe::NonMaxSuppressionCalculatorOptions::WEIGHTED); +} + +} // namespace + +class FaceDetectorGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + ASSIGN_OR_RETURN(const auto* model_resources, + CreateModelResources(sc)); + Graph graph; + ASSIGN_OR_RETURN(auto face_detections, + BuildFaceDetectionSubgraph( + sc->Options(), + *model_resources, graph[Input(kImageTag)], + graph[Input(kNormRectTag)], graph)); + face_detections >> graph[Output>(kDetectionsTag)]; + return graph.GetConfig(); + } + + private: + absl::StatusOr>> BuildFaceDetectionSubgraph( + const FaceDetectorGraphOptions& subgraph_options, + const core::ModelResources& model_resources, Source image_in, + Source norm_rect_in, Graph& graph) { + // Image preprocessing subgraph to convert image to tensor for the tflite + // model. + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( + model_resources, use_gpu, + &preprocessing.GetOptions< + components::processors::proto::ImagePreprocessingGraphOptions>())); + auto& image_to_tensor_options = + *preprocessing + .GetOptions() + .mutable_image_to_tensor_options(); + image_to_tensor_options.set_keep_aspect_ratio(true); + image_to_tensor_options.set_border_mode( + mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); + image_in >> preprocessing.In("IMAGE"); + norm_rect_in >> preprocessing.In("NORM_RECT"); + auto preprocessed_tensors = preprocessing.Out("TENSORS"); + auto matrix = preprocessing.Out("MATRIX"); + + // Face detection model inferece. + auto& inference = AddInference( + model_resources, subgraph_options.base_options().acceleration(), graph); + preprocessed_tensors >> inference.In("TENSORS"); + auto model_output_tensors = + inference.Out("TENSORS").Cast>(); + + // Generates a single side packet containing a vector of SSD anchors. + auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator"); + ConfigureSsdAnchorsCalculator( + &ssd_anchor.GetOptions()); + auto anchors = ssd_anchor.SideOut(""); + + // Converts output tensors to Detections. + auto& tensors_to_detections = + graph.AddNode("TensorsToDetectionsCalculator"); + ConfigureTensorsToDetectionsCalculator( + subgraph_options, + &tensors_to_detections + .GetOptions()); + model_output_tensors >> tensors_to_detections.In("TENSORS"); + anchors >> tensors_to_detections.SideIn("ANCHORS"); + auto detections = tensors_to_detections.Out("DETECTIONS"); + + // Non maximum suppression removes redundant face detections. + auto& non_maximum_suppression = + graph.AddNode("NonMaxSuppressionCalculator"); + ConfigureNonMaxSuppressionCalculator( + subgraph_options, + &non_maximum_suppression + .GetOptions()); + detections >> non_maximum_suppression.In(""); + auto nms_detections = non_maximum_suppression.Out(""); + + // Projects detections back into the input image coordinates system. + auto& detection_projection = graph.AddNode("DetectionProjectionCalculator"); + nms_detections >> detection_projection.In("DETECTIONS"); + matrix >> detection_projection.In("PROJECTION_MATRIX"); + auto face_detections = + detection_projection[Output>("DETECTIONS")]; + + return {face_detections}; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::face_detector::FaceDetectorGraph); + +} // namespace face_detector +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc new file mode 100644 index 000000000..fc1f49f13 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace face_detector { +namespace { + +using ::file::Defaults; +using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::DecodeImageFromFile; +using ::mediapipe::tasks::vision::face_detector::proto:: + FaceDetectorGraphOptions; +using ::testing::EqualsProto; +using ::testing::Pointwise; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kFullRangeBlazeFaceModel[] = "face_detection_full_range.tflite"; +constexpr char kFullRangeSparseBlazeFaceModel[] = + "face_detection_full_range_sparse.tflite"; +constexpr char kPortraitImage[] = "portrait.jpg"; +constexpr char kPortraitExpectedDetection[] = + "portrait_expected_detection.pbtxt"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormRectName[] = "norm_rect"; +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kDetectionsName[] = "detections"; + +constexpr float kFaceDetectionMaxDiff = 0.01; + +// Helper function to create a TaskRunner. +absl::StatusOr> CreateTaskRunner( + absl::string_view model_name) { + Graph graph; + + auto& face_detector_graph = + graph.AddNode("mediapipe.tasks.vision.face_detector.FaceDetectorGraph"); + + auto options = std::make_unique(); + options->mutable_base_options()->mutable_model_asset()->set_file_name( + JoinPath("./", kTestDataDirectory, model_name)); + options->set_min_detection_confidence(0.6); + options->set_min_suppression_threshold(0.3); + face_detector_graph.GetOptions().Swap( + options.get()); + + graph[Input(kImageTag)].SetName(kImageName) >> + face_detector_graph.In(kImageTag); + graph[Input(kNormRectTag)].SetName(kNormRectName) >> + face_detector_graph.In(kNormRectTag); + + face_detector_graph.Out(kDetectionsTag).SetName(kDetectionsName) >> + graph[Output>(kDetectionsTag)]; + + return TaskRunner::Create( + graph.GetConfig(), std::make_unique()); +} + +Detection GetExpectedFaceDetectionResult(absl::string_view file_name) { + Detection detection; + CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name), + &detection, Defaults())) + << "Expected face detection result does not exist."; + return detection; +} + +struct TestParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of face landmark detection model. + std::string face_detection_model_name; + // The filename of test image. + std::string test_image_name; + // Expected face detection results. + std::vector expected_result; +}; + +class FaceDetectorGraphTest : public testing::TestWithParam {}; + +TEST_P(FaceDetectorGraphTest, Succeed) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + GetParam().test_image_name))); + NormalizedRect input_norm_rect; + input_norm_rect.set_x_center(0.5); + input_norm_rect.set_y_center(0.5); + input_norm_rect.set_width(1.0); + input_norm_rect.set_height(1.0); + MP_ASSERT_OK_AND_ASSIGN( + auto task_runner, CreateTaskRunner(GetParam().face_detection_model_name)); + auto output_packets = task_runner->Process( + {{kImageName, MakePacket(std::move(image))}, + {kNormRectName, + MakePacket(std::move(input_norm_rect))}}); + MP_ASSERT_OK(output_packets); + const std::vector& face_detections = + (*output_packets)[kDetectionsName].Get>(); + EXPECT_THAT(face_detections, Pointwise(Approximately(Partially(EqualsProto()), + kFaceDetectionMaxDiff), + GetParam().expected_result)); +} + +INSTANTIATE_TEST_SUITE_P( + FaceDetectorGraphTest, FaceDetectorGraphTest, + Values(TestParams{.test_name = "FullRange", + .face_detection_model_name = kFullRangeBlazeFaceModel, + .test_image_name = kPortraitImage, + .expected_result = {GetExpectedFaceDetectionResult( + kPortraitExpectedDetection)}}, + TestParams{ + .test_name = "FullRangeSparse", + .face_detection_model_name = kFullRangeSparseBlazeFaceModel, + .test_image_name = kPortraitImage, + .expected_result = {GetExpectedFaceDetectionResult( + kPortraitExpectedDetection)}}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace face_detector +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/face_detector/proto/BUILD b/mediapipe/tasks/cc/vision/face_detector/proto/BUILD new file mode 100644 index 000000000..ca9a6f8c4 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/proto/BUILD @@ -0,0 +1,31 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "face_detector_graph_options_proto", + srcs = ["face_detector_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto b/mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto new file mode 100644 index 000000000..a58338288 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto @@ -0,0 +1,42 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.face_detector.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.facedetector.proto"; +option java_outer_classname = "FaceDetectorGraphOptionsProto"; + +message FaceDetectorGraphOptions { + extend mediapipe.CalculatorOptions { + optional FaceDetectorGraphOptions ext = 502141897; + } + // Base options for configuring Task library, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Minimum confidence value ([0.0, 1.0]) for confidence score to be considered + // successfully detecting a face in the image. + optional float min_detection_confidence = 2 [default = 0.5]; + + // IoU threshold ([0,0, 1.0]) for non-maximu-suppression to be considered + // duplicate detetions. + optional float min_suppression_threshold = 3 [default = 0.5]; +} diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 607245700..09f830aba 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -37,6 +37,8 @@ mediapipe_files(srcs = [ "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", + "face_detection_full_range.tflite", + "face_detection_full_range_sparse.tflite", "fist.jpg", "fist.png", "hand_landmark_full.tflite", @@ -58,6 +60,7 @@ mediapipe_files(srcs = [ "palm_detection_full.tflite", "pointing_up.jpg", "pointing_up_rotated.jpg", + "portrait.jpg", "right_hands.jpg", "right_hands_rotated.jpg", "segmentation_golden_rotation0.png", @@ -79,6 +82,7 @@ exports_files( "expected_right_down_hand_landmarks.prototxt", "expected_right_up_hand_landmarks.prototxt", "gesture_recognizer.task", + "portrait_expected_detection.pbtxt", ], ) @@ -106,6 +110,7 @@ filegroup( "multi_objects_rotated.jpg", "pointing_up.jpg", "pointing_up_rotated.jpg", + "portrait.jpg", "right_hands.jpg", "right_hands_rotated.jpg", "segmentation_golden_rotation0.png", @@ -129,6 +134,8 @@ filegroup( "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", + "face_detection_full_range.tflite", + "face_detection_full_range_sparse.tflite", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "hand_landmarker.task", @@ -161,6 +168,7 @@ filegroup( "hand_detector_result_two_hands.pbtxt", "pointing_up_landmarks.pbtxt", "pointing_up_rotated_landmarks.pbtxt", + "portrait_expected_detection.pbtxt", "thumb_up_landmarks.pbtxt", "thumb_up_rotated_landmarks.pbtxt", "victory_landmarks.pbtxt", diff --git a/mediapipe/tasks/testdata/vision/portrait_expected_detection.pbtxt b/mediapipe/tasks/testdata/vision/portrait_expected_detection.pbtxt new file mode 100644 index 000000000..775f4479b --- /dev/null +++ b/mediapipe/tasks/testdata/vision/portrait_expected_detection.pbtxt @@ -0,0 +1,35 @@ +# proto-file: mediapipe/framework/formats/detection.proto +# proto-message: Detection +location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.35494408 + ymin: 0.1059662 + width: 0.28768203 + height: 0.23037356 + } + relative_keypoints { + x: 0.44416338 + y: 0.17643969 + } + relative_keypoints { + x: 0.55514044 + y: 0.17731678 + } + relative_keypoints { + x: 0.5046702 + y: 0.2265771 + } + relative_keypoints { + x: 0.50227845 + y: 0.2719954 + } + relative_keypoints { + x: 0.37245658 + y: 0.20143759 + } + relative_keypoints { + x: 0.6084143 + y: 0.20409837 + } +} diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 5adfbdfc6..1d9239c83 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -240,14 +240,14 @@ def external_files(): http_file( name = "com_google_mediapipe_face_detection_full_range_sparse_tflite", - sha256 = "671dd2f9ed11a78436fc21cc42357a803dfc6f73e9fb86541be942d5716c2dce", - urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range_sparse.tflite?generation=1661875739104017"], + sha256 = "2c3728e6da56f21e21a320433396fb06d40d9088f2247c05e5635a688d45dfe1", + urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range_sparse.tflite?generation=1674261618323821"], ) http_file( name = "com_google_mediapipe_face_detection_full_range_tflite", - sha256 = "99bf9494d84f50acc6617d89873f71bf6635a841ea699c17cb3377f9507cfec3", - urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range.tflite?generation=1661875742733283"], + sha256 = "3698b18f063835bc609069ef052228fbe86d9c9a6dc8dcb7c7c2d69aed2b181b", + urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range.tflite?generation=1674261620964007"], ) http_file( @@ -712,6 +712,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666629486774022"], ) + http_file( + name = "com_google_mediapipe_portrait_expected_detection_pbtxt", + sha256 = "bb54e08e87844ef14bb185d5cb808908eb6011bfa6db48bd22d9650f6fda338b", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_detection.pbtxt?generation=1674261627835475"], + ) + + http_file( + name = "com_google_mediapipe_portrait_jpg", + sha256 = "a6f11efaa834706db23f275b6115058fa87fc7f14362681e6abe14e82749de3e", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait.jpg?generation=1674261630039907"], + ) + http_file( name = "com_google_mediapipe_pose_detection_tflite", sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85", From 2465e47b01e6883c22c7a7047e9dd087e93e7615 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 23 Jan 2023 16:41:32 -0800 Subject: [PATCH 345/346] Stream/SidePacket == and != operators PiperOrigin-RevId: 504114182 --- mediapipe/framework/api2/builder.h | 13 +++++++ mediapipe/framework/api2/builder_test.cc | 46 ++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 2a98c4166..da09acc83 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -206,6 +206,16 @@ class SourceImpl { return ConnectTo(dest); } + template + bool operator==(const SourceImpl& other) { + return base_ == other.base_; + } + + template + bool operator!=(const SourceImpl& other) { + return !(*this == other); + } + Src& SetName(std::string name) { base_->name_ = std::move(name); return *this; @@ -218,6 +228,9 @@ class SourceImpl { } private: + template + friend class SourceImpl; + // Never null. SourceBase* base_; }; diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 08f4f0ca1..194f1b8ff 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -494,5 +494,51 @@ TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } +TEST(BuilderTest, TestStreamEqualsNotEqualsOperators) { + Graph graph; + Stream input0 = graph.In(0); + EXPECT_TRUE(input0 == input0); + EXPECT_FALSE(input0 != input0); + + EXPECT_TRUE(input0 == input0.Cast()); + EXPECT_FALSE(input0.Cast() != input0); + + EXPECT_TRUE(input0.Cast() == input0.Cast()); + EXPECT_FALSE(input0.Cast() != input0.Cast()); + + Stream input1 = graph.In(1); + EXPECT_FALSE(input0 == input1); + EXPECT_TRUE(input0 != input1); + + input1 = input0; + EXPECT_TRUE(input0 == input1); + EXPECT_FALSE(input0 != input1); + EXPECT_TRUE(input0.Cast() == input1.Cast()); + EXPECT_FALSE(input0.Cast() != input1.Cast()); +} + +TEST(BuilderTest, TestSidePacketEqualsNotEqualsOperators) { + Graph graph; + SidePacket side_input0 = graph.SideIn(0); + EXPECT_TRUE(side_input0 == side_input0); + EXPECT_FALSE(side_input0 != side_input0); + + EXPECT_TRUE(side_input0 == side_input0.Cast()); + EXPECT_FALSE(side_input0.Cast() != side_input0); + + EXPECT_TRUE(side_input0.Cast() == side_input0.Cast()); + EXPECT_FALSE(side_input0.Cast() != side_input0.Cast()); + + SidePacket side_input1 = graph.SideIn(1); + EXPECT_FALSE(side_input0 == side_input1); + EXPECT_TRUE(side_input0 != side_input1); + + side_input1 = side_input0; + EXPECT_TRUE(side_input0 == side_input1); + EXPECT_FALSE(side_input0 != side_input1); + EXPECT_TRUE(side_input0.Cast() == side_input1.Cast()); + EXPECT_FALSE(side_input0.Cast() != side_input1.Cast()); +} + } // namespace } // namespace mediapipe::api2::builder From 4e135ccdb9273c2e465b701130529bb3d4c77172 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 24 Jan 2023 10:36:36 -0800 Subject: [PATCH 346/346] Internal Model Maker change. PiperOrigin-RevId: 504315641 --- mediapipe/model_maker/python/text/text_classifier/BUILD | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 43f2b6c75..ac5b04f20 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -140,7 +140,11 @@ py_test( "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", "//mediapipe/model_maker/python/text/text_classifier/testdata", ], - tags = ["requires-net:external"], + tags = [ + "notsan", + "requires-mem:16g", + "requires-net:external", + ], deps = [ ":text_classifier_import", "//mediapipe/tasks/python/test:test_utils",