Merge branch 'google:master' into image-embedder-python

This commit is contained in:
Kinar R 2022-11-05 00:17:26 +05:30 committed by GitHub
commit b1d1ef0124
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
79 changed files with 2407 additions and 669 deletions

1
.gitignore vendored
View File

@ -2,5 +2,6 @@ bazel-*
mediapipe/MediaPipe.xcodeproj mediapipe/MediaPipe.xcodeproj
mediapipe/MediaPipe.tulsiproj/*.tulsiconf-user mediapipe/MediaPipe.tulsiproj/*.tulsiconf-user
mediapipe/provisioning_profile.mobileprovision mediapipe/provisioning_profile.mobileprovision
node_modules/
.configure.bazelrc .configure.bazelrc
.user.bazelrc .user.bazelrc

View File

@ -1,4 +1,4 @@
# Copyright 2019 The MediaPipe Authors. # Copyright 2022 The MediaPipe Authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,4 +14,9 @@
licenses(["notice"]) licenses(["notice"])
exports_files(["LICENSE"]) exports_files([
"LICENSE",
"tsconfig.json",
"package.json",
"yarn.lock",
])

View File

@ -501,5 +501,48 @@ libedgetpu_dependencies()
load("@coral_crosstool//:configure.bzl", "cc_crosstool") load("@coral_crosstool//:configure.bzl", "cc_crosstool")
cc_crosstool(name = "crosstool") cc_crosstool(name = "crosstool")
# Node dependencies
http_archive(
name = "build_bazel_rules_nodejs",
sha256 = "5aae76dced38f784b58d9776e4ab12278bc156a9ed2b1d9fcd3e39921dc88fda",
urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.7.1/rules_nodejs-5.7.1.tar.gz"],
)
load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies")
build_bazel_rules_nodejs_dependencies()
# fetches nodejs, npm, and yarn
load("@build_bazel_rules_nodejs//:index.bzl", "node_repositories", "yarn_install")
node_repositories()
yarn_install(
name = "npm",
package_json = "//:package.json",
yarn_lock = "//:yarn.lock",
)
# Protobuf for Node dependencies
http_archive(
name = "rules_proto_grpc",
sha256 = "bbe4db93499f5c9414926e46f9e35016999a4e9f6e3522482d3760dc61011070",
strip_prefix = "rules_proto_grpc-4.2.0",
urls = ["https://github.com/rules-proto-grpc/rules_proto_grpc/archive/4.2.0.tar.gz"],
)
http_archive(
name = "com_google_protobuf_javascript",
sha256 = "35bca1729532b0a77280bf28ab5937438e3dcccd6b31a282d9ae84c896b6f6e3",
strip_prefix = "protobuf-javascript-3.21.2",
urls = ["https://github.com/protocolbuffers/protobuf-javascript/archive/refs/tags/v3.21.2.tar.gz"],
)
load("@rules_proto_grpc//:repositories.bzl", "rules_proto_grpc_toolchains", "rules_proto_grpc_repos")
rules_proto_grpc_toolchains()
rules_proto_grpc_repos()
load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains")
rules_proto_dependencies()
rules_proto_toolchains()
load("//third_party:external_files.bzl", "external_files") load("//third_party:external_files.bzl", "external_files")
external_files() external_files()

View File

@ -141,3 +141,4 @@ Nvidia Jetson and Raspberry Pi, please read
```bash ```bash
(mp_env)mediapipe$ python3 setup.py bdist_wheel (mp_env)mediapipe$ python3 setup.py bdist_wheel
``` ```
7. Exit from the MediaPipe repo directory and launch the Python interpreter.

View File

@ -25,12 +25,12 @@
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/util/tflite/tflite_gpu_runner.h" #include "mediapipe/util/tflite/tflite_gpu_runner.h"
#if defined(MEDIAPIPE_ANDROID) #if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/util/android/file/base/file.h" #include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/filesystem.h" #include "mediapipe/util/android/file/base/filesystem.h"
#include "mediapipe/util/android/file/base/helpers.h" #include "mediapipe/util/android/file/base/helpers.h"
#endif // MEDIAPIPE_ANDROID #endif // defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe #define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
@ -231,7 +231,7 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner(
return tflite_gpu_runner_->Build(); return tflite_gpu_runner_->Build();
} }
#if defined(MEDIAPIPE_ANDROID) #if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init( absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init(
const mediapipe::InferenceCalculatorOptions& options, const mediapipe::InferenceCalculatorOptions& options,
const mediapipe::InferenceCalculatorOptions::Delegate::Gpu& const mediapipe::InferenceCalculatorOptions::Delegate::Gpu&
@ -318,7 +318,7 @@ InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches(
tflite::gpu::TFLiteGPURunner* gpu_runner) const { tflite::gpu::TFLiteGPURunner* gpu_runner) const {
return absl::OkStatus(); return absl::OkStatus();
} }
#endif // MEDIAPIPE_ANDROID #endif // defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
absl::Status InferenceCalculatorGlAdvancedImpl::UpdateContract( absl::Status InferenceCalculatorGlAdvancedImpl::UpdateContract(
CalculatorContract* cc) { CalculatorContract* cc) {

View File

@ -15,7 +15,7 @@
# Description: # Description:
# The dependencies of mediapipe. # The dependencies of mediapipe.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
licenses(["notice"]) licenses(["notice"])
@ -38,7 +38,7 @@ bzl_library(
visibility = ["//mediapipe/framework:__subpackages__"], visibility = ["//mediapipe/framework:__subpackages__"],
) )
proto_library( mediapipe_proto_library(
name = "proto_descriptor_proto", name = "proto_descriptor_proto",
srcs = ["proto_descriptor.proto"], srcs = ["proto_descriptor.proto"],
visibility = [ visibility = [
@ -47,13 +47,6 @@ proto_library(
], ],
) )
mediapipe_cc_proto_library(
name = "proto_descriptor_cc_proto",
srcs = ["proto_descriptor.proto"],
visibility = ["//mediapipe/framework:__subpackages__"],
deps = [":proto_descriptor_proto"],
)
cc_library( cc_library(
name = "aligned_malloc_and_free", name = "aligned_malloc_and_free",
hdrs = ["aligned_malloc_and_free.h"], hdrs = ["aligned_malloc_and_free.h"],

View File

@ -4,6 +4,9 @@
""".bzl file for mediapipe open source build configs.""" """.bzl file for mediapipe open source build configs."""
load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "py_proto_library") load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "py_proto_library")
load("@npm//@bazel/typescript:index.bzl", "ts_project")
load("@rules_proto//proto:defs.bzl", "proto_library")
load("@rules_proto_grpc//js:defs.bzl", "js_proto_library")
load("//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_options_library") load("//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_options_library")
def provided_args(**kwargs): def provided_args(**kwargs):
@ -71,7 +74,7 @@ def mediapipe_proto_library(
def_jspb_proto: define the jspb_proto_library target def_jspb_proto: define the jspb_proto_library target
def_options_lib: define the mediapipe_options_library target def_options_lib: define the mediapipe_options_library target
""" """
_ignore = [def_portable_proto, def_objc_proto, def_java_proto, def_jspb_proto, portable_deps] _ignore = [def_portable_proto, def_objc_proto, def_java_proto, portable_deps] # buildifier: disable=unused-variable
# The proto_library targets for the compiled ".proto" source files. # The proto_library targets for the compiled ".proto" source files.
proto_deps = [":" + name] proto_deps = [":" + name]
@ -119,6 +122,24 @@ def mediapipe_proto_library(
compatible_with = compatible_with, compatible_with = compatible_with,
)) ))
if def_jspb_proto:
js_deps = replace_deps(deps, "_proto", "_jspb_proto", False)
proto_library(
name = replace_suffix(name, "_proto", "_lib_proto"),
srcs = srcs,
deps = deps,
)
js_proto_library(
name = replace_suffix(name, "_proto", "_jspb_proto"),
protos = [replace_suffix(name, "_proto", "_lib_proto")],
output_mode = "NO_PREFIX_FLAT",
# Need to specify this to work around bug in js_proto_library()
# https://github.com/bazelbuild/rules_nodejs/issues/3503
legacy_path = "unused",
deps = js_deps,
visibility = visibility,
)
if def_options_lib: if def_options_lib:
cc_deps = replace_deps(deps, "_proto", "_cc_proto") cc_deps = replace_deps(deps, "_proto", "_cc_proto")
mediapipe_options_library(**provided_args( mediapipe_options_library(**provided_args(
@ -182,3 +203,35 @@ def mediapipe_cc_proto_library(name, srcs, visibility = None, deps = [], cc_deps
default_runtime = "@com_google_protobuf//:protobuf", default_runtime = "@com_google_protobuf//:protobuf",
alwayslink = 1, alwayslink = 1,
)) ))
def mediapipe_ts_library(
name,
srcs,
visibility = None,
deps = [],
testonly = 0,
allow_unoptimized_namespaces = False):
"""Generate ts_project for MediaPipe open source version.
Args:
name: the name of the cc_proto_library.
srcs: the .proto files of the cc_proto_library for Bazel use.
visibility: visibility of this target.
deps: a list of dependency labels for Bazel use; must be cc_proto_library.
testonly: test only or not.
allow_unoptimized_namespaces: ignored, used only internally
"""
_ignore = [allow_unoptimized_namespaces] # buildifier: disable=unused-variable
ts_project(**provided_args(
name = name,
srcs = srcs,
visibility = visibility,
deps = deps + [
"@npm//@types/offscreencanvas",
"@npm//@types/google-protobuf",
],
testonly = testonly,
declaration = True,
tsconfig = "//:tsconfig.json",
))

View File

@ -52,6 +52,7 @@ class CustomModel(abc.ABC):
"""Prints a summary of the model.""" """Prints a summary of the model."""
self._model.summary() self._model.summary()
# TODO: Remove this method when all tasks use Metadata writer
def export_tflite( def export_tflite(
self, self,
export_dir: str, export_dir: str,
@ -62,7 +63,7 @@ class CustomModel(abc.ABC):
Args: Args:
export_dir: The directory to save exported files. export_dir: The directory to save exported files.
tflite_filename: File name to save tflite model. The full export path is tflite_filename: File name to save TFLite model. The full export path is
{export_dir}/{tflite_filename}. {export_dir}/{tflite_filename}.
quantization_config: The configuration for model quantization. quantization_config: The configuration for model quantization.
preprocess: A callable to preprocess the representative dataset for preprocess: A callable to preprocess the representative dataset for
@ -73,11 +74,11 @@ class CustomModel(abc.ABC):
tf.io.gfile.makedirs(export_dir) tf.io.gfile.makedirs(export_dir)
tflite_filepath = os.path.join(export_dir, tflite_filename) tflite_filepath = os.path.join(export_dir, tflite_filename)
# TODO: Populate metadata to the exported TFLite model. tflite_model = model_util.convert_to_tflite(
model_util.export_tflite(
model=self._model, model=self._model,
tflite_filepath=tflite_filepath,
quantization_config=quantization_config, quantization_config=quantization_config,
preprocess=preprocess) preprocess=preprocess)
model_util.save_tflite(
tflite_model=tflite_model, tflite_file=tflite_filepath)
tf.compat.v1.logging.info( tf.compat.v1.logging.info(
'TensorFlow Lite model exported successfully: %s' % tflite_filepath) 'TensorFlow Lite model exported successfully: %s' % tflite_filepath)

View File

@ -89,28 +89,25 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
return len(train_data) // batch_size return len(train_data) // batch_size
def export_tflite( def convert_to_tflite(
model: tf.keras.Model, model: tf.keras.Model,
tflite_filepath: str,
quantization_config: Optional[quantization.QuantizationConfig] = None, quantization_config: Optional[quantization.QuantizationConfig] = None,
supported_ops: Tuple[tf.lite.OpsSet, supported_ops: Tuple[tf.lite.OpsSet,
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,), ...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
preprocess: Optional[Callable[..., bool]] = None): preprocess: Optional[Callable[..., bool]] = None) -> bytearray:
"""Converts the model to tflite format and saves it. """Converts the input Keras model to TFLite format.
Args: Args:
model: model to be converted to tflite. model: Keras model to be converted to TFLite.
tflite_filepath: File path to save tflite model.
quantization_config: Configuration for post-training quantization. quantization_config: Configuration for post-training quantization.
supported_ops: A list of supported ops in the converted TFLite file. supported_ops: A list of supported ops in the converted TFLite file.
preprocess: A callable to preprocess the representative dataset for preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature, label, quantization. The callable takes three arguments in order: feature, label,
and is_training. and is_training.
"""
if tflite_filepath is None:
raise ValueError(
"TFLite filepath couldn't be None when exporting to tflite.")
Returns:
bytearray of TFLite model
"""
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
save_path = os.path.join(temp_dir, 'saved_model') save_path = os.path.join(temp_dir, 'saved_model')
model.save(save_path, include_optimizer=False, save_format='tf') model.save(save_path, include_optimizer=False, save_format='tf')
@ -122,9 +119,22 @@ def export_tflite(
converter.target_spec.supported_ops = supported_ops converter.target_spec.supported_ops = supported_ops
tflite_model = converter.convert() tflite_model = converter.convert()
return tflite_model
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
def save_tflite(tflite_model: bytearray, tflite_file: str) -> None:
"""Saves TFLite file to tflite_file.
Args:
tflite_model: A valid flatbuffer representing the TFLite model.
tflite_file: File path to save TFLite model.
"""
if tflite_file is None:
raise ValueError("TFLite filepath can't be None when exporting to TFLite.")
with tf.io.gfile.GFile(tflite_file, 'wb') as f:
f.write(tflite_model) f.write(tflite_model)
tf.compat.v1.logging.info(
'TensorFlow Lite model exported successfully to: %s' % tflite_file)
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
@ -176,14 +186,12 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
class LiteRunner(object): class LiteRunner(object):
"""A runner to do inference with the TFLite model.""" """A runner to do inference with the TFLite model."""
def __init__(self, tflite_filepath: str): def __init__(self, tflite_model: bytearray):
"""Initializes Lite runner with tflite model file. """Initializes Lite runner from TFLite model buffer.
Args: Args:
tflite_filepath: File path to the TFLite model. tflite_model: A valid flatbuffer representing the TFLite model.
""" """
with tf.io.gfile.GFile(tflite_filepath, 'rb') as f:
tflite_model = f.read()
self.interpreter = tf.lite.Interpreter(model_content=tflite_model) self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
self.interpreter.allocate_tensors() self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details() self.input_details = self.interpreter.get_input_details()
@ -250,9 +258,9 @@ class LiteRunner(object):
return output_tensors return output_tensors
def get_lite_runner(tflite_filepath: str) -> 'LiteRunner': def get_lite_runner(tflite_buffer: bytearray) -> 'LiteRunner':
"""Returns a `LiteRunner` from file path to TFLite model.""" """Returns a `LiteRunner` from flatbuffer of the TFLite model."""
lite_runner = LiteRunner(tflite_filepath) lite_runner = LiteRunner(tflite_buffer)
return lite_runner return lite_runner

View File

@ -95,13 +95,12 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
'name': 'test' 'name': 'test'
}) })
def test_export_tflite(self): def test_convert_to_tflite(self):
input_dim = 4 input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2) model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') tflite_model = model_util.convert_to_tflite(model)
model_util.export_tflite(model, tflite_file)
test_util.test_tflite( test_util.test_tflite(
keras_model=model, tflite_file=tflite_file, size=[1, input_dim]) keras_model=model, tflite_model=tflite_model, size=[1, input_dim])
@parameterized.named_parameters( @parameterized.named_parameters(
dict( dict(
@ -118,25 +117,32 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
testcase_name='float16_quantize', testcase_name='float16_quantize',
config=quantization.QuantizationConfig.for_float16(), config=quantization.QuantizationConfig.for_float16(),
model_size=1468)) model_size=1468))
def test_export_tflite_quantized(self, config, model_size): def test_convert_to_tflite_quantized(self, config, model_size):
input_dim = 16 input_dim = 16
num_classes = 2 num_classes = 2
max_input_value = 5 max_input_value = 5
model = test_util.build_model( model = test_util.build_model(
input_shape=[input_dim], num_classes=num_classes) input_shape=[input_dim], num_classes=num_classes)
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
model_util.export_tflite( tflite_model = model_util.convert_to_tflite(
model=model, tflite_filepath=tflite_file, quantization_config=config) model=model, quantization_config=config)
self.assertTrue( self.assertTrue(
test_util.test_tflite( test_util.test_tflite(
keras_model=model, keras_model=model,
tflite_file=tflite_file, tflite_model=tflite_model,
size=[1, input_dim], size=[1, input_dim],
high=max_input_value, high=max_input_value,
atol=1e-00)) atol=1e-00))
self.assertNear(os.path.getsize(tflite_file), model_size, 300) self.assertNear(len(tflite_model), model_size, 300)
def test_save_tflite(self):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_model = model_util.convert_to_tflite(model)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
test_util.test_tflite_file(
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()

View File

@ -79,13 +79,13 @@ def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model:
return model return model
def is_same_output(tflite_file: str, def is_same_output(tflite_model: bytearray,
keras_model: tf.keras.Model, keras_model: tf.keras.Model,
input_tensors: Union[List[tf.Tensor], tf.Tensor], input_tensors: Union[List[tf.Tensor], tf.Tensor],
atol: float = 1e-04) -> bool: atol: float = 1e-04) -> bool:
"""Returns if the output of TFLite model and keras model are identical.""" """Returns if the output of TFLite model and keras model are identical."""
# Gets output from lite model. # Gets output from lite model.
lite_runner = model_util.get_lite_runner(tflite_file) lite_runner = model_util.get_lite_runner(tflite_model)
lite_output = lite_runner.run(input_tensors) lite_output = lite_runner.run(input_tensors)
# Gets output from keras model. # Gets output from keras model.
@ -95,12 +95,41 @@ def is_same_output(tflite_file: str,
def test_tflite(keras_model: tf.keras.Model, def test_tflite(keras_model: tf.keras.Model,
tflite_file: str, tflite_model: bytearray,
size: Union[int, List[int]], size: Union[int, List[int]],
high: float = 1, high: float = 1,
atol: float = 1e-04) -> bool: atol: float = 1e-04) -> bool:
"""Verifies if the output of TFLite model and TF Keras model are identical. """Verifies if the output of TFLite model and TF Keras model are identical.
Args:
keras_model: Input TensorFlow Keras model.
tflite_model: Input TFLite model flatbuffer.
size: Size of the input tesnor.
high: Higher boundary of the values in input tensors.
atol: Absolute tolerance of the difference between the outputs of Keras
model and TFLite model.
Returns:
True if the output of TFLite model and TF Keras model are identical.
Otherwise, False.
"""
random_input = create_random_sample(size=size, high=high)
random_input = tf.convert_to_tensor(random_input)
return is_same_output(
tflite_model=tflite_model,
keras_model=keras_model,
input_tensors=random_input,
atol=atol)
def test_tflite_file(keras_model: tf.keras.Model,
tflite_file: bytearray,
size: Union[int, List[int]],
high: float = 1,
atol: float = 1e-04) -> bool:
"""Verifies if the output of TFLite model and TF Keras model are identical.
Args: Args:
keras_model: Input TensorFlow Keras model. keras_model: Input TensorFlow Keras model.
tflite_file: Input TFLite model file. tflite_file: Input TFLite model file.
@ -113,11 +142,6 @@ def test_tflite(keras_model: tf.keras.Model,
True if the output of TFLite model and TF Keras model are identical. True if the output of TFLite model and TF Keras model are identical.
Otherwise, False. Otherwise, False.
""" """
random_input = create_random_sample(size=size, high=high) with tf.io.gfile.GFile(tflite_file, "rb") as f:
random_input = tf.convert_to_tensor(random_input) tflite_model = f.read()
return test_tflite(keras_model, tflite_model, size, high, atol)
return is_same_output(
tflite_file=tflite_file,
keras_model=keras_model,
input_tensors=random_input,
atol=atol)

View File

@ -1,4 +0,0 @@
# MediaPipe Model Maker Internal Library
This directory contains model maker library for internal users and experimental
purposes.

View File

@ -1 +0,0 @@
"""Model maker internal library."""

View File

@ -81,6 +81,8 @@ py_library(
"//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization", "//mediapipe/model_maker/python/core/utils:quantization",
"//mediapipe/model_maker/python/vision/core:image_preprocessing", "//mediapipe/model_maker/python/vision/core:image_preprocessing",
"//mediapipe/tasks/python/metadata/metadata_writers:image_classifier",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
], ],
) )
@ -88,7 +90,11 @@ py_library(
name = "image_classifier_test_lib", name = "image_classifier_test_lib",
testonly = 1, testonly = 1,
srcs = ["image_classifier_test.py"], srcs = ["image_classifier_test.py"],
deps = [":image_classifier_import"], data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"],
deps = [
":image_classifier_import",
"//mediapipe/tasks/python/test:test_utils",
],
) )
py_test( py_test(

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""APIs to train image classifier model.""" """APIs to train image classifier model."""
import os
from typing import List, Optional from typing import List, Optional
@ -26,6 +27,8 @@ from mediapipe.model_maker.python.vision.core import image_preprocessing
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
class ImageClassifier(classifier.Classifier): class ImageClassifier(classifier.Classifier):
@ -156,15 +159,32 @@ class ImageClassifier(classifier.Classifier):
self, self,
model_name: str = 'model.tflite', model_name: str = 'model.tflite',
quantization_config: Optional[quantization.QuantizationConfig] = None): quantization_config: Optional[quantization.QuantizationConfig] = None):
"""Converts the model to the requested formats and exports to a file. """Converts and saves the model to a TFLite file with metadata included.
Note that only the TFLite file is needed for deployment. This function also
saves a metadata.json file to the same directory as the TFLite file which
can be used to interpret the metadata content in the TFLite file.
Args: Args:
model_name: File name to save tflite model. The full export path is model_name: File name to save TFLite model with metadata. The full export
{export_dir}/{tflite_filename}. path is {self._hparams.model_dir}/{model_name}.
quantization_config: The configuration for model quantization. quantization_config: The configuration for model quantization.
""" """
super().export_tflite( if not tf.io.gfile.exists(self._hparams.model_dir):
self._hparams.model_dir, tf.io.gfile.makedirs(self._hparams.model_dir)
model_name, tflite_file = os.path.join(self._hparams.model_dir, model_name)
quantization_config, metadata_file = os.path.join(self._hparams.model_dir, 'metadata.json')
tflite_model = model_util.convert_to_tflite(
model=self._model,
quantization_config=quantization_config,
preprocess=self._preprocess) preprocess=self._preprocess)
writer = image_classifier_writer.MetadataWriter.create(
tflite_model,
self._model_spec.mean_rgb,
self._model_spec.stddev_rgb,
labels=metadata_writer.Labels().add(self._label_names))
tflite_model_with_metadata, metadata_json = writer.populate()
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
with open(metadata_file, 'w') as f:
f.write(metadata_json)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import filecmp
import os import os
from absl.testing import parameterized from absl.testing import parameterized
@ -19,6 +20,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.vision import image_classifier from mediapipe.model_maker.python.vision import image_classifier
from mediapipe.tasks.python.test import test_utils
def _fill_image(rgb, image_size): def _fill_image(rgb, image_size):
@ -86,7 +88,7 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
validation_data=self.test_data) validation_data=self.test_data)
self._test_accuracy(model) self._test_accuracy(model)
def test_efficientnetlite0_model_with_model_maker_retraining_lib(self): def test_efficientnetlite0_model_train_and_export(self):
hparams = image_classifier.HParams( hparams = image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True) train_epochs=1, batch_size=1, shuffle=True)
model = image_classifier.ImageClassifier.create( model = image_classifier.ImageClassifier.create(
@ -96,6 +98,19 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
validation_data=self.test_data) validation_data=self.test_data)
self._test_accuracy(model) self._test_accuracy(model)
# Test export_model
model.export_model()
output_metadata_file = os.path.join(hparams.model_dir, 'metadata.json')
output_tflite_file = os.path.join(hparams.model_dir, 'model.tflite')
expected_metadata_file = test_utils.get_test_data_path('metadata.json')
self.assertTrue(os.path.exists(output_tflite_file))
self.assertGreater(os.path.getsize(output_tflite_file), 0)
self.assertTrue(os.path.exists(output_metadata_file))
self.assertGreater(os.path.getsize(output_metadata_file), 0)
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
def _test_accuracy(self, model, threshold=0.0): def _test_accuracy(self, model, threshold=0.0):
_, accuracy = model.evaluate(self.test_data) _, accuracy = model.evaluate(self.test_data)
self.assertGreaterEqual(accuracy, threshold) self.assertGreaterEqual(accuracy, threshold)

View File

@ -0,0 +1,23 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//mediapipe/model_maker/python/vision/image_classifier:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "testdata",
srcs = ["metadata.json"],
)

View File

@ -0,0 +1,68 @@
{
"name": "ImageClassifier",
"description": "Identify the most prominent object in the image from a known set of categories.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "image",
"description": "Input image to be processed.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
"color_space": "RGB"
}
},
"process_units": [
{
"options_type": "NormalizationOptions",
"options": {
"mean": [
0.0
],
"std": [
255.0
]
}
}
],
"stats": {
"max": [
1.0
],
"min": [
0.0
]
}
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
]
}
],
"min_parser_version": "1.0.0"
}

View File

@ -65,6 +65,7 @@ cc_library(
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory", "//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
"//mediapipe/tasks/cc/audio/core:base_audio_task_api", "//mediapipe/tasks/cc/audio/core:base_audio_task_api",
"//mediapipe/tasks/cc/audio/core:running_mode", "//mediapipe/tasks/cc/audio/core:running_mode",
"//mediapipe/tasks/cc/components/containers:classification_result",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",

View File

@ -18,12 +18,14 @@ limitations under the License.
#include <map> #include <map>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector>
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h" #include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
@ -38,12 +40,16 @@ namespace audio_classifier {
namespace { namespace {
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kAudioStreamName[] = "audio_in"; constexpr char kAudioStreamName[] = "audio_in";
constexpr char kAudioTag[] = "AUDIO"; constexpr char kAudioTag[] = "AUDIO";
constexpr char kClassificationResultStreamName[] = "classification_result_out"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsName[] = "classifications_out";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
constexpr char kTimestampedClassificationsName[] =
"timestamped_classifications_out";
constexpr char kSampleRateName[] = "sample_rate_in"; constexpr char kSampleRateName[] = "sample_rate_in";
constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kSubgraphTypeName[] = constexpr char kSubgraphTypeName[] =
@ -63,9 +69,11 @@ CalculatorGraphConfig CreateGraphConfig(
} }
subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap( subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
options_proto.get()); options_proto.get());
subgraph.Out(kClassificationResultTag) subgraph.Out(kClassificationsTag).SetName(kClassificationsName) >>
.SetName(kClassificationResultStreamName) >> graph.Out(kClassificationsTag);
graph.Out(kClassificationResultTag); subgraph.Out(kTimestampedClassificationsTag)
.SetName(kTimestampedClassificationsName) >>
graph.Out(kTimestampedClassificationsTag);
return graph.GetConfig(); return graph.GetConfig();
} }
@ -91,13 +99,30 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
return options_proto; return options_proto;
} }
absl::StatusOr<ClassificationResult> ConvertOutputPackets( absl::StatusOr<std::vector<AudioClassifierResult>> ConvertOutputPackets(
absl::StatusOr<tasks::core::PacketMap> status_or_packets) { absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
if (!status_or_packets.ok()) { if (!status_or_packets.ok()) {
return status_or_packets.status(); return status_or_packets.status();
} }
return status_or_packets.value()[kClassificationResultStreamName] auto classification_results =
.Get<ClassificationResult>(); status_or_packets.value()[kTimestampedClassificationsName]
.Get<std::vector<ClassificationResult>>();
std::vector<AudioClassifierResult> results;
results.reserve(classification_results.size());
for (const auto& classification_result : classification_results) {
results.emplace_back(ConvertToClassificationResult(classification_result));
}
return results;
}
absl::StatusOr<AudioClassifierResult> ConvertAsyncOutputPackets(
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
if (!status_or_packets.ok()) {
return status_or_packets.status();
}
return ConvertToClassificationResult(
status_or_packets.value()[kClassificationsName]
.Get<ClassificationResult>());
} }
} // namespace } // namespace
@ -118,7 +143,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
auto result_callback = options->result_callback; auto result_callback = options->result_callback;
packets_callback = packets_callback =
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) { [=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
result_callback(ConvertOutputPackets(status_or_packets)); result_callback(ConvertAsyncOutputPackets(status_or_packets));
}; };
} }
return core::AudioTaskApiFactory::Create<AudioClassifier, return core::AudioTaskApiFactory::Create<AudioClassifier,
@ -128,7 +153,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
std::move(packets_callback)); std::move(packets_callback));
} }
absl::StatusOr<ClassificationResult> AudioClassifier::Classify( absl::StatusOr<std::vector<AudioClassifierResult>> AudioClassifier::Classify(
Matrix audio_clip, double audio_sample_rate) { Matrix audio_clip, double audio_sample_rate) {
return ConvertOutputPackets(ProcessAudioClip( return ConvertOutputPackets(ProcessAudioClip(
{{kAudioStreamName, MakePacket<Matrix>(std::move(audio_clip))}, {{kAudioStreamName, MakePacket<Matrix>(std::move(audio_clip))},

View File

@ -18,12 +18,13 @@ limitations under the License.
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector>
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h" #include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
#include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
@ -32,6 +33,10 @@ namespace tasks {
namespace audio { namespace audio {
namespace audio_classifier { namespace audio_classifier {
// Alias the shared ClassificationResult struct as result type.
using AudioClassifierResult =
::mediapipe::tasks::components::containers::ClassificationResult;
// The options for configuring a mediapipe audio classifier task. // The options for configuring a mediapipe audio classifier task.
struct AudioClassifierOptions { struct AudioClassifierOptions {
// Base options for configuring Task library, such as specifying the TfLite // Base options for configuring Task library, such as specifying the TfLite
@ -59,9 +64,8 @@ struct AudioClassifierOptions {
// The user-defined result callback for processing audio stream data. // The user-defined result callback for processing audio stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::AUDIO_STREAM. // to RunningMode::AUDIO_STREAM.
std::function<void( std::function<void(absl::StatusOr<AudioClassifierResult>)> result_callback =
absl::StatusOr<components::containers::proto::ClassificationResult>)> nullptr;
result_callback = nullptr;
}; };
// Performs audio classification on audio clips or audio stream. // Performs audio classification on audio clips or audio stream.
@ -117,23 +121,36 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
// required to provide the corresponding audio sample rate along with the // required to provide the corresponding audio sample rate along with the
// input audio clips. // input audio clips.
// //
// For each audio clip, the output classifications are grouped in a // The input audio clip may be longer than what the model is able to process
// ClassificationResult object that has three dimensions: // in a single inference. When this occurs, the input audio clip is split into
// Classification head: // multiple chunks starting at different timestamps. For this reason, this
// The prediction heads targeting different audio classification tasks // function returns a vector of ClassificationResult objects, each associated
// such as audio event classification and bird sound classification. // with a timestamp corresponding to the start (in milliseconds) of the chunk
// Classification timestamp: // data that was classified, e.g:
// The start time (in milliseconds) of each audio clip that is sent to the //
// model for audio classification. As the audio classification models take // ClassificationResult #0 (first chunk of data):
// a fixed number of audio samples, long audio clips will be framed to // timestamp_ms: 0 (starts at 0ms)
// multiple buffers (with the desired number of audio samples) during // classifications #0 (single head model):
// preprocessing. // category #0:
// Classification category: // category_name: "Speech"
// The list of the classification categories that model predicts per // score: 0.6
// framed audio clip. // category #1:
// category_name: "Music"
// score: 0.2
// ClassificationResult #1 (second chunk of data):
// timestamp_ms: 800 (starts at 800ms)
// classifications #0 (single head model):
// category #0:
// category_name: "Speech"
// score: 0.5
// category #1:
// category_name: "Silence"
// score: 0.1
// ...
//
// TODO: Use `sample_rate` in AudioClassifierOptions by default // TODO: Use `sample_rate` in AudioClassifierOptions by default
// and makes `audio_sample_rate` optional. // and makes `audio_sample_rate` optional.
absl::StatusOr<components::containers::proto::ClassificationResult> Classify( absl::StatusOr<std::vector<AudioClassifierResult>> Classify(
mediapipe::Matrix audio_clip, double audio_sample_rate); mediapipe::Matrix audio_clip, double audio_sample_rate);
// Sends audio data (a block in a continuous audio stream) to perform audio // Sends audio data (a block in a continuous audio stream) to perform audio
@ -147,17 +164,10 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
// milliseconds) to indicate the start time of the input audio block. The // milliseconds) to indicate the start time of the input audio block. The
// timestamps must be monotonically increasing. // timestamps must be monotonically increasing.
// //
// The output classifications are grouped in a ClassificationResult object // The input audio block may be longer than what the model is able to process
// that has three dimensions: // in a single inference. When this occurs, the input audio block is split
// Classification head: // into multiple chunks. For this reason, the callback may be called multiple
// The prediction heads targeting different audio classification tasks // times (once per chunk) for each call to this function.
// such as audio event classification and bird sound classification.
// Classification timestamp :
// The start time (in milliseconds) of the framed audio block that is sent
// to the model for audio classification.
// Classification category:
// The list of the classification categories that model predicts per
// framed audio clip.
absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms); absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms);
// Shuts down the AudioClassifier when all works are done. // Shuts down the AudioClassifier when all works are done.

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <stdint.h> #include <stdint.h>
#include <utility> #include <utility>
#include <vector>
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
@ -57,12 +58,20 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kAtPrestreamTag[] = "AT_PRESTREAM"; constexpr char kAtPrestreamTag[] = "AT_PRESTREAM";
constexpr char kAudioTag[] = "AUDIO"; constexpr char kAudioTag[] = "AUDIO";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
constexpr char kPacketTag[] = "PACKET"; constexpr char kPacketTag[] = "PACKET";
constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsTag[] = "TIMESTAMPS";
// Struct holding the different output streams produced by the audio classifier
// graph.
struct AudioClassifierOutputStreams {
Source<ClassificationResult> classifications;
Source<std::vector<ClassificationResult>> timestamped_classifications;
};
absl::Status SanityCheckOptions( absl::Status SanityCheckOptions(
const proto::AudioClassifierGraphOptions& options) { const proto::AudioClassifierGraphOptions& options) {
if (options.base_options().use_stream_mode() && if (options.base_options().use_stream_mode() &&
@ -124,16 +133,20 @@ void ConfigureAudioToTensorCalculator(
// series stream header with sample rate info. // series stream header with sample rate info.
// //
// Outputs: // Outputs:
// CLASSIFICATION_RESULT - ClassificationResult // CLASSIFICATIONS - ClassificationResult @Optional
// The aggregated classification result object that has 3 dimensions: // The classification results aggregated by head. Only produces results if
// (classification head, classification timestamp, classification category). // the graph if the 'use_stream_mode' option is true.
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
// The classification result aggregated by timestamp, then by head. Only
// produces results if the graph if the 'use_stream_mode' option is false.
// //
// Example: // Example:
// node { // node {
// calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph" // calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"
// input_stream: "AUDIO:audio_in" // input_stream: "AUDIO:audio_in"
// input_stream: "SAMPLE_RATE:sample_rate_in" // input_stream: "SAMPLE_RATE:sample_rate_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out" // output_stream: "CLASSIFICATIONS:classifications"
// output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications"
// options { // options {
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext] // [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext]
// { // {
@ -162,7 +175,7 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
.base_options() .base_options()
.use_stream_mode(); .use_stream_mode();
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto classification_result_out, auto output_streams,
BuildAudioClassificationTask( BuildAudioClassificationTask(
sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources, sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
graph[Input<Matrix>(kAudioTag)], graph[Input<Matrix>(kAudioTag)],
@ -170,8 +183,11 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
? absl::nullopt ? absl::nullopt
: absl::make_optional(graph[Input<double>(kSampleRateTag)]), : absl::make_optional(graph[Input<double>(kSampleRateTag)]),
graph)); graph));
classification_result_out >> output_streams.classifications >>
graph[Output<ClassificationResult>(kClassificationResultTag)]; graph[Output<ClassificationResult>(kClassificationsTag)];
output_streams.timestamped_classifications >>
graph[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -187,7 +203,7 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
// audio_in: (mediapipe::Matrix) stream to run audio classification on. // audio_in: (mediapipe::Matrix) stream to run audio classification on.
// sample_rate_in: (double) optional stream of the input audio sample rate. // sample_rate_in: (double) optional stream of the input audio sample rate.
// graph: the mediapipe builder::Graph instance to be updated. // graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>> BuildAudioClassificationTask( absl::StatusOr<AudioClassifierOutputStreams> BuildAudioClassificationTask(
const proto::AudioClassifierGraphOptions& task_options, const proto::AudioClassifierGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Matrix> audio_in, const core::ModelResources& model_resources, Source<Matrix> audio_in,
absl::optional<Source<double>> sample_rate_in, Graph& graph) { absl::optional<Source<double>> sample_rate_in, Graph& graph) {
@ -250,16 +266,20 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Time aggregation is only needed for performing audio classification on // Time aggregation is only needed for performing audio classification on
// audio files. Disables time aggregration by not connecting the // audio files. Disables timestamp aggregation by not connecting the
// "TIMESTAMPS" streams. // "TIMESTAMPS" streams.
if (!use_stream_mode) { if (!use_stream_mode) {
audio_to_tensor.Out(kTimestampsTag) >> postprocessing.In(kTimestampsTag); audio_to_tensor.Out(kTimestampsTag) >> postprocessing.In(kTimestampsTag);
} }
// Outputs the aggregated classification result as the subgraph output // Output both streams as graph output streams/
// stream. return AudioClassifierOutputStreams{
return postprocessing[Output<ClassificationResult>( /*classifications=*/postprocessing[Output<ClassificationResult>(
kClassificationResultTag)]; kClassificationsTag)],
/*timestamped_classifications=*/
postprocessing[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)],
};
} }
}; };

View File

@ -32,13 +32,11 @@ limitations under the License.
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/audio/utils/test_utils.h" #include "mediapipe/tasks/cc/audio/utils/test_utils.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" #include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe { namespace mediapipe {
@ -49,7 +47,6 @@ namespace {
using ::absl::StatusOr; using ::absl::StatusOr;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -73,95 +70,86 @@ Matrix GetAudioData(absl::string_view filename) {
return matrix_mapping.matrix(); return matrix_mapping.matrix();
} }
void CheckSpeechClassificationResult(const ClassificationResult& result) { void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
EXPECT_THAT(result.classifications_size(), testing::Eq(1)); int expected_num_categories = 521) {
EXPECT_EQ(result.classifications(0).head_name(), "scores"); EXPECT_EQ(result.size(), 5);
EXPECT_EQ(result.classifications(0).head_index(), 0); // Ignore last result, which operates on a too small chunk to return relevant
EXPECT_THAT(result.classifications(0).entries_size(), testing::Eq(5)); // results.
std::vector<int64> timestamps_ms = {0, 975, 1950, 2925}; std::vector<int64> timestamps_ms = {0, 975, 1950, 2925};
for (int i = 0; i < timestamps_ms.size(); i++) { for (int i = 0; i < timestamps_ms.size(); i++) {
EXPECT_THAT(result.classifications(0).entries(0).categories_size(), EXPECT_EQ(result[i].timestamp_ms, timestamps_ms[i]);
testing::Eq(521)); EXPECT_EQ(result[i].classifications.size(), 1);
const auto* top_category = auto classifications = result[i].classifications[0];
&result.classifications(0).entries(0).categories(0); EXPECT_EQ(classifications.head_index, 0);
EXPECT_THAT(top_category->category_name(), testing::Eq("Speech")); EXPECT_EQ(classifications.head_name, "scores");
EXPECT_GT(top_category->score(), 0.9f); EXPECT_EQ(classifications.categories.size(), expected_num_categories);
EXPECT_EQ(result.classifications(0).entries(i).timestamp_ms(), auto category = classifications.categories[0];
timestamps_ms[i]); EXPECT_EQ(category.index, 0);
EXPECT_EQ(category.category_name, "Speech");
EXPECT_GT(category.score, 0.9f);
} }
} }
void CheckTwoHeadsClassificationResult(const ClassificationResult& result) { void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
EXPECT_THAT(result.classifications_size(), testing::Eq(2)); EXPECT_GE(result.size(), 1);
// Checks classification head #1. EXPECT_LE(result.size(), 2);
EXPECT_EQ(result.classifications(0).head_name(), "yamnet_classification"); // Check first result.
EXPECT_EQ(result.classifications(0).head_index(), 0); EXPECT_EQ(result[0].timestamp_ms, 0);
EXPECT_THAT(result.classifications(0).entries(0).categories_size(), EXPECT_EQ(result[0].classifications.size(), 2);
testing::Eq(521)); // Check first head.
const auto* top_category = EXPECT_EQ(result[0].classifications[0].head_index, 0);
&result.classifications(0).entries(0).categories(0); EXPECT_EQ(result[0].classifications[0].head_name, "yamnet_classification");
EXPECT_THAT(top_category->category_name(), EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
testing::Eq("Environmental noise")); EXPECT_EQ(result[0].classifications[0].categories[0].index, 508);
EXPECT_GT(top_category->score(), 0.5f); EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
EXPECT_EQ(result.classifications(0).entries(0).timestamp_ms(), 0); "Environmental noise");
if (result.classifications(0).entries_size() == 2) { EXPECT_GT(result[0].classifications[0].categories[0].score, 0.5f);
top_category = &result.classifications(0).entries(1).categories(0); // Check second head.
EXPECT_THAT(top_category->category_name(), testing::Eq("Silence")); EXPECT_EQ(result[0].classifications[1].head_index, 1);
EXPECT_GT(top_category->score(), 0.99f); EXPECT_EQ(result[0].classifications[1].head_name, "bird_classification");
EXPECT_EQ(result.classifications(0).entries(1).timestamp_ms(), 975); EXPECT_EQ(result[0].classifications[1].categories.size(), 5);
EXPECT_EQ(result[0].classifications[1].categories[0].index, 4);
EXPECT_EQ(result[0].classifications[1].categories[0].category_name,
"Chestnut-crowned Antpitta");
EXPECT_GT(result[0].classifications[1].categories[0].score, 0.9f);
// Check second result, if present.
if (result.size() == 2) {
EXPECT_EQ(result[1].timestamp_ms, 975);
EXPECT_EQ(result[1].classifications.size(), 2);
// Check first head.
EXPECT_EQ(result[1].classifications[0].head_index, 0);
EXPECT_EQ(result[1].classifications[0].head_name, "yamnet_classification");
EXPECT_EQ(result[1].classifications[0].categories.size(), 521);
EXPECT_EQ(result[1].classifications[0].categories[0].index, 494);
EXPECT_EQ(result[1].classifications[0].categories[0].category_name,
"Silence");
EXPECT_GT(result[1].classifications[0].categories[0].score, 0.99f);
// Check second head.
EXPECT_EQ(result[1].classifications[1].head_index, 1);
EXPECT_EQ(result[1].classifications[1].head_name, "bird_classification");
EXPECT_EQ(result[1].classifications[1].categories.size(), 5);
EXPECT_EQ(result[1].classifications[1].categories[0].index, 1);
EXPECT_EQ(result[1].classifications[1].categories[0].category_name,
"White-breasted Wood-Wren");
EXPECT_GT(result[1].classifications[1].categories[0].score, 0.99f);
} }
// Checks classification head #2.
EXPECT_EQ(result.classifications(1).head_name(), "bird_classification");
EXPECT_EQ(result.classifications(1).head_index(), 1);
EXPECT_THAT(result.classifications(1).entries(0).categories_size(),
testing::Eq(5));
top_category = &result.classifications(1).entries(0).categories(0);
EXPECT_THAT(top_category->category_name(),
testing::Eq("Chestnut-crowned Antpitta"));
EXPECT_GT(top_category->score(), 0.9f);
EXPECT_EQ(result.classifications(1).entries(0).timestamp_ms(), 0);
} }
ClassificationResult GenerateSpeechClassificationResult() { void CheckStreamingModeResults(std::vector<AudioClassifierResult> outputs) {
return ParseTextProtoOrDie<ClassificationResult>( EXPECT_EQ(outputs.size(), 5);
R"pb(classifications { // Ignore last result, which operates on a too small chunk to return relevant
head_index: 0 // results.
head_name: "scores" for (int i = 0; i < outputs.size() - 1; i++) {
entries { EXPECT_FALSE(outputs[i].timestamp_ms.has_value());
categories { index: 0 score: 0.94140625 category_name: "Speech" } EXPECT_EQ(outputs[i].classifications.size(), 1);
timestamp_ms: 0 EXPECT_EQ(outputs[i].classifications[0].head_index, 0);
} EXPECT_EQ(outputs[i].classifications[0].head_name, "scores");
entries { EXPECT_EQ(outputs[i].classifications[0].categories.size(), 1);
categories { index: 0 score: 0.9921875 category_name: "Speech" } EXPECT_EQ(outputs[i].classifications[0].categories[0].index, 0);
timestamp_ms: 975 EXPECT_EQ(outputs[i].classifications[0].categories[0].category_name,
} "Speech");
entries { EXPECT_GT(outputs[i].classifications[0].categories[0].score, 0.9f);
categories { index: 0 score: 0.98828125 category_name: "Speech" }
timestamp_ms: 1950
}
entries {
categories { index: 0 score: 0.99609375 category_name: "Speech" }
timestamp_ms: 2925
}
entries {
# categories are filtered out due to the low scores.
timestamp_ms: 3900
}
})pb");
}
void CheckStreamingModeClassificationResult(
std::vector<ClassificationResult> outputs) {
ASSERT_TRUE(outputs.size() == 5 || outputs.size() == 6);
auto expected_results = GenerateSpeechClassificationResult();
for (int i = 0; i < outputs.size() - 1; ++i) {
EXPECT_THAT(outputs[i].classifications(0).entries(0),
EqualsProto(expected_results.classifications(0).entries(i)));
} }
int last_elem_index = outputs.size() - 1;
EXPECT_EQ(
mediapipe::Timestamp::Done().Value() / 1000,
outputs[last_elem_index].classifications(0).entries(0).timestamp_ms());
} }
class CreateFromOptionsTest : public tflite_shims::testing::Test {}; class CreateFromOptionsTest : public tflite_shims::testing::Test {};
@ -264,7 +252,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
options->result_callback = options->result_callback =
[](absl::StatusOr<ClassificationResult> status_or_result) {}; [](absl::StatusOr<AudioClassifierResult> status_or_result) {};
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or = StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
AudioClassifier::Create(std::move(options)); AudioClassifier::Create(std::move(options));
@ -284,7 +272,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) {
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
options->running_mode = core::RunningMode::AUDIO_STREAM; options->running_mode = core::RunningMode::AUDIO_STREAM;
options->result_callback = options->result_callback =
[](absl::StatusOr<ClassificationResult> status_or_result) {}; [](absl::StatusOr<AudioClassifierResult> status_or_result) {};
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or = StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
AudioClassifier::Create(std::move(options)); AudioClassifier::Create(std::move(options));
@ -310,7 +298,7 @@ TEST_F(ClassifyTest, Succeeds) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/16000)); /*audio_sample_rate=*/16000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckSpeechClassificationResult(result); CheckSpeechResult(result);
} }
TEST_F(ClassifyTest, SucceedsWithResampling) { TEST_F(ClassifyTest, SucceedsWithResampling) {
@ -324,7 +312,7 @@ TEST_F(ClassifyTest, SucceedsWithResampling) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckSpeechClassificationResult(result); CheckSpeechResult(result);
} }
TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) { TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
@ -339,13 +327,13 @@ TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
auto result_16k_hz, auto result_16k_hz,
audio_classifier->Classify(std::move(audio_buffer_16k_hz), audio_classifier->Classify(std::move(audio_buffer_16k_hz),
/*audio_sample_rate=*/16000)); /*audio_sample_rate=*/16000));
CheckSpeechClassificationResult(result_16k_hz); CheckSpeechResult(result_16k_hz);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto result_48k_hz, auto result_48k_hz,
audio_classifier->Classify(std::move(audio_buffer_48k_hz), audio_classifier->Classify(std::move(audio_buffer_48k_hz),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckSpeechClassificationResult(result_48k_hz); CheckSpeechResult(result_48k_hz);
} }
TEST_F(ClassifyTest, SucceedsWithInsufficientData) { TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
@ -361,15 +349,16 @@ TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto result, audio_classifier->Classify(std::move(zero_matrix), 16000)); auto result, audio_classifier->Classify(std::move(zero_matrix), 16000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result.classifications_size(), testing::Eq(1)); EXPECT_EQ(result.size(), 1);
EXPECT_THAT(result.classifications(0).entries_size(), testing::Eq(1)); EXPECT_EQ(result[0].timestamp_ms, 0);
EXPECT_THAT(result.classifications(0).entries(0).categories_size(), EXPECT_EQ(result[0].classifications.size(), 1);
testing::Eq(521)); EXPECT_EQ(result[0].classifications[0].head_index, 0);
EXPECT_THAT( EXPECT_EQ(result[0].classifications[0].head_name, "scores");
result.classifications(0).entries(0).categories(0).category_name(), EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
testing::Eq("Silence")); EXPECT_EQ(result[0].classifications[0].categories[0].index, 494);
EXPECT_THAT(result.classifications(0).entries(0).categories(0).score(), EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
testing::FloatEq(.800781f)); "Silence");
EXPECT_FLOAT_EQ(result[0].classifications[0].categories[0].score, 0.800781f);
} }
TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) { TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
@ -383,7 +372,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/16000)); /*audio_sample_rate=*/16000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckTwoHeadsClassificationResult(result); CheckTwoHeadsResult(result);
} }
TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) { TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
@ -397,7 +386,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/44100)); /*audio_sample_rate=*/44100));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckTwoHeadsClassificationResult(result); CheckTwoHeadsResult(result);
} }
TEST_F(ClassifyTest, TEST_F(ClassifyTest,
@ -413,13 +402,13 @@ TEST_F(ClassifyTest,
auto result_44k_hz, auto result_44k_hz,
audio_classifier->Classify(std::move(audio_buffer_44k_hz), audio_classifier->Classify(std::move(audio_buffer_44k_hz),
/*audio_sample_rate=*/44100)); /*audio_sample_rate=*/44100));
CheckTwoHeadsClassificationResult(result_44k_hz); CheckTwoHeadsResult(result_44k_hz);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto result_16k_hz, auto result_16k_hz,
audio_classifier->Classify(std::move(audio_buffer_16k_hz), audio_classifier->Classify(std::move(audio_buffer_16k_hz),
/*audio_sample_rate=*/16000)); /*audio_sample_rate=*/16000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckTwoHeadsClassificationResult(result_16k_hz); CheckTwoHeadsResult(result_16k_hz);
} }
TEST_F(ClassifyTest, SucceedsWithMaxResultOption) { TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
@ -428,14 +417,13 @@ TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->classifier_options.max_results = 1; options->classifier_options.max_results = 1;
options->classifier_options.score_threshold = 0.35f;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
AudioClassifier::Create(std::move(options))); AudioClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult())); CheckSpeechResult(result, /*expected_num_categories=*/1);
} }
TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) { TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
@ -450,7 +438,7 @@ TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult())); CheckSpeechResult(result, /*expected_num_categories=*/1);
} }
TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) { TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
@ -466,7 +454,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult())); CheckSpeechResult(result, /*expected_num_categories=*/1);
} }
TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) { TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
@ -482,16 +470,16 @@ TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
// All categroies with the "Speech" label are filtered out. // All categories with the "Speech" label are filtered out.
EXPECT_THAT(result, EqualsProto(R"pb(classifications { std::vector<int64> timestamps_ms = {0, 975, 1950, 2925};
head_index: 0 for (int i = 0; i < timestamps_ms.size(); i++) {
head_name: "scores" EXPECT_EQ(result[i].timestamp_ms, timestamps_ms[i]);
entries { timestamp_ms: 0 } EXPECT_EQ(result[i].classifications.size(), 1);
entries { timestamp_ms: 975 } auto classifications = result[i].classifications[0];
entries { timestamp_ms: 1950 } EXPECT_EQ(classifications.head_index, 0);
entries { timestamp_ms: 2925 } EXPECT_EQ(classifications.head_name, "scores");
entries { timestamp_ms: 3900 } EXPECT_TRUE(classifications.categories.empty());
})pb")); }
} }
class ClassifyAsyncTest : public tflite_shims::testing::Test {}; class ClassifyAsyncTest : public tflite_shims::testing::Test {};
@ -506,9 +494,9 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
options->classifier_options.score_threshold = 0.3f; options->classifier_options.score_threshold = 0.3f;
options->running_mode = core::RunningMode::AUDIO_STREAM; options->running_mode = core::RunningMode::AUDIO_STREAM;
options->sample_rate = kSampleRateHz; options->sample_rate = kSampleRateHz;
std::vector<ClassificationResult> outputs; std::vector<AudioClassifierResult> outputs;
options->result_callback = options->result_callback =
[&outputs](absl::StatusOr<ClassificationResult> status_or_result) { [&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result); MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result);
}; };
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
@ -523,7 +511,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
start_col += kYamnetNumOfAudioSamples * 3; start_col += kYamnetNumOfAudioSamples * 3;
} }
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckStreamingModeClassificationResult(outputs); CheckStreamingModeResults(outputs);
} }
TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) { TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
@ -536,9 +524,9 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
options->classifier_options.score_threshold = 0.3f; options->classifier_options.score_threshold = 0.3f;
options->running_mode = core::RunningMode::AUDIO_STREAM; options->running_mode = core::RunningMode::AUDIO_STREAM;
options->sample_rate = kSampleRateHz; options->sample_rate = kSampleRateHz;
std::vector<ClassificationResult> outputs; std::vector<AudioClassifierResult> outputs;
options->result_callback = options->result_callback =
[&outputs](absl::StatusOr<ClassificationResult> status_or_result) { [&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result); MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result);
}; };
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
@ -555,7 +543,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
start_col += num_samples; start_col += num_samples;
} }
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckStreamingModeClassificationResult(outputs); CheckStreamingModeResults(outputs);
} }
} // namespace } // namespace

View File

@ -40,6 +40,7 @@ cc_library(
"//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
@ -60,41 +61,6 @@ cc_library(
# TODO: Enable this test # TODO: Enable this test
cc_library(
name = "embedder_options",
srcs = ["embedder_options.cc"],
hdrs = ["embedder_options.h"],
deps = ["//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto"],
)
cc_library(
name = "embedding_postprocessing_graph",
srcs = ["embedding_postprocessing_graph.cc"],
hdrs = ["embedding_postprocessing_graph.h"],
deps = [
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)
# TODO: Investigate rewriting the build rule to only link # TODO: Investigate rewriting the build rule to only link
# the Bert Preprocessor if it's needed. # the Bert Preprocessor if it's needed.
cc_library( cc_library(

View File

@ -163,7 +163,7 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_proto", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
], ],
) )
@ -178,7 +178,7 @@ cc_library(
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
], ],

View File

@ -26,14 +26,14 @@
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
namespace { namespace {
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; using ::mediapipe::tasks::components::containers::proto::Embedding;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
// Computes the inverse L2 norm of the provided array of values. Returns 1.0 in // Computes the inverse L2 norm of the provided array of values. Returns 1.0 in
@ -66,7 +66,7 @@ float GetInverseL2Norm(const float* values, int size) {
class TensorsToEmbeddingsCalculator : public Node { class TensorsToEmbeddingsCalculator : public Node {
public: public:
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"}; static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDING_RESULT"}; static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDINGS"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut); MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut);
absl::Status Open(CalculatorContext* cc) override; absl::Status Open(CalculatorContext* cc) override;
@ -77,8 +77,8 @@ class TensorsToEmbeddingsCalculator : public Node {
bool quantize_; bool quantize_;
std::vector<std::string> head_names_; std::vector<std::string> head_names_;
void FillFloatEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry); void FillFloatEmbedding(const Tensor& tensor, Embedding* embedding);
void FillQuantizedEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry); void FillQuantizedEmbedding(const Tensor& tensor, Embedding* embedding);
}; };
absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) { absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) {
@ -104,42 +104,42 @@ absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) {
for (int i = 0; i < tensors.size(); ++i) { for (int i = 0; i < tensors.size(); ++i) {
const auto& tensor = tensors[i]; const auto& tensor = tensors[i];
RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32); RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32);
auto* embeddings = result.add_embeddings(); auto* embedding = result.add_embeddings();
embeddings->set_head_index(i); embedding->set_head_index(i);
if (!head_names_.empty()) { if (!head_names_.empty()) {
embeddings->set_head_name(head_names_[i]); embedding->set_head_name(head_names_[i]);
} }
if (quantize_) { if (quantize_) {
FillQuantizedEmbeddingEntry(tensor, embeddings->add_entries()); FillQuantizedEmbedding(tensor, embedding);
} else { } else {
FillFloatEmbeddingEntry(tensor, embeddings->add_entries()); FillFloatEmbedding(tensor, embedding);
} }
} }
kEmbeddingsOut(cc).Send(result); kEmbeddingsOut(cc).Send(result);
return absl::OkStatus(); return absl::OkStatus();
} }
void TensorsToEmbeddingsCalculator::FillFloatEmbeddingEntry( void TensorsToEmbeddingsCalculator::FillFloatEmbedding(const Tensor& tensor,
const Tensor& tensor, EmbeddingEntry* entry) { Embedding* embedding) {
int size = tensor.shape().num_elements(); int size = tensor.shape().num_elements();
auto tensor_view = tensor.GetCpuReadView(); auto tensor_view = tensor.GetCpuReadView();
const float* tensor_buffer = tensor_view.buffer<float>(); const float* tensor_buffer = tensor_view.buffer<float>();
float inv_l2_norm = float inv_l2_norm =
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f; l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
auto* float_embedding = entry->mutable_float_embedding(); auto* float_embedding = embedding->mutable_float_embedding();
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
float_embedding->add_values(tensor_buffer[i] * inv_l2_norm); float_embedding->add_values(tensor_buffer[i] * inv_l2_norm);
} }
} }
void TensorsToEmbeddingsCalculator::FillQuantizedEmbeddingEntry( void TensorsToEmbeddingsCalculator::FillQuantizedEmbedding(
const Tensor& tensor, EmbeddingEntry* entry) { const Tensor& tensor, Embedding* embedding) {
int size = tensor.shape().num_elements(); int size = tensor.shape().num_elements();
auto tensor_view = tensor.GetCpuReadView(); auto tensor_view = tensor.GetCpuReadView();
const float* tensor_buffer = tensor_view.buffer<float>(); const float* tensor_buffer = tensor_view.buffer<float>();
float inv_l2_norm = float inv_l2_norm =
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f; l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
auto* values = entry->mutable_quantized_embedding()->mutable_values(); auto* values = embedding->mutable_quantized_embedding()->mutable_values();
values->resize(size); values->resize(size);
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
// Normalize. // Normalize.

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe; package mediapipe;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/embedder_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
message TensorsToEmbeddingsCalculatorOptions { message TensorsToEmbeddingsCalculatorOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
@ -27,8 +27,8 @@ message TensorsToEmbeddingsCalculatorOptions {
// The embedder options defining whether to L2-normalize or scalar-quantize // The embedder options defining whether to L2-normalize or scalar-quantize
// the outputs. // the outputs.
optional mediapipe.tasks.components.proto.EmbedderOptions embedder_options = optional mediapipe.tasks.components.processors.proto.EmbedderOptions
1; embedder_options = 1;
// The embedder head names. // The embedder head names.
repeated string head_names = 2; repeated string head_names = 2;

View File

@ -55,7 +55,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, FailsWithInvalidHeadNamesNumber) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" } [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" }
} }
@ -73,7 +73,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: false } embedder_options { l2_normalize: false quantize: false }
@ -84,28 +84,24 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs() const EmbeddingResult& result =
.Get("EMBEDDING_RESULT", 0) runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
.packets[0] EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
.Get<EmbeddingResult>(); R"pb(embeddings {
EXPECT_THAT( float_embedding { values: 0.1 values: 0.2 }
result, head_index: 0
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( }
R"pb(embeddings { embeddings {
entries { float_embedding { values: 0.1 values: 0.2 } } float_embedding { values: -0.2 values: -0.3 }
head_index: 0 head_index: 1
} })pb")));
embeddings {
entries { float_embedding { values: -0.2 values: -0.3 } }
head_index: 1
})pb")));
} }
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) { TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: false } embedder_options { l2_normalize: false quantize: false }
@ -118,30 +114,26 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs() const EmbeddingResult& result =
.Get("EMBEDDING_RESULT", 0) runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
.packets[0] EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
.Get<EmbeddingResult>(); R"pb(embeddings {
EXPECT_THAT( float_embedding { values: 0.1 values: 0.2 }
result, head_index: 0
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( head_name: "foo"
R"pb(embeddings { }
entries { float_embedding { values: 0.1 values: 0.2 } } embeddings {
head_index: 0 float_embedding { values: -0.2 values: -0.3 }
head_name: "foo" head_index: 1
} head_name: "bar"
embeddings { })pb")));
entries { float_embedding { values: -0.2 values: -0.3 } }
head_index: 1
head_name: "bar"
})pb")));
} }
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) { TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: true quantize: false } embedder_options { l2_normalize: true quantize: false }
@ -152,23 +144,17 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs() const EmbeddingResult& result =
.Get("EMBEDDING_RESULT", 0) runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
.packets[0]
.Get<EmbeddingResult>();
EXPECT_THAT( EXPECT_THAT(
result, result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings { R"pb(embeddings {
entries { float_embedding { values: 0.44721356 values: 0.8944271 }
float_embedding { values: 0.44721356 values: 0.8944271 }
}
head_index: 0 head_index: 0
} }
embeddings { embeddings {
entries { float_embedding { values: -0.5547002 values: -0.8320503 }
float_embedding { values: -0.5547002 values: -0.8320503 }
}
head_index: 1 head_index: 1
})pb"))); })pb")));
} }
@ -177,7 +163,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: true } embedder_options { l2_normalize: false quantize: true }
@ -188,22 +174,16 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs() const EmbeddingResult& result =
.Get("EMBEDDING_RESULT", 0) runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
.packets[0]
.Get<EmbeddingResult>();
EXPECT_THAT(result, EXPECT_THAT(result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings { R"pb(embeddings {
entries { quantized_embedding { values: "\x0d\x1a" } # 13,26
quantized_embedding { values: "\x0d\x1a" } # 13,26
}
head_index: 0 head_index: 0
} }
embeddings { embeddings {
entries { quantized_embedding { values: "\xe6\xda" } # -26,-38
quantized_embedding { values: "\xe6\xda" } # -26,-38
}
head_index: 1 head_index: 1
})pb"))); })pb")));
} }
@ -213,7 +193,7 @@ TEST(TensorsToEmbeddingsCalculatorTest,
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: true quantize: true } embedder_options { l2_normalize: true quantize: true }
@ -224,25 +204,18 @@ TEST(TensorsToEmbeddingsCalculatorTest,
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs() const EmbeddingResult& result =
.Get("EMBEDDING_RESULT", 0) runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
.packets[0] EXPECT_THAT(result,
.Get<EmbeddingResult>(); EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
EXPECT_THAT( R"pb(embeddings {
result, quantized_embedding { values: "\x39\x72" } # 57,114
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( head_index: 0
R"pb(embeddings { }
entries { embeddings {
quantized_embedding { values: "\x39\x72" } # 57,114 quantized_embedding { values: "\xb9\x95" } # -71,-107
} head_index: 1
head_index: 0 })pb")));
}
embeddings {
entries {
quantized_embedding { values: "\xb9\x95" } # -71,-107
}
head_index: 1
})pb")));
} }
} // namespace } // namespace

View File

@ -49,3 +49,12 @@ cc_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
], ],
) )
cc_library(
name = "embedding_result",
srcs = ["embedding_result.cc"],
hdrs = ["embedding_result.h"],
deps = [
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
],
)

View File

@ -0,0 +1,57 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include <iterator>
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
namespace mediapipe::tasks::components::containers {
Embedding ConvertToEmbedding(const proto::Embedding& proto) {
Embedding embedding;
if (proto.has_float_embedding()) {
embedding.float_embedding = {
std::make_move_iterator(proto.float_embedding().values().begin()),
std::make_move_iterator(proto.float_embedding().values().end())};
} else {
embedding.quantized_embedding = {
std::make_move_iterator(proto.quantized_embedding().values().begin()),
std::make_move_iterator(proto.quantized_embedding().values().end())};
}
embedding.head_index = proto.head_index();
if (proto.has_head_name()) {
embedding.head_name = proto.head_name();
}
return embedding;
}
EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto) {
EmbeddingResult embedding_result;
embedding_result.embeddings.reserve(proto.embeddings_size());
for (const auto& embedding : proto.embeddings()) {
embedding_result.embeddings.push_back(ConvertToEmbedding(embedding));
}
if (proto.has_timestamp_ms()) {
embedding_result.timestamp_ms = proto.timestamp_ms();
}
return embedding_result;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,72 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
namespace mediapipe::tasks::components::containers {
// Embedding result for a given embedder head.
//
// One and only one of the two 'float_embedding' and 'quantized_embedding' will
// contain data, based on whether or not the embedder was configured to perform
// scalar quantization.
struct Embedding {
// Floating-point embedding. Empty if the embedder was configured to perform
// scalar-quantization.
std::vector<float> float_embedding;
// Scalar-quantized embedding. Empty if the embedder was not configured to
// perform scalar quantization.
std::string quantized_embedding;
// The index of the embedder head (i.e. output tensor) this embedding comes
// from. This is useful for multi-head models.
int head_index;
// The optional name of the embedder head, as provided in the TFLite Model
// Metadata [1] if present. This is useful for multi-head models.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
std::optional<std::string> head_name = std::nullopt;
};
// Defines embedding results of a model.
struct EmbeddingResult {
// The embedding results for each head of the model.
std::vector<Embedding> embeddings;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for embedding extraction on time series (e.g. audio
// embedding). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
std::optional<int64_t> timestamp_ms = std::nullopt;
};
// Utility function to convert from Embedding proto to Embedding struct.
Embedding ConvertToEmbedding(const proto::Embedding& proto);
// Utility function to convert from EmbeddingResult proto to EmbeddingResult
// struct.
EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_

View File

@ -30,30 +30,31 @@ message QuantizedEmbedding {
optional bytes values = 1; optional bytes values = 1;
} }
// Floating-point or scalar-quantized embedding with an optional timestamp. // Embedding result for a given embedder head.
message EmbeddingEntry { message Embedding {
// The actual embedding, either floating-point or scalar-quantized. // The actual embedding, either floating-point or quantized.
oneof embedding { oneof embedding {
FloatEmbedding float_embedding = 1; FloatEmbedding float_embedding = 1;
QuantizedEmbedding quantized_embedding = 2; QuantizedEmbedding quantized_embedding = 2;
} }
// The optional timestamp (in milliseconds) associated to the embedding entry.
// This is useful for time series use cases, e.g. audio embedding.
optional int64 timestamp_ms = 3;
}
// Embeddings for a given embedder head.
message Embeddings {
repeated EmbeddingEntry entries = 1;
// The index of the embedder head that produced this embedding. This is useful // The index of the embedder head that produced this embedding. This is useful
// for multi-head models. // for multi-head models.
optional int32 head_index = 2; optional int32 head_index = 3;
// The name of the embedder head, which is the corresponding tensor metadata // The name of the embedder head, which is the corresponding tensor metadata
// name (if any). This is useful for multi-head models. // name (if any). This is useful for multi-head models.
optional string head_name = 3; optional string head_name = 4;
} }
// Contains one set of results per embedder head. // Embedding results for a given embedder model.
message EmbeddingResult { message EmbeddingResult {
repeated Embeddings embeddings = 1; // The embedding results for each model head, i.e. one for each output tensor.
repeated Embedding embeddings = 1;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for embedding extraction on time series (e.g. audio
// embedding). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
optional int64 timestamp_ms = 2;
} }

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/calculators/image/image_clone_calculator.pb.h" #include "mediapipe/calculators/image/image_clone_calculator.pb.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" #include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"

View File

@ -62,3 +62,38 @@ cc_library(
], ],
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "embedder_options",
srcs = ["embedder_options.cc"],
hdrs = ["embedder_options.h"],
deps = ["//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto"],
)
cc_library(
name = "embedding_postprocessing_graph",
srcs = ["embedding_postprocessing_graph.cc"],
hdrs = ["embedding_postprocessing_graph.h"],
deps = [
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)

View File

@ -13,22 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/embedder_options.h" #include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto( proto::EmbedderOptions ConvertEmbedderOptionsToProto(
EmbedderOptions* embedder_options) { EmbedderOptions* embedder_options) {
tasks::components::proto::EmbedderOptions options_proto; proto::EmbedderOptions options_proto;
options_proto.set_l2_normalize(embedder_options->l2_normalize); options_proto.set_l2_normalize(embedder_options->l2_normalize);
options_proto.set_quantize(embedder_options->quantize); options_proto.set_quantize(embedder_options->quantize);
return options_proto; return options_proto;
} }
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_ #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
// Embedder options for MediaPipe C++ embedding extraction tasks. // Embedder options for MediaPipe C++ embedding extraction tasks.
struct EmbedderOptions { struct EmbedderOptions {
@ -37,11 +38,12 @@ struct EmbedderOptions {
bool quantize; bool quantize;
}; };
tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto( proto::EmbedderOptions ConvertEmbedderOptionsToProto(
EmbedderOptions* embedder_options); EmbedderOptions* embedder_options);
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_ #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include <string> #include <string>
#include <vector> #include <vector>
@ -29,8 +29,8 @@ limitations under the License.
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
@ -39,6 +39,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
namespace { namespace {
@ -49,13 +50,12 @@ using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::components::proto::EmbedderOptions;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
using TensorsSource = using TensorsSource =
::mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>; ::mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
// Identifies whether or not the model has quantized outputs, and performs // Identifies whether or not the model has quantized outputs, and performs
// sanity checks. // sanity checks.
@ -144,7 +144,7 @@ absl::StatusOr<std::vector<std::string>> GetHeadNames(
absl::Status ConfigureEmbeddingPostprocessing( absl::Status ConfigureEmbeddingPostprocessing(
const ModelResources& model_resources, const ModelResources& model_resources,
const EmbedderOptions& embedder_options, const proto::EmbedderOptions& embedder_options,
proto::EmbeddingPostprocessingGraphOptions* options) { proto::EmbeddingPostprocessingGraphOptions* options) {
ASSIGN_OR_RETURN(bool has_quantized_outputs, ASSIGN_OR_RETURN(bool has_quantized_outputs,
HasQuantizedOutputs(model_resources)); HasQuantizedOutputs(model_resources));
@ -188,7 +188,7 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
BuildEmbeddingPostprocessing( BuildEmbeddingPostprocessing(
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(), sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph)); graph[Input<std::vector<Tensor>>(kTensorsTag)], graph));
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingResultTag)]; embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -220,13 +220,13 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>() .GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
.CopyFrom(options.tensors_to_embeddings_options()); .CopyFrom(options.tensors_to_embeddings_options());
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag); dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
return tensors_to_embeddings_node[Output<EmbeddingResult>( return tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)];
kEmbeddingResultTag)];
} }
}; };
REGISTER_MEDIAPIPE_GRAPH( REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::components::EmbeddingPostprocessingGraph); ::mediapipe::tasks::components::processors::EmbeddingPostprocessingGraph);
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_ #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
// Configures an EmbeddingPostprocessingGraph using the provided model resources // Configures an EmbeddingPostprocessingGraph using the provided model resources
// and EmbedderOptions. // and EmbedderOptions.
@ -44,18 +45,19 @@ namespace components {
// The output tensors of an InferenceCalculator, to convert into // The output tensors of an InferenceCalculator, to convert into
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8. // EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
// Outputs: // Outputs:
// EMBEDDING_RESULT - EmbeddingResult // EMBEDDINGS - EmbeddingResult
// The output EmbeddingResult. // The output EmbeddingResult.
// //
// TODO: add support for additional optional "TIMESTAMPS" input for // TODO: add support for additional optional "TIMESTAMPS" input for
// embeddings aggregation. // embeddings aggregation.
absl::Status ConfigureEmbeddingPostprocessing( absl::Status ConfigureEmbeddingPostprocessing(
const tasks::core::ModelResources& model_resources, const tasks::core::ModelResources& model_resources,
const tasks::components::proto::EmbedderOptions& embedder_options, const proto::EmbedderOptions& embedder_options,
proto::EmbeddingPostprocessingGraphOptions* options); proto::EmbeddingPostprocessingGraphOptions* options);
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_ #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include <memory> #include <memory>
@ -25,8 +25,8 @@ limitations under the License.
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -34,12 +34,10 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
namespace { namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::proto::EmbedderOptions;
using ::mediapipe::tasks::components::proto::
EmbeddingPostprocessingGraphOptions;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/"; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
@ -69,68 +67,72 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
EmbedderOptions options_in; proto::EmbedderOptions options_in;
options_in.set_l2_normalize(true); options_in.set_l2_normalize(true);
EmbeddingPostprocessingGraphOptions options_out; proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
&options_out)); &options_out));
EXPECT_THAT( EXPECT_THAT(
options_out, options_out,
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>( EqualsProto(
R"pb(tensors_to_embeddings_options { ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
embedder_options { l2_normalize: true } R"pb(tensors_to_embeddings_options {
head_names: "probability" embedder_options { l2_normalize: true }
} head_names: "probability"
has_quantized_outputs: true)pb"))); }
has_quantized_outputs: true)pb")));
} }
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) { TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
EmbedderOptions options_in; proto::EmbedderOptions options_in;
options_in.set_quantize(true); options_in.set_quantize(true);
EmbeddingPostprocessingGraphOptions options_out; proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
&options_out)); &options_out));
EXPECT_THAT( EXPECT_THAT(
options_out, options_out,
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>( EqualsProto(
R"pb(tensors_to_embeddings_options { ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
embedder_options { quantize: true } R"pb(tensors_to_embeddings_options {
} embedder_options { quantize: true }
has_quantized_outputs: true)pb"))); }
has_quantized_outputs: true)pb")));
} }
TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN(auto model_resources, MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
CreateModelResourcesForModel(kMobileNetV3Embedder)); CreateModelResourcesForModel(kMobileNetV3Embedder));
EmbedderOptions options_in; proto::EmbedderOptions options_in;
options_in.set_quantize(true); options_in.set_quantize(true);
options_in.set_l2_normalize(true); options_in.set_l2_normalize(true);
EmbeddingPostprocessingGraphOptions options_out; proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
&options_out)); &options_out));
EXPECT_THAT( EXPECT_THAT(
options_out, options_out,
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>( EqualsProto(
R"pb(tensors_to_embeddings_options { ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
embedder_options { quantize: true l2_normalize: true } R"pb(tensors_to_embeddings_options {
head_names: "feature" embedder_options { quantize: true l2_normalize: true }
} head_names: "feature"
has_quantized_outputs: false)pb"))); }
has_quantized_outputs: false)pb")));
} }
// TODO: add E2E Postprocessing tests once timestamp aggregation is // TODO: add E2E Postprocessing tests once timestamp aggregation is
// supported. // supported.
} // namespace } // namespace
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -34,3 +34,18 @@ mediapipe_proto_library(
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
], ],
) )
mediapipe_proto_library(
name = "embedder_options_proto",
srcs = ["embedder_options.proto"],
)
mediapipe_proto_library(
name = "embedding_postprocessing_graph_options_proto",
srcs = ["embedding_postprocessing_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto",
],
)

View File

@ -15,7 +15,10 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks.components.proto; package mediapipe.tasks.components.processors.proto;
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
option java_outer_classname = "EmbedderOptionsProto";
// Shared options used by all embedding extraction tasks. // Shared options used by all embedding extraction tasks.
message EmbedderOptions { message EmbedderOptions {

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks.components.proto; package mediapipe.tasks.components.processors.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto"; import "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto";

View File

@ -23,21 +23,6 @@ mediapipe_proto_library(
srcs = ["segmenter_options.proto"], srcs = ["segmenter_options.proto"],
) )
mediapipe_proto_library(
name = "embedder_options_proto",
srcs = ["embedder_options.proto"],
)
mediapipe_proto_library(
name = "embedding_postprocessing_graph_options_proto",
srcs = ["embedding_postprocessing_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "text_preprocessing_graph_options_proto", name = "text_preprocessing_graph_options_proto",
srcs = ["text_preprocessing_graph_options.proto"], srcs = ["text_preprocessing_graph_options.proto"],

View File

@ -26,7 +26,7 @@ cc_library(
hdrs = ["cosine_similarity.h"], hdrs = ["cosine_similarity.h"],
deps = [ deps = [
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers:embedding_result",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
@ -39,7 +39,7 @@ cc_test(
deps = [ deps = [
":cosine_similarity", ":cosine_similarity",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers:embedding_result",
], ],
) )

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -30,7 +30,7 @@ namespace utils {
namespace { namespace {
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; using ::mediapipe::tasks::components::containers::Embedding;
template <typename T> template <typename T>
absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v, absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
@ -66,39 +66,35 @@ absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
// an L2-norm of 0. // an L2-norm of 0.
// //
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity // [1]: https://en.wikipedia.org/wiki/Cosine_similarity
absl::StatusOr<double> CosineSimilarity(const EmbeddingEntry& u, absl::StatusOr<double> CosineSimilarity(const Embedding& u,
const EmbeddingEntry& v) { const Embedding& v) {
if (u.has_float_embedding() && v.has_float_embedding()) { if (!u.float_embedding.empty() && !v.float_embedding.empty()) {
if (u.float_embedding().values().size() != if (u.float_embedding.size() != v.float_embedding.size()) {
v.float_embedding().values().size()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrFormat("Cannot compute cosine similarity between embeddings " absl::StrFormat("Cannot compute cosine similarity between embeddings "
"of different sizes (%d vs. %d)", "of different sizes (%d vs. %d)",
u.float_embedding().values().size(), u.float_embedding.size(), v.float_embedding.size()),
v.float_embedding().values().size()),
MediaPipeTasksStatus::kInvalidArgumentError); MediaPipeTasksStatus::kInvalidArgumentError);
} }
return ComputeCosineSimilarity(u.float_embedding().values().data(), return ComputeCosineSimilarity(u.float_embedding.data(),
v.float_embedding().values().data(), v.float_embedding.data(),
u.float_embedding().values().size()); u.float_embedding.size());
} }
if (u.has_quantized_embedding() && v.has_quantized_embedding()) { if (!u.quantized_embedding.empty() && !v.quantized_embedding.empty()) {
if (u.quantized_embedding().values().size() != if (u.quantized_embedding.size() != v.quantized_embedding.size()) {
v.quantized_embedding().values().size()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrFormat("Cannot compute cosine similarity between embeddings " absl::StrFormat("Cannot compute cosine similarity between embeddings "
"of different sizes (%d vs. %d)", "of different sizes (%d vs. %d)",
u.quantized_embedding().values().size(), u.quantized_embedding.size(),
v.quantized_embedding().values().size()), v.quantized_embedding.size()),
MediaPipeTasksStatus::kInvalidArgumentError); MediaPipeTasksStatus::kInvalidArgumentError);
} }
return ComputeCosineSimilarity(reinterpret_cast<const int8_t*>( return ComputeCosineSimilarity(
u.quantized_embedding().values().data()), reinterpret_cast<const int8_t*>(u.quantized_embedding.data()),
reinterpret_cast<const int8_t*>( reinterpret_cast<const int8_t*>(v.quantized_embedding.data()),
v.quantized_embedding().values().data()), u.quantized_embedding.size());
u.quantized_embedding().values().size());
} }
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,

View File

@ -17,22 +17,20 @@ limitations under the License.
#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace utils { namespace utils {
// Utility function to compute cosine similarity [1] between two embedding // Utility function to compute cosine similarity [1] between two embeddings. May
// entries. May return an InvalidArgumentError if e.g. the feature vectors are // return an InvalidArgumentError if e.g. the embeddings are of different types
// of different types (quantized vs. float), have different sizes, or have a // (quantized vs. float), have different sizes, or have a an L2-norm of 0.
// an L2-norm of 0.
// //
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity // [1]: https://en.wikipedia.org/wiki/Cosine_similarity
absl::StatusOr<double> CosineSimilarity( absl::StatusOr<double> CosineSimilarity(const containers::Embedding& u,
const containers::proto::EmbeddingEntry& u, const containers::Embedding& v);
const containers::proto::EmbeddingEntry& v);
} // namespace utils } // namespace utils
} // namespace components } // namespace components

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -30,29 +30,27 @@ namespace components {
namespace utils { namespace utils {
namespace { namespace {
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; using ::mediapipe::tasks::components::containers::Embedding;
using ::testing::HasSubstr; using ::testing::HasSubstr;
// Helper function to generate float EmbeddingEntry. // Helper function to generate float Embedding.
EmbeddingEntry BuildFloatEntry(std::vector<float> values) { Embedding BuildFloatEmbedding(std::vector<float> values) {
EmbeddingEntry entry; Embedding embedding;
for (const float value : values) { embedding.float_embedding = values;
entry.mutable_float_embedding()->add_values(value); return embedding;
}
return entry;
} }
// Helper function to generate quantized EmbeddingEntry. // Helper function to generate quantized Embedding.
EmbeddingEntry BuildQuantizedEntry(std::vector<int8_t> values) { Embedding BuildQuantizedEmbedding(std::vector<int8_t> values) {
EmbeddingEntry entry; Embedding embedding;
entry.mutable_quantized_embedding()->set_values( uint8_t* data = reinterpret_cast<uint8_t*>(values.data());
reinterpret_cast<uint8_t*>(values.data()), values.size()); embedding.quantized_embedding = {data, data + values.size()};
return entry; return embedding;
} }
TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) { TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
auto u = BuildFloatEntry({0.1, 0.2}); auto u = BuildFloatEmbedding({0.1, 0.2});
auto v = BuildQuantizedEntry({0, 1}); auto v = BuildQuantizedEmbedding({0, 1});
auto status = CosineSimilarity(u, v); auto status = CosineSimilarity(u, v);
@ -63,8 +61,8 @@ TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
} }
TEST(CosineSimilarity, FailsWithZeroNorm) { TEST(CosineSimilarity, FailsWithZeroNorm) {
auto u = BuildFloatEntry({0.1, 0.2}); auto u = BuildFloatEmbedding({0.1, 0.2});
auto v = BuildFloatEntry({0.0, 0.0}); auto v = BuildFloatEmbedding({0.0, 0.0});
auto status = CosineSimilarity(u, v); auto status = CosineSimilarity(u, v);
@ -75,8 +73,8 @@ TEST(CosineSimilarity, FailsWithZeroNorm) {
} }
TEST(CosineSimilarity, FailsWithDifferentSizes) { TEST(CosineSimilarity, FailsWithDifferentSizes) {
auto u = BuildFloatEntry({0.1, 0.2}); auto u = BuildFloatEmbedding({0.1, 0.2});
auto v = BuildFloatEntry({0.1, 0.2, 0.3}); auto v = BuildFloatEmbedding({0.1, 0.2, 0.3});
auto status = CosineSimilarity(u, v); auto status = CosineSimilarity(u, v);
@ -87,8 +85,8 @@ TEST(CosineSimilarity, FailsWithDifferentSizes) {
} }
TEST(CosineSimilarity, SucceedsWithFloatEntries) { TEST(CosineSimilarity, SucceedsWithFloatEntries) {
auto u = BuildFloatEntry({1.0, 0.0, 0.0, 0.0}); auto u = BuildFloatEmbedding({1.0, 0.0, 0.0, 0.0});
auto v = BuildFloatEntry({0.5, 0.5, 0.5, 0.5}); auto v = BuildFloatEmbedding({0.5, 0.5, 0.5, 0.5});
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v)); MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));
@ -96,8 +94,8 @@ TEST(CosineSimilarity, SucceedsWithFloatEntries) {
} }
TEST(CosineSimilarity, SucceedsWithQuantizedEntries) { TEST(CosineSimilarity, SucceedsWithQuantizedEntries) {
auto u = BuildQuantizedEntry({127, 0, 0, 0}); auto u = BuildQuantizedEmbedding({127, 0, 0, 0});
auto v = BuildQuantizedEntry({-128, 0, 0, 0}); auto v = BuildQuantizedEmbedding({-128, 0, 0, 0});
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v)); MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));

View File

@ -26,12 +26,12 @@ cc_library(
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
@ -49,9 +49,10 @@ cc_library(
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/tool:options_map", "//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc/components:embedder_options", "//mediapipe/tasks/cc/components/containers:embedding_result",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto", "//mediapipe/tasks/cc/components/processors:embedder_options",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:cosine_similarity", "//mediapipe/tasks/cc/components/utils:cosine_similarity",
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:task_runner",

View File

@ -21,9 +21,10 @@ limitations under the License.
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/tool/options_map.h" #include "mediapipe/framework/tool/options_map.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/embedder_options.h" #include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h" #include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
@ -41,8 +42,8 @@ namespace image_embedder {
namespace { namespace {
constexpr char kEmbeddingResultStreamName[] = "embedding_result_out"; constexpr char kEmbeddingsStreamName[] = "embeddings_out";
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
@ -53,7 +54,7 @@ constexpr char kGraphTypeName[] =
"mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"; "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000; constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketMap;
using ::mediapipe::tasks::vision::image_embedder::proto:: using ::mediapipe::tasks::vision::image_embedder::proto::
@ -71,13 +72,13 @@ CalculatorGraphConfig CreateGraphConfig(
graph.In(kNormRectTag).SetName(kNormRectStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName);
auto& task_graph = graph.AddNode(kGraphTypeName); auto& task_graph = graph.AddNode(kGraphTypeName);
task_graph.GetOptions<ImageEmbedderGraphOptions>().Swap(options_proto.get()); task_graph.GetOptions<ImageEmbedderGraphOptions>().Swap(options_proto.get());
task_graph.Out(kEmbeddingResultTag).SetName(kEmbeddingResultStreamName) >> task_graph.Out(kEmbeddingsTag).SetName(kEmbeddingsStreamName) >>
graph.Out(kEmbeddingResultTag); graph.Out(kEmbeddingsTag);
task_graph.Out(kImageTag).SetName(kImageOutStreamName) >> task_graph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag); graph.Out(kImageTag);
if (enable_flow_limiting) { if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator( return tasks::core::AddFlowLimiterCalculator(
graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingResultTag); graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingsTag);
} }
graph.In(kImageTag) >> task_graph.In(kImageTag); graph.In(kImageTag) >> task_graph.In(kImageTag);
graph.In(kNormRectTag) >> task_graph.In(kNormRectTag); graph.In(kNormRectTag) >> task_graph.In(kNormRectTag);
@ -95,8 +96,8 @@ std::unique_ptr<ImageEmbedderGraphOptions> ConvertImageEmbedderOptionsToProto(
options_proto->mutable_base_options()->set_use_stream_mode( options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE); options->running_mode != core::RunningMode::IMAGE);
auto embedder_options_proto = auto embedder_options_proto =
std::make_unique<tasks::components::proto::EmbedderOptions>( std::make_unique<components::processors::proto::EmbedderOptions>(
components::ConvertEmbedderOptionsToProto( components::processors::ConvertEmbedderOptionsToProto(
&(options->embedder_options))); &(options->embedder_options)));
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get()); options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
return options_proto; return options_proto;
@ -121,9 +122,10 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
return; return;
} }
Packet embedding_result_packet = Packet embedding_result_packet =
status_or_packets.value()[kEmbeddingResultStreamName]; status_or_packets.value()[kEmbeddingsStreamName];
Packet image_packet = status_or_packets.value()[kImageOutStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback(embedding_result_packet.Get<EmbeddingResult>(), result_callback(ConvertToEmbeddingResult(
embedding_result_packet.Get<EmbeddingResult>()),
image_packet.Get<Image>(), image_packet.Get<Image>(),
embedding_result_packet.Timestamp().Value() / embedding_result_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond); kMicroSecondsPerMilliSecond);
@ -138,7 +140,7 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
std::move(packets_callback)); std::move(packets_callback));
} }
absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed( absl::StatusOr<ImageEmbedderResult> ImageEmbedder::Embed(
Image image, Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -155,10 +157,11 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
{{kImageInStreamName, MakePacket<Image>(std::move(image))}, {{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}})); MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>(); return ConvertToEmbeddingResult(
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
} }
absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo( absl::StatusOr<ImageEmbedderResult> ImageEmbedder::EmbedForVideo(
Image image, int64 timestamp_ms, Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -178,7 +181,8 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect)) MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>(); return ConvertToEmbeddingResult(
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
} }
absl::Status ImageEmbedder::EmbedAsync( absl::Status ImageEmbedder::EmbedAsync(
@ -202,7 +206,8 @@ absl::Status ImageEmbedder::EmbedAsync(
} }
absl::StatusOr<double> ImageEmbedder::CosineSimilarity( absl::StatusOr<double> ImageEmbedder::CosineSimilarity(
const EmbeddingEntry& u, const EmbeddingEntry& v) { const components::containers::Embedding& u,
const components::containers::Embedding& v) {
return components::utils::CosineSimilarity(u, v); return components::utils::CosineSimilarity(u, v);
} }

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/embedder_options.h" #include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
@ -33,6 +33,10 @@ namespace tasks {
namespace vision { namespace vision {
namespace image_embedder { namespace image_embedder {
// Alias the shared EmbeddingResult struct as result typo.
using ImageEmbedderResult =
::mediapipe::tasks::components::containers::EmbeddingResult;
// The options for configuring a MediaPipe image embedder task. // The options for configuring a MediaPipe image embedder task.
struct ImageEmbedderOptions { struct ImageEmbedderOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model // Base options for configuring MediaPipe Tasks, such as specifying the model
@ -50,14 +54,12 @@ struct ImageEmbedderOptions {
// Options for configuring the embedder behavior, such as L2-normalization or // Options for configuring the embedder behavior, such as L2-normalization or
// scalar-quantization. // scalar-quantization.
components::EmbedderOptions embedder_options; components::processors::EmbedderOptions embedder_options;
// The user-defined result callback for processing live stream data. // The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. // to RunningMode::LIVE_STREAM.
std::function<void( std::function<void(absl::StatusOr<ImageEmbedderResult>, const Image&, int64)>
absl::StatusOr<components::containers::proto::EmbeddingResult>,
const Image&, int64)>
result_callback = nullptr; result_callback = nullptr;
}; };
@ -104,7 +106,7 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// running mode. // running mode.
// //
// The image can be of any size with format RGB or RGBA. // The image can be of any size with format RGB or RGBA.
absl::StatusOr<components::containers::proto::EmbeddingResult> Embed( absl::StatusOr<ImageEmbedderResult> Embed(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
@ -127,7 +129,7 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. It's required to // The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps // provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing. // must be monotonically increasing.
absl::StatusOr<components::containers::proto::EmbeddingResult> EmbedForVideo( absl::StatusOr<ImageEmbedderResult> EmbedForVideo(
mediapipe::Image image, int64 timestamp_ms, mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
@ -168,15 +170,15 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// Shuts down the ImageEmbedder when all works are done. // Shuts down the ImageEmbedder when all works are done.
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }
// Utility function to compute cosine similarity [1] between two embedding // Utility function to compute cosine similarity [1] between two embeddings.
// entries. May return an InvalidArgumentError if e.g. the feature vectors are // May return an InvalidArgumentError if e.g. the embeddings are of different
// of different types (quantized vs. float), have different sizes, or have a // types (quantized vs. float), have different sizes, or have a an L2-norm of
// an L2-norm of 0. // 0.
// //
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity // [1]: https://en.wikipedia.org/wiki/Cosine_similarity
static absl::StatusOr<double> CosineSimilarity( static absl::StatusOr<double> CosineSimilarity(
const components::containers::proto::EmbeddingEntry& u, const components::containers::Embedding& u,
const components::containers::proto::EmbeddingEntry& v); const components::containers::Embedding& v);
}; };
} // namespace image_embedder } // namespace image_embedder

View File

@ -20,10 +20,10 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
@ -40,10 +40,8 @@ using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::components::proto::
EmbeddingPostprocessingGraphOptions;
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
@ -67,7 +65,7 @@ struct ImageEmbedderOutputStreams {
// Describes region of image to perform embedding extraction on. // Describes region of image to perform embedding extraction on.
// @Optional: rect covering the whole image is used if not specified. // @Optional: rect covering the whole image is used if not specified.
// Outputs: // Outputs:
// EMBEDDING_RESULT - EmbeddingResult // EMBEDDINGS - EmbeddingResult
// The embedding result. // The embedding result.
// IMAGE - Image // IMAGE - Image
// The image that embedding extraction runs on. // The image that embedding extraction runs on.
@ -76,7 +74,7 @@ struct ImageEmbedderOutputStreams {
// node { // node {
// calculator: "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph" // calculator: "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"
// input_stream: "IMAGE:image_in" // input_stream: "IMAGE:image_in"
// output_stream: "EMBEDDING_RESULT:embedding_result_out" // output_stream: "EMBEDDINGS:embedding_result_out"
// output_stream: "IMAGE:image_out" // output_stream: "IMAGE:image_out"
// options { // options {
// [mediapipe.tasks.vision.image_embedder.proto.ImageEmbedderOptions.ext] // [mediapipe.tasks.vision.image_embedder.proto.ImageEmbedderOptions.ext]
@ -107,7 +105,7 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
graph[Input<Image>(kImageTag)], graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
output_streams.embedding_result >> output_streams.embedding_result >>
graph[Output<EmbeddingResult>(kEmbeddingResultTag)]; graph[Output<EmbeddingResult>(kEmbeddingsTag)];
output_streams.image >> graph[Output<Image>(kImageTag)]; output_streams.image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -152,16 +150,17 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
// Adds postprocessing calculators and connects its input stream to the // Adds postprocessing calculators and connects its input stream to the
// inference results. // inference results.
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.EmbeddingPostprocessingGraph"); "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(components::ConfigureEmbeddingPostprocessing( MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
model_resources, task_options.embedder_options(), model_resources, task_options.embedder_options(),
&postprocessing.GetOptions<EmbeddingPostprocessingGraphOptions>())); &postprocessing.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the embedding results. // Outputs the embedding results.
return ImageEmbedderOutputStreams{ return ImageEmbedderOutputStreams{
/*embedding_result=*/postprocessing[Output<EmbeddingResult>( /*embedding_result=*/postprocessing[Output<EmbeddingResult>(
kEmbeddingResultTag)], kEmbeddingsTag)],
/*image=*/preprocessing[Output<Image>(kImageTag)]}; /*image=*/preprocessing[Output<Image>(kImageTag)]};
} }
}; };

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -42,7 +42,6 @@ namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -54,18 +53,14 @@ constexpr double kSimilarityTolerancy = 1e-6;
// Utility function to check the sizes, head_index and head_names of a result // Utility function to check the sizes, head_index and head_names of a result
// procuded by kMobileNetV3Embedder. // procuded by kMobileNetV3Embedder.
void CheckMobileNetV3Result(const EmbeddingResult& result, bool quantized) { void CheckMobileNetV3Result(const ImageEmbedderResult& result, bool quantized) {
EXPECT_EQ(result.embeddings().size(), 1); EXPECT_EQ(result.embeddings.size(), 1);
EXPECT_EQ(result.embeddings(0).head_index(), 0); EXPECT_EQ(result.embeddings[0].head_index, 0);
EXPECT_EQ(result.embeddings(0).head_name(), "feature"); EXPECT_EQ(result.embeddings[0].head_name, "feature");
EXPECT_EQ(result.embeddings(0).entries().size(), 1);
if (quantized) { if (quantized) {
EXPECT_EQ( EXPECT_EQ(result.embeddings[0].quantized_embedding.size(), 1024);
result.embeddings(0).entries(0).quantized_embedding().values().size(),
1024);
} else { } else {
EXPECT_EQ(result.embeddings(0).entries(0).float_embedding().values().size(), EXPECT_EQ(result.embeddings[0].float_embedding.size(), 1024);
1024);
} }
} }
@ -154,7 +149,7 @@ TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = running_mode; options->running_mode = running_mode;
options->result_callback = [](absl::StatusOr<EmbeddingResult>, options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
const Image& image, int64 timestamp_ms) {}; const Image& image, int64 timestamp_ms) {};
auto image_embedder = ImageEmbedder::Create(std::move(options)); auto image_embedder = ImageEmbedder::Create(std::move(options));
@ -231,19 +226,18 @@ TEST_F(ImageModeTest, SucceedsWithoutL2Normalization) {
JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image)); image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
// Check results. // Check results.
CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false); CheckMobileNetV3Result(crop_result, false);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, image_result.embeddings[0],
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), crop_result.embeddings[0]));
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.925519; double expected_similarity = 0.925519;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -264,19 +258,18 @@ TEST_F(ImageModeTest, SucceedsWithL2Normalization) {
JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image)); image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
// Check results. // Check results.
CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false); CheckMobileNetV3Result(crop_result, false);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, image_result.embeddings[0],
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), crop_result.embeddings[0]));
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.925519; double expected_similarity = 0.925519;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -297,19 +290,18 @@ TEST_F(ImageModeTest, SucceedsWithQuantization) {
JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image)); image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
// Check results. // Check results.
CheckMobileNetV3Result(image_result, true); CheckMobileNetV3Result(image_result, true);
CheckMobileNetV3Result(crop_result, true); CheckMobileNetV3Result(crop_result, true);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, image_result.embeddings[0],
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), crop_result.embeddings[0]));
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.926791; double expected_similarity = 0.926791;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -333,19 +325,18 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& image_result, const ImageEmbedderResult& image_result,
image_embedder->Embed(image, image_processing_options)); image_embedder->Embed(image, image_processing_options));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
// Check results. // Check results.
CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false); CheckMobileNetV3Result(crop_result, false);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, image_result.embeddings[0],
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), crop_result.embeddings[0]));
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.999931; double expected_similarity = 0.999931;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -367,20 +358,19 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
image_processing_options.rotation_degrees = -90; image_processing_options.rotation_degrees = -90;
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image)); image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result, const ImageEmbedderResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options)); image_embedder->Embed(rotated, image_processing_options));
// Check results. // Check results.
CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(rotated_result, false); CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, image_result.embeddings[0],
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), rotated_result.embeddings[0]));
rotated_result.embeddings(0).entries(0)));
double expected_similarity = 0.572265; double expected_similarity = 0.572265;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -403,20 +393,19 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
/*rotation_degrees=*/-90}; /*rotation_degrees=*/-90};
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result, const ImageEmbedderResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options)); image_embedder->Embed(rotated, image_processing_options));
// Check results. // Check results.
CheckMobileNetV3Result(crop_result, false); CheckMobileNetV3Result(crop_result, false);
CheckMobileNetV3Result(rotated_result, false); CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, crop_result.embeddings[0],
ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0), rotated_result.embeddings[0]));
rotated_result.embeddings(0).entries(0)));
double expected_similarity = 0.62838; double expected_similarity = 0.62838;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -487,16 +476,16 @@ TEST_F(VideoModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options))); ImageEmbedder::Create(std::move(options)));
EmbeddingResult previous_results; ImageEmbedderResult previous_results;
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK_AND_ASSIGN(auto results, MP_ASSERT_OK_AND_ASSIGN(auto results,
image_embedder->EmbedForVideo(image, i)); image_embedder->EmbedForVideo(image, i));
CheckMobileNetV3Result(results, false); CheckMobileNetV3Result(results, false);
if (i > 0) { if (i > 0) {
MP_ASSERT_OK_AND_ASSIGN(double similarity, MP_ASSERT_OK_AND_ASSIGN(
ImageEmbedder::CosineSimilarity( double similarity,
results.embeddings(0).entries(0), ImageEmbedder::CosineSimilarity(results.embeddings[0],
previous_results.embeddings(0).entries(0))); previous_results.embeddings[0]));
double expected_similarity = 1.000000; double expected_similarity = 1.000000;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -515,7 +504,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<EmbeddingResult>, options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
const Image& image, int64 timestamp_ms) {}; const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options))); ImageEmbedder::Create(std::move(options)));
@ -546,7 +535,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<EmbeddingResult>, options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
const Image& image, int64 timestamp_ms) {}; const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options))); ImageEmbedder::Create(std::move(options)));
@ -564,7 +553,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
} }
struct LiveStreamModeResults { struct LiveStreamModeResults {
EmbeddingResult embedding_result; ImageEmbedderResult embedding_result;
std::pair<int, int> image_size; std::pair<int, int> image_size;
int64 timestamp_ms; int64 timestamp_ms;
}; };
@ -580,7 +569,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = options->result_callback =
[&results](absl::StatusOr<EmbeddingResult> embedding_result, [&results](absl::StatusOr<ImageEmbedderResult> embedding_result,
const Image& image, int64 timestamp_ms) { const Image& image, int64 timestamp_ms) {
MP_ASSERT_OK(embedding_result.status()); MP_ASSERT_OK(embedding_result.status());
results.push_back( results.push_back(
@ -612,8 +601,8 @@ TEST_F(LiveStreamModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
double similarity, double similarity,
ImageEmbedder::CosineSimilarity( ImageEmbedder::CosineSimilarity(
result.embedding_result.embeddings(0).entries(0), result.embedding_result.embeddings[0],
results[i - 1].embedding_result.embeddings(0).entries(0))); results[i - 1].embedding_result.embeddings[0]));
double expected_similarity = 1.000000; double expected_similarity = 1.000000;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }

View File

@ -24,7 +24,7 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_proto", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto",
], ],
) )

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.image_embedder.proto; package mediapipe.tasks.vision.image_embedder.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/embedder_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
message ImageEmbedderGraphOptions { message ImageEmbedderGraphOptions {
@ -31,5 +31,5 @@ message ImageEmbedderGraphOptions {
// Options for configuring the embedder behavior, such as normalization or // Options for configuring the embedder behavior, such as normalization or
// quantization. // quantization.
optional components.proto.EmbedderOptions embedder_options = 2; optional components.processors.proto.EmbedderOptions embedder_options = 2;
} }

View File

@ -287,10 +287,9 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i])); tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
} }
} }
return {{ return ImageSegmenterOutputs{
.segmented_masks = segmented_masks, /*segmented_masks=*/segmented_masks,
.image = preprocessing[Output<Image>(kImageTag)], /*image=*/preprocessing[Output<Image>(kImageTag)]};
}};
} }
}; };

View File

@ -23,6 +23,7 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite", "//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite", "//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",

View File

@ -18,6 +18,11 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
py_library(
name = "audio_data",
srcs = ["audio_data.py"],
)
py_library( py_library(
name = "bounding_box", name = "bounding_box",
srcs = ["bounding_box.py"], srcs = ["bounding_box.py"],

View File

@ -0,0 +1,109 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe audio data."""
import dataclasses
from typing import Optional
import numpy as np
@dataclasses.dataclass
class AudioFormat:
"""Audio format metadata.
Attributes:
num_channels: the number of channels of the audio data.
sample_rate: the audio sample rate.
"""
num_channels: int = 1
sample_rate: Optional[float] = None
class AudioData(object):
"""MediaPipe Tasks' audio container."""
def __init__(
self, buffer_length: int,
audio_format: AudioFormat = AudioFormat()) -> None:
"""Initializes the `AudioData` object.
Args:
buffer_length: the length of the audio buffer.
audio_format: the audio format metadata.
"""
self._audio_format = audio_format
self._buffer = np.zeros([buffer_length, self._audio_format.num_channels],
dtype=np.float32)
def clear(self):
"""Clears the internal buffer and fill it with zeros."""
self._buffer.fill(0)
def load_from_array(self,
src: np.ndarray,
offset: int = 0,
size: int = -1) -> None:
"""Loads the audio data from a NumPy array.
Args:
src: A NumPy source array contains the input audio.
offset: An optional offset for loading a slice of the `src` array to the
buffer.
size: An optional size parameter denoting the number of samples to load
from the `src` array.
Raises:
ValueError: If the input array has an incorrect shape or if
`offset` + `size` exceeds the length of the `src` array.
"""
if src.shape[1] != self._audio_format.num_channels:
raise ValueError(f"Input audio contains an invalid number of channels. "
f"Expect {self._audio_format.num_channels}.")
if size < 0:
size = len(src)
if offset + size > len(src):
raise ValueError(
f"Index out of range. offset {offset} + size {size} should be <= "
f"src's length: {len(src)}")
if len(src) >= len(self._buffer):
# If the internal buffer is shorter than the load target (src), copy
# values from the end of the src array to the internal buffer.
new_offset = offset + size - len(self._buffer)
new_size = len(self._buffer)
self._buffer = src[new_offset:new_offset + new_size].copy()
else:
# Shift the internal buffer backward and add the incoming data to the end
# of the buffer.
shift = size
self._buffer = np.roll(self._buffer, -shift, axis=0)
self._buffer[-shift:, :] = src[offset:offset + size].copy()
@property
def audio_format(self) -> AudioFormat:
"""Gets the audio format of the audio."""
return self._audio_format
@property
def buffer_length(self) -> int:
"""Gets the sample count of the audio."""
return self._buffer.shape[0]
@property
def buffer(self) -> np.ndarray:
"""Gets the internal buffer."""
return self._buffer

View File

@ -24,7 +24,7 @@ py_library(
srcs = ["test_utils.py"], srcs = ["test_utils.py"],
srcs_version = "PY3", srcs_version = "PY3",
visibility = [ visibility = [
"//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__", "//mediapipe/model_maker/python:__subpackages__",
"//mediapipe/tasks:internal", "//mediapipe/tasks:internal",
], ],
deps = [ deps = [

View File

@ -33,8 +33,8 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_Image = image_module.Image _Image = image_module.Image
_ImageFormat = image_frame.ImageFormat _ImageFormat = image_frame.ImageFormat
_OutputType = image_segmenter.OutputType _OutputType = image_segmenter.ImageSegmenterOptions.OutputType
_Activation = image_segmenter.Activation _Activation = image_segmenter.ImageSegmenterOptions.Activation
_ImageSegmenter = image_segmenter.ImageSegmenter _ImageSegmenter = image_segmenter.ImageSegmenter
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode

View File

@ -15,17 +15,25 @@
"""MediaPipe Tasks Vision API.""" """MediaPipe Tasks Vision API."""
import mediapipe.tasks.python.vision.core import mediapipe.tasks.python.vision.core
import mediapipe.tasks.python.vision.gesture_recognizer
import mediapipe.tasks.python.vision.image_classifier import mediapipe.tasks.python.vision.image_classifier
import mediapipe.tasks.python.vision.image_segmenter
import mediapipe.tasks.python.vision.object_detector import mediapipe.tasks.python.vision.object_detector
GestureRecognizer = gesture_recognizer.GestureRecognizer
GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions
ImageClassifier = image_classifier.ImageClassifier ImageClassifier = image_classifier.ImageClassifier
ImageClassifierOptions = image_classifier.ImageClassifierOptions ImageClassifierOptions = image_classifier.ImageClassifierOptions
ImageSegmenter = image_segmenter.ImageSegmenter
ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
ObjectDetector = object_detector.ObjectDetector ObjectDetector = object_detector.ObjectDetector
ObjectDetectorOptions = object_detector.ObjectDetectorOptions ObjectDetectorOptions = object_detector.ObjectDetectorOptions
RunningMode = core.vision_task_running_mode.VisionTaskRunningMode RunningMode = core.vision_task_running_mode.VisionTaskRunningMode
# Remove unnecessary modules to avoid duplication in API docs. # Remove unnecessary modules to avoid duplication in API docs.
del core del core
del gesture_recognizer
del image_classifier del image_classifier
del image_segmenter
del object_detector del object_detector
del mediapipe del mediapipe

View File

@ -44,18 +44,6 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageSegmenterGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000 _MICRO_SECONDS_PER_MILLISECOND = 1000
class OutputType(enum.Enum):
UNSPECIFIED = 0
CATEGORY_MASK = 1
CONFIDENCE_MASK = 2
class Activation(enum.Enum):
NONE = 0
SIGMOID = 1
SOFTMAX = 2
@dataclasses.dataclass @dataclasses.dataclass
class ImageSegmenterOptions: class ImageSegmenterOptions:
"""Options for the image segmenter task. """Options for the image segmenter task.
@ -74,6 +62,17 @@ class ImageSegmenterOptions:
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
""" """
class OutputType(enum.Enum):
UNSPECIFIED = 0
CATEGORY_MASK = 1
CONFIDENCE_MASK = 2
class Activation(enum.Enum):
NONE = 0
SIGMOID = 1
SOFTMAX = 2
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK output_type: Optional[OutputType] = OutputType.CATEGORY_MASK

View File

@ -0,0 +1,21 @@
# This package contains options shared by all MediaPipe Tasks for Web.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
mediapipe_ts_library(
name = "category",
srcs = ["category.d.ts"],
)
mediapipe_ts_library(
name = "classifications",
srcs = ["classifications.d.ts"],
deps = [":category"],
)
mediapipe_ts_library(
name = "landmark",
srcs = ["landmark.d.ts"],
)

View File

@ -0,0 +1,38 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/** A classification category. */
export interface Category {
/** The probability score of this label category. */
score: number;
/** The index of the category in the corresponding label file. */
index: number;
/**
* The label of this category object. Defaults to an empty string if there is
* no category.
*/
categoryName: string;
/**
* The display name of the label, which may be translated for different
* locales. For example, a label, "apple", may be translated into Spanish for
* display purpose, so that the `display_name` is "manzana". Defaults to an
* empty string if there is no display name.
*/
displayName: string;
}

View File

@ -0,0 +1,51 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {Category} from '../../../../tasks/web/components/containers/category';
/** List of predicted categories with an optional timestamp. */
export interface ClassificationEntry {
/**
* The array of predicted categories, usually sorted by descending scores,
* e.g., from high to low probability.
*/
categories: Category[];
/**
* The optional timestamp (in milliseconds) associated to the classification
* entry. This is useful for time series use cases, e.g., audio
* classification.
*/
timestampMs?: number;
}
/** Classifications for a given classifier head. */
export interface Classifications {
/** A list of classification entries. */
entries: ClassificationEntry[];
/**
* The index of the classifier head these categories refer to. This is
* useful for multi-head models.
*/
headIndex: number;
/**
* The name of the classifier head, which is the corresponding tensor
* metadata name.
*/
headName: string;
}

View File

@ -0,0 +1,35 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Landmark represents a point in 3D space with x, y, z coordinates. If
* normalized is true, the landmark coordinates is normalized respect to the
* dimension of image, and the coordinates values are in the range of [0,1].
* Otherwise, it represenet a point in world coordinates.
*/
export class Landmark {
/** The x coordinates of the landmark. */
x: number;
/** The y coordinates of the landmark. */
y: number;
/** The z coordinates of the landmark. */
z: number;
/** Whether this landmark is normalized with respect to the image size. */
normalized: boolean;
}

View File

@ -0,0 +1,33 @@
# This package contains options shared by all MediaPipe Tasks for Web.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
mediapipe_ts_library(
name = "classifier_options",
srcs = ["classifier_options.ts"],
deps = [
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_jspb_proto",
"//mediapipe/tasks/web/core:classifier_options",
],
)
mediapipe_ts_library(
name = "classifier_result",
srcs = ["classifier_result.ts"],
deps = [
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/components/containers:classifications",
],
)
mediapipe_ts_library(
name = "base_options",
srcs = ["base_options.ts"],
deps = [
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
"//mediapipe/tasks/web/core",
],
)

View File

@ -0,0 +1,50 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
import {BaseOptions} from '../../../../tasks/web/core/base_options';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
/**
* Converts a BaseOptions API object to its Protobuf representation.
* @throws If neither a model assset path or buffer is provided
*/
export async function convertBaseOptionsToProto(baseOptions: BaseOptions):
Promise<BaseOptionsProto> {
if (baseOptions.modelAssetPath && baseOptions.modelAssetBuffer) {
throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
}
if (!baseOptions.modelAssetPath && !baseOptions.modelAssetBuffer) {
throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
}
let modelAssetBuffer = baseOptions.modelAssetBuffer;
if (!modelAssetBuffer) {
const response = await fetch(baseOptions.modelAssetPath!.toString());
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
}
const proto = new BaseOptionsProto();
const externalFile = new ExternalFile();
externalFile.setFileContent(modelAssetBuffer);
proto.setModelAsset(externalFile);
return proto;
}

View File

@ -0,0 +1,62 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {ClassifierOptions as ClassifierOptionsProto} from '../../../../tasks/cc/components/processors/proto/classifier_options_pb';
import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options';
/**
* Converts a ClassifierOptions object to its Proto representation, optionally
* based on existing definition.
* @param options The options object to convert to a Proto. Only options that
* are expliclty provided are set.
* @param baseOptions A base object that options can be merged into.
*/
export function convertClassifierOptionsToProto(
options: ClassifierOptions,
baseOptions?: ClassifierOptionsProto): ClassifierOptionsProto {
const classifierOptions =
baseOptions ? baseOptions.clone() : new ClassifierOptionsProto();
if (options.displayNamesLocale) {
classifierOptions.setDisplayNamesLocale(options.displayNamesLocale);
} else if (options.displayNamesLocale === undefined) {
classifierOptions.clearDisplayNamesLocale();
}
if (options.maxResults) {
classifierOptions.setMaxResults(options.maxResults);
} else if ('maxResults' in options) { // Check for undefined
classifierOptions.clearMaxResults();
}
if (options.scoreThreshold) {
classifierOptions.setScoreThreshold(options.scoreThreshold);
} else if ('scoreThreshold' in options) { // Check for undefined
classifierOptions.clearScoreThreshold();
}
if (options.categoryAllowlist) {
classifierOptions.setCategoryAllowlistList(options.categoryAllowlist);
} else if ('categoryAllowlist' in options) { // Check for undefined
classifierOptions.clearCategoryAllowlistList();
}
if (options.categoryDenylist) {
classifierOptions.setCategoryDenylistList(options.categoryDenylist);
} else if ('categoryDenylist' in options) { // Check for undefined
classifierOptions.clearCategoryDenylistList();
}
return classifierOptions;
}

View File

@ -0,0 +1,61 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {ClassificationEntry as ClassificationEntryProto, ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
const DEFAULT_INDEX = -1;
const DEFAULT_SCORE = 0.0;
/**
* Converts a ClassificationEntry proto to the ClassificationEntry result
* type.
*/
function convertFromClassificationEntryProto(source: ClassificationEntryProto):
ClassificationEntry {
const categories = source.getCategoriesList().map(category => {
return {
index: category.getIndex() ?? DEFAULT_INDEX,
score: category.getScore() ?? DEFAULT_SCORE,
displayName: category.getDisplayName() ?? '',
categoryName: category.getCategoryName() ?? '',
};
});
return {
categories,
timestampMs: source.getTimestampMs(),
};
}
/**
* Converts a ClassificationResult proto to a list of classifications.
*/
export function convertFromClassificationResultProto(
classificationResult: ClassificationResult) : Classifications[] {
const result: Classifications[] = [];
for (const classificationsProto of
classificationResult.getClassificationsList()) {
const classifications: Classifications = {
entries: classificationsProto.getEntriesList().map(
entry => convertFromClassificationEntryProto(entry)),
headIndex: classificationsProto.getHeadIndex() ?? DEFAULT_INDEX,
headName: classificationsProto.getHeadName() ?? '',
};
result.push(classifications);
}
return result;
}

View File

@ -0,0 +1,21 @@
# This package contains options shared by all MediaPipe Tasks for Web.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
mediapipe_ts_library(
name = "core",
srcs = [
"base_options.d.ts",
"wasm_loader_options.d.ts",
],
)
mediapipe_ts_library(
name = "classifier_options",
srcs = [
"classifier_options.d.ts",
],
deps = [":core"],
)

View File

@ -0,0 +1,31 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Placeholder for internal dependency on trusted resource url
/** Options to configure MediaPipe Tasks in general. */
export interface BaseOptions {
/**
* The model path to the model asset file. Only one of `modelAssetPath` or
* `modelAssetBuffer` can be set.
*/
modelAssetPath?: string;
/**
* A buffer containing the model aaset. Only one of `modelAssetPath` or
* `modelAssetBuffer` can be set.
*/
modelAssetBuffer?: Uint8Array;
}

View File

@ -0,0 +1,52 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {BaseOptions} from '../../../tasks/web/core/base_options';
/** Options to configure the Mediapipe Classifier Task. */
export interface ClassifierOptions {
/** Options to configure the loading of the model assets. */
baseOptions?: BaseOptions;
/**
* The locale to use for display names specified through the TFLite Model
* Metadata, if any. Defaults to English.
*/
displayNamesLocale?: string|undefined;
/** The maximum number of top-scored detection results to return. */
maxResults?: number|undefined;
/**
* Overrides the value provided in the model metadata. Results below this
* value are rejected.
*/
scoreThreshold?: number|undefined;
/**
* Allowlist of category names. If non-empty, detection results whose category
* name is not in this set will be filtered out. Duplicate or unknown category
* names are ignored. Mutually exclusive with `categoryDenylist`.
*/
categoryAllowlist?: string[]|undefined;
/**
* Denylist of category names. If non-empty, detection results whose category
* name is in this set will be filtered out. Duplicate or unknown category
* names are ignored. Mutually exclusive with `categoryAllowlist`.
*/
categoryDenylist?: string[]|undefined;
}

View File

@ -0,0 +1,25 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Placeholder for internal dependency on trusted resource url
/** An object containing the locations of all Wasm assets */
export interface WasmLoaderOptions {
/** The path to the Wasm loader script. */
wasmLoaderPath: string;
/** The path to the Wasm binary. */
wasmBinaryPath: string;
}

15
package.json Normal file
View File

@ -0,0 +1,15 @@
{
"name": "medipipe-dev",
"version": "0.0.0-alphga",
"description": "MediaPipe GitHub repo",
"devDependencies": {
"@bazel/typescript": "^5.7.1",
"@types/google-protobuf": "^3.15.6",
"@types/offscreencanvas": "^2019.7.0",
"google-protobuf": "^3.21.2",
"protobufjs": "^7.1.2",
"protobufjs-cli": "^1.0.2",
"ts-protoc-gen": "^0.15.0",
"typescript": "^4.8.4"
}
}

View File

@ -1,5 +1,6 @@
absl-py absl-py
attrs>=19.1.0 attrs>=19.1.0
flatbuffers>=2.0
matplotlib matplotlib
numpy numpy
opencv-contrib-python opencv-contrib-python

View File

@ -129,6 +129,15 @@ def _add_mp_init_files():
mp_dir_init_file.close() mp_dir_init_file.close()
def _copy_to_build_lib_dir(build_lib, file):
"""Copy a file from bazel-bin to the build lib dir."""
dst = os.path.join(build_lib + '/', file)
dst_dir = os.path.dirname(dst)
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
shutil.copyfile(os.path.join('bazel-bin/', file), dst)
class GeneratePyProtos(build_ext.build_ext): class GeneratePyProtos(build_ext.build_ext):
"""Generate MediaPipe Python protobuf files by Protocol Compiler.""" """Generate MediaPipe Python protobuf files by Protocol Compiler."""
@ -259,7 +268,7 @@ class BuildModules(build_ext.build_ext):
] ]
if subprocess.call(fetch_model_command) != 0: if subprocess.call(fetch_model_command) != 0:
sys.exit(-1) sys.exit(-1)
self._copy_to_build_lib_dir(external_file) _copy_to_build_lib_dir(self.build_lib, external_file)
def _generate_binary_graph(self, binary_graph_target): def _generate_binary_graph(self, binary_graph_target):
"""Generate binary graph for a particular MediaPipe binary graph target.""" """Generate binary graph for a particular MediaPipe binary graph target."""
@ -277,15 +286,27 @@ class BuildModules(build_ext.build_ext):
bazel_command.append('--define=OPENCV=source') bazel_command.append('--define=OPENCV=source')
if subprocess.call(bazel_command) != 0: if subprocess.call(bazel_command) != 0:
sys.exit(-1) sys.exit(-1)
self._copy_to_build_lib_dir(binary_graph_target + '.binarypb') _copy_to_build_lib_dir(self.build_lib, binary_graph_target + '.binarypb')
def _copy_to_build_lib_dir(self, file):
"""Copy a file from bazel-bin to the build lib dir.""" class GenerateMetadataSchema(build_ext.build_ext):
dst = os.path.join(self.build_lib + '/', file) """Generate metadata python schema files."""
dst_dir = os.path.dirname(dst)
if not os.path.exists(dst_dir): def run(self):
os.makedirs(dst_dir) for target in ['metadata_schema_py', 'schema_py']:
shutil.copyfile(os.path.join('bazel-bin/', file), dst) bazel_command = [
'bazel',
'build',
'--compilation_mode=opt',
'--define=MEDIAPIPE_DISABLE_GPU=1',
'--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable),
'//mediapipe/tasks/metadata:' + target,
]
if subprocess.call(bazel_command) != 0:
sys.exit(-1)
_copy_to_build_lib_dir(
self.build_lib,
'mediapipe/tasks/metadata/' + target + '_generated.py')
class BazelExtension(setuptools.Extension): class BazelExtension(setuptools.Extension):
@ -375,6 +396,7 @@ class BuildPy(build_py.build_py):
build_ext_obj = self.distribution.get_command_obj('build_ext') build_ext_obj = self.distribution.get_command_obj('build_ext')
build_ext_obj.link_opencv = self.link_opencv build_ext_obj.link_opencv = self.link_opencv
self.run_command('gen_protos') self.run_command('gen_protos')
self.run_command('generate_metadata_schema')
self.run_command('build_modules') self.run_command('build_modules')
self.run_command('build_ext') self.run_command('build_ext')
build_py.build_py.run(self) build_py.build_py.run(self)
@ -434,18 +456,25 @@ setuptools.setup(
author_email='mediapipe@google.com', author_email='mediapipe@google.com',
long_description=_get_long_description(), long_description=_get_long_description(),
long_description_content_type='text/markdown', long_description_content_type='text/markdown',
packages=setuptools.find_packages(exclude=['mediapipe.examples.desktop.*']), packages=setuptools.find_packages(
exclude=['mediapipe.examples.desktop.*', 'mediapipe.model_maker.*']),
install_requires=_parse_requirements('requirements.txt'), install_requires=_parse_requirements('requirements.txt'),
cmdclass={ cmdclass={
'build_py': BuildPy, 'build_py': BuildPy,
'gen_protos': GeneratePyProtos,
'build_modules': BuildModules, 'build_modules': BuildModules,
'build_ext': BuildExtension, 'build_ext': BuildExtension,
'generate_metadata_schema': GenerateMetadataSchema,
'gen_protos': GeneratePyProtos,
'install': Install, 'install': Install,
'restore': Restore, 'restore': Restore,
}, },
ext_modules=[ ext_modules=[
BazelExtension('//mediapipe/python:_framework_bindings'), BazelExtension('//mediapipe/python:_framework_bindings'),
BazelExtension(
'//mediapipe/tasks/cc/metadata/python:_pywrap_metadata_version'),
BazelExtension(
'//mediapipe/tasks/python/metadata/flatbuffers_lib:_pywrap_flatbuffers'
),
], ],
zip_safe=False, zip_safe=False,
include_package_data=True, include_package_data=True,

47
tsconfig.json Normal file
View File

@ -0,0 +1,47 @@
{
"compilerOptions": {
"target": "es2017",
"module": "commonjs",
"lib": ["ES2017", "dom"],
"declaration": true,
"moduleResolution": "node",
"esModuleInterop": true,
"noImplicitAny": true,
"inlineSourceMap": true,
"inlineSources": true,
"strict": true,
"types": ["@types/offscreencanvas"],
"rootDirs": [
".",
"./bazel-out/host/bin",
"./bazel-out/darwin-dbg/bin",
"./bazel-out/darwin-fastbuild/bin",
"./bazel-out/darwin-opt/bin",
"./bazel-out/darwin_arm64-dbg/bin",
"./bazel-out/darwin_arm64-fastbuild/bin",
"./bazel-out/darwin_arm64-opt/bin",
"./bazel-out/k8-dbg/bin",
"./bazel-out/k8-fastbuild/bin",
"./bazel-out/k8-opt/bin",
"./bazel-out/x64_windows-dbg/bin",
"./bazel-out/x64_windows-fastbuild/bin",
"./bazel-out/x64_windows-opt/bin",
"./bazel-out/darwin-dbg/bin",
"./bazel-out/darwin-fastbuild/bin",
"./bazel-out/darwin-opt/bin",
"./bazel-out/k8-dbg/bin",
"./bazel-out/k8-fastbuild/bin",
"./bazel-out/k8-opt/bin",
"./bazel-out/x64_windows-dbg/bin",
"./bazel-out/x64_windows-fastbuild/bin",
"./bazel-out/x64_windows-opt/bin",
"./bazel-out/k8-fastbuild-ST-4a519fd6d3e4/bin"
]
},
"exclude": [
"./_bazel_bin",
"./_bazel_buildbot",
"./_bazel_out",
"./_bazel_testlogs"
]
}

626
yarn.lock Normal file
View File

@ -0,0 +1,626 @@
# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY.
# yarn lockfile v1
"@babel/parser@^7.9.4":
version "7.20.1"
resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.1.tgz#3e045a92f7b4623cafc2425eddcb8cf2e54f9cc5"
integrity sha512-hp0AYxaZJhxULfM1zyp7Wgr+pSUKBcP3M+PHnSzWGdXOzg/kHWIgiUWARvubhUKGOEw3xqY4x+lyZ9ytBVcELw==
"@bazel/typescript@^5.7.1":
version "5.7.1"
resolved "https://registry.yarnpkg.com/@bazel/typescript/-/typescript-5.7.1.tgz#e585bcdc54a4ccb23d99c3e1206abf4853cf0682"
integrity sha512-MAnAtFxA2znadm81+rbYXcyWX1DEF/urzZ1F4LBq+w27EQ4PGyqIqCM5om7JcoSZJwjjMoBJc3SflRsMrZZ6+g==
dependencies:
"@bazel/worker" "5.7.1"
semver "5.6.0"
source-map-support "0.5.9"
tsutils "3.21.0"
"@bazel/worker@5.7.1":
version "5.7.1"
resolved "https://registry.yarnpkg.com/@bazel/worker/-/worker-5.7.1.tgz#2c4a9bd0e0ef75e496aec9599ff64a87307e7dad"
integrity sha512-UndmQVRqK0t0NMNl8I1P5XmxzdPvMA0X6jufszpfwy5gyzjOxeiOIzmC0ALCOx78CuJqOB/8WOI1pwTRmhd0tg==
dependencies:
google-protobuf "^3.6.1"
"@protobufjs/aspromise@^1.1.1", "@protobufjs/aspromise@^1.1.2":
version "1.1.2"
resolved "https://registry.yarnpkg.com/@protobufjs/aspromise/-/aspromise-1.1.2.tgz#9b8b0cc663d669a7d8f6f5d0893a14d348f30fbf"
integrity sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ==
"@protobufjs/base64@^1.1.2":
version "1.1.2"
resolved "https://registry.yarnpkg.com/@protobufjs/base64/-/base64-1.1.2.tgz#4c85730e59b9a1f1f349047dbf24296034bb2735"
integrity sha512-AZkcAA5vnN/v4PDqKyMR5lx7hZttPDgClv83E//FMNhR2TMcLUhfRUBHCmSl0oi9zMgDDqRUJkSxO3wm85+XLg==
"@protobufjs/codegen@^2.0.4":
version "2.0.4"
resolved "https://registry.yarnpkg.com/@protobufjs/codegen/-/codegen-2.0.4.tgz#7ef37f0d010fb028ad1ad59722e506d9262815cb"
integrity sha512-YyFaikqM5sH0ziFZCN3xDC7zeGaB/d0IUb9CATugHWbd1FRFwWwt4ld4OYMPWu5a3Xe01mGAULCdqhMlPl29Jg==
"@protobufjs/eventemitter@^1.1.0":
version "1.1.0"
resolved "https://registry.yarnpkg.com/@protobufjs/eventemitter/-/eventemitter-1.1.0.tgz#355cbc98bafad5978f9ed095f397621f1d066b70"
integrity sha512-j9ednRT81vYJ9OfVuXG6ERSTdEL1xVsNgqpkxMsbIabzSo3goCjDIveeGv5d03om39ML71RdmrGNjG5SReBP/Q==
"@protobufjs/fetch@^1.1.0":
version "1.1.0"
resolved "https://registry.yarnpkg.com/@protobufjs/fetch/-/fetch-1.1.0.tgz#ba99fb598614af65700c1619ff06d454b0d84c45"
integrity sha512-lljVXpqXebpsijW71PZaCYeIcE5on1w5DlQy5WH6GLbFryLUrBD4932W/E2BSpfRJWseIL4v/KPgBFxDOIdKpQ==
dependencies:
"@protobufjs/aspromise" "^1.1.1"
"@protobufjs/inquire" "^1.1.0"
"@protobufjs/float@^1.0.2":
version "1.0.2"
resolved "https://registry.yarnpkg.com/@protobufjs/float/-/float-1.0.2.tgz#5e9e1abdcb73fc0a7cb8b291df78c8cbd97b87d1"
integrity sha512-Ddb+kVXlXst9d+R9PfTIxh1EdNkgoRe5tOX6t01f1lYWOvJnSPDBlG241QLzcyPdoNTsblLUdujGSE4RzrTZGQ==
"@protobufjs/inquire@^1.1.0":
version "1.1.0"
resolved "https://registry.yarnpkg.com/@protobufjs/inquire/-/inquire-1.1.0.tgz#ff200e3e7cf2429e2dcafc1140828e8cc638f089"
integrity sha512-kdSefcPdruJiFMVSbn801t4vFK7KB/5gd2fYvrxhuJYg8ILrmn9SKSX2tZdV6V+ksulWqS7aXjBcRXl3wHoD9Q==
"@protobufjs/path@^1.1.2":
version "1.1.2"
resolved "https://registry.yarnpkg.com/@protobufjs/path/-/path-1.1.2.tgz#6cc2b20c5c9ad6ad0dccfd21ca7673d8d7fbf68d"
integrity sha512-6JOcJ5Tm08dOHAbdR3GrvP+yUUfkjG5ePsHYczMFLq3ZmMkAD98cDgcT2iA1lJ9NVwFd4tH/iSSoe44YWkltEA==
"@protobufjs/pool@^1.1.0":
version "1.1.0"
resolved "https://registry.yarnpkg.com/@protobufjs/pool/-/pool-1.1.0.tgz#09fd15f2d6d3abfa9b65bc366506d6ad7846ff54"
integrity sha512-0kELaGSIDBKvcgS4zkjz1PeddatrjYcmMWOlAuAPwAeccUrPHdUqo/J6LiymHHEiJT5NrF1UVwxY14f+fy4WQw==
"@protobufjs/utf8@^1.1.0":
version "1.1.0"
resolved "https://registry.yarnpkg.com/@protobufjs/utf8/-/utf8-1.1.0.tgz#a777360b5b39a1a2e5106f8e858f2fd2d060c570"
integrity sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==
"@types/google-protobuf@^3.15.6":
version "3.15.6"
resolved "https://registry.yarnpkg.com/@types/google-protobuf/-/google-protobuf-3.15.6.tgz#674a69493ef2c849b95eafe69167ea59079eb504"
integrity sha512-pYVNNJ+winC4aek+lZp93sIKxnXt5qMkuKmaqS3WGuTq0Bw1ZDYNBgzG5kkdtwcv+GmYJGo3yEg6z2cKKAiEdw==
"@types/linkify-it@*":
version "3.0.2"
resolved "https://registry.yarnpkg.com/@types/linkify-it/-/linkify-it-3.0.2.tgz#fd2cd2edbaa7eaac7e7f3c1748b52a19143846c9"
integrity sha512-HZQYqbiFVWufzCwexrvh694SOim8z2d+xJl5UNamcvQFejLY/2YUtzXHYi3cHdI7PMlS8ejH2slRAOJQ32aNbA==
"@types/markdown-it@^12.2.3":
version "12.2.3"
resolved "https://registry.yarnpkg.com/@types/markdown-it/-/markdown-it-12.2.3.tgz#0d6f6e5e413f8daaa26522904597be3d6cd93b51"
integrity sha512-GKMHFfv3458yYy+v/N8gjufHO6MSZKCOXpZc5GXIWWy8uldwfmPn98vp81gZ5f9SVw8YYBctgfJ22a2d7AOMeQ==
dependencies:
"@types/linkify-it" "*"
"@types/mdurl" "*"
"@types/mdurl@*":
version "1.0.2"
resolved "https://registry.yarnpkg.com/@types/mdurl/-/mdurl-1.0.2.tgz#e2ce9d83a613bacf284c7be7d491945e39e1f8e9"
integrity sha512-eC4U9MlIcu2q0KQmXszyn5Akca/0jrQmwDRgpAMJai7qBWq4amIQhZyNau4VYGtCeALvW1/NtjzJJ567aZxfKA==
"@types/node@>=13.7.0":
version "18.11.9"
resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.9.tgz#02d013de7058cea16d36168ef2fc653464cfbad4"
integrity sha512-CRpX21/kGdzjOpFsZSkcrXMGIBWMGNIHXXBVFSH+ggkftxg+XYP20TESbh+zFvFj3EQOl5byk0HTRn1IL6hbqg==
"@types/offscreencanvas@^2019.7.0":
version "2019.7.0"
resolved "https://registry.yarnpkg.com/@types/offscreencanvas/-/offscreencanvas-2019.7.0.tgz#e4a932069db47bb3eabeb0b305502d01586fa90d"
integrity sha512-PGcyveRIpL1XIqK8eBsmRBt76eFgtzuPiSTyKHZxnGemp2yzGzWpjYKAfK3wIMiU7eH+851yEpiuP8JZerTmWg==
acorn-jsx@^5.3.2:
version "5.3.2"
resolved "https://registry.yarnpkg.com/acorn-jsx/-/acorn-jsx-5.3.2.tgz#7ed5bb55908b3b2f1bc55c6af1653bada7f07937"
integrity sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==
acorn@^8.8.0:
version "8.8.1"
resolved "https://registry.yarnpkg.com/acorn/-/acorn-8.8.1.tgz#0a3f9cbecc4ec3bea6f0a80b66ae8dd2da250b73"
integrity sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA==
ansi-styles@^4.1.0:
version "4.3.0"
resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937"
integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==
dependencies:
color-convert "^2.0.1"
argparse@^2.0.1:
version "2.0.1"
resolved "https://registry.yarnpkg.com/argparse/-/argparse-2.0.1.tgz#246f50f3ca78a3240f6c997e8a9bd1eac49e4b38"
integrity sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==
balanced-match@^1.0.0:
version "1.0.2"
resolved "https://registry.yarnpkg.com/balanced-match/-/balanced-match-1.0.2.tgz#e83e3a7e3f300b34cb9d87f615fa0cbf357690ee"
integrity sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==
bluebird@^3.7.2:
version "3.7.2"
resolved "https://registry.yarnpkg.com/bluebird/-/bluebird-3.7.2.tgz#9f229c15be272454ffa973ace0dbee79a1b0c36f"
integrity sha512-XpNj6GDQzdfW+r2Wnn7xiSAd7TM3jzkxGXBGTtWKuSXv1xUV+azxAm8jdWZN06QTQk+2N2XB9jRDkvbmQmcRtg==
brace-expansion@^1.1.7:
version "1.1.11"
resolved "https://registry.yarnpkg.com/brace-expansion/-/brace-expansion-1.1.11.tgz#3c7fcbf529d87226f3d2f52b966ff5271eb441dd"
integrity sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==
dependencies:
balanced-match "^1.0.0"
concat-map "0.0.1"
brace-expansion@^2.0.1:
version "2.0.1"
resolved "https://registry.yarnpkg.com/brace-expansion/-/brace-expansion-2.0.1.tgz#1edc459e0f0c548486ecf9fc99f2221364b9a0ae"
integrity sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==
dependencies:
balanced-match "^1.0.0"
buffer-from@^1.0.0:
version "1.1.2"
resolved "https://registry.yarnpkg.com/buffer-from/-/buffer-from-1.1.2.tgz#2b146a6fd72e80b4f55d255f35ed59a3a9a41bd5"
integrity sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==
catharsis@^0.9.0:
version "0.9.0"
resolved "https://registry.yarnpkg.com/catharsis/-/catharsis-0.9.0.tgz#40382a168be0e6da308c277d3a2b3eb40c7d2121"
integrity sha512-prMTQVpcns/tzFgFVkVp6ak6RykZyWb3gu8ckUpd6YkTlacOd3DXGJjIpD4Q6zJirizvaiAjSSHlOsA+6sNh2A==
dependencies:
lodash "^4.17.15"
chalk@^4.0.0:
version "4.1.2"
resolved "https://registry.yarnpkg.com/chalk/-/chalk-4.1.2.tgz#aac4e2b7734a740867aeb16bf02aad556a1e7a01"
integrity sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==
dependencies:
ansi-styles "^4.1.0"
supports-color "^7.1.0"
color-convert@^2.0.1:
version "2.0.1"
resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3"
integrity sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==
dependencies:
color-name "~1.1.4"
color-name@~1.1.4:
version "1.1.4"
resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2"
integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==
concat-map@0.0.1:
version "0.0.1"
resolved "https://registry.yarnpkg.com/concat-map/-/concat-map-0.0.1.tgz#d8a96bd77fd68df7793a73036a3ba0d5405d477b"
integrity sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==
deep-is@~0.1.3:
version "0.1.4"
resolved "https://registry.yarnpkg.com/deep-is/-/deep-is-0.1.4.tgz#a6f2dce612fadd2ef1f519b73551f17e85199831"
integrity sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==
entities@~2.1.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/entities/-/entities-2.1.0.tgz#992d3129cf7df6870b96c57858c249a120f8b8b5"
integrity sha512-hCx1oky9PFrJ611mf0ifBLBRW8lUUVRlFolb5gWRfIELabBlbp9xZvrqZLZAs+NxFnbfQoeGd8wDkygjg7U85w==
escape-string-regexp@^2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-2.0.0.tgz#a30304e99daa32e23b2fd20f51babd07cffca344"
integrity sha512-UpzcLCXolUWcNu5HtVMHYdXJjArjsF9C0aNnquZYY4uW/Vu0miy5YoWvbV345HauVvcAUnpRuhMMcqTcGOY2+w==
escodegen@^1.13.0:
version "1.14.3"
resolved "https://registry.yarnpkg.com/escodegen/-/escodegen-1.14.3.tgz#4e7b81fba61581dc97582ed78cab7f0e8d63f503"
integrity sha512-qFcX0XJkdg+PB3xjZZG/wKSuT1PnQWx57+TVSjIMmILd2yC/6ByYElPwJnslDsuWuSAp4AwJGumarAAmJch5Kw==
dependencies:
esprima "^4.0.1"
estraverse "^4.2.0"
esutils "^2.0.2"
optionator "^0.8.1"
optionalDependencies:
source-map "~0.6.1"
eslint-visitor-keys@^3.3.0:
version "3.3.0"
resolved "https://registry.yarnpkg.com/eslint-visitor-keys/-/eslint-visitor-keys-3.3.0.tgz#f6480fa6b1f30efe2d1968aa8ac745b862469826"
integrity sha512-mQ+suqKJVyeuwGYHAdjMFqjCyfl8+Ldnxuyp3ldiMBFKkvytrXUZWaiPCEav8qDHKty44bD+qV1IP4T+w+xXRA==
espree@^9.0.0:
version "9.4.0"
resolved "https://registry.yarnpkg.com/espree/-/espree-9.4.0.tgz#cd4bc3d6e9336c433265fc0aa016fc1aaf182f8a"
integrity sha512-DQmnRpLj7f6TgN/NYb0MTzJXL+vJF9h3pHy4JhCIs3zwcgez8xmGg3sXHcEO97BrmO2OSvCwMdfdlyl+E9KjOw==
dependencies:
acorn "^8.8.0"
acorn-jsx "^5.3.2"
eslint-visitor-keys "^3.3.0"
esprima@^4.0.1:
version "4.0.1"
resolved "https://registry.yarnpkg.com/esprima/-/esprima-4.0.1.tgz#13b04cdb3e6c5d19df91ab6987a8695619b0aa71"
integrity sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==
estraverse@^4.2.0:
version "4.3.0"
resolved "https://registry.yarnpkg.com/estraverse/-/estraverse-4.3.0.tgz#398ad3f3c5a24948be7725e83d11a7de28cdbd1d"
integrity sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==
estraverse@^5.1.0:
version "5.3.0"
resolved "https://registry.yarnpkg.com/estraverse/-/estraverse-5.3.0.tgz#2eea5290702f26ab8fe5370370ff86c965d21123"
integrity sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==
esutils@^2.0.2:
version "2.0.3"
resolved "https://registry.yarnpkg.com/esutils/-/esutils-2.0.3.tgz#74d2eb4de0b8da1293711910d50775b9b710ef64"
integrity sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==
fast-levenshtein@~2.0.6:
version "2.0.6"
resolved "https://registry.yarnpkg.com/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz#3d8a5c66883a16a30ca8643e851f19baa7797917"
integrity sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==
fs.realpath@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/fs.realpath/-/fs.realpath-1.0.0.tgz#1504ad2523158caa40db4a2787cb01411994ea4f"
integrity sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==
glob@^7.1.3:
version "7.2.3"
resolved "https://registry.yarnpkg.com/glob/-/glob-7.2.3.tgz#b8df0fb802bbfa8e89bd1d938b4e16578ed44f2b"
integrity sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==
dependencies:
fs.realpath "^1.0.0"
inflight "^1.0.4"
inherits "2"
minimatch "^3.1.1"
once "^1.3.0"
path-is-absolute "^1.0.0"
glob@^8.0.0:
version "8.0.3"
resolved "https://registry.yarnpkg.com/glob/-/glob-8.0.3.tgz#415c6eb2deed9e502c68fa44a272e6da6eeca42e"
integrity sha512-ull455NHSHI/Y1FqGaaYFaLGkNMMJbavMrEGFXG/PGrg6y7sutWHUHrz6gy6WEBH6akM1M414dWKCNs+IhKdiQ==
dependencies:
fs.realpath "^1.0.0"
inflight "^1.0.4"
inherits "2"
minimatch "^5.0.1"
once "^1.3.0"
google-protobuf@^3.15.5, google-protobuf@^3.21.2, google-protobuf@^3.6.1:
version "3.21.2"
resolved "https://registry.yarnpkg.com/google-protobuf/-/google-protobuf-3.21.2.tgz#4580a2bea8bbb291ee579d1fefb14d6fa3070ea4"
integrity sha512-3MSOYFO5U9mPGikIYCzK0SaThypfGgS6bHqrUGXG3DPHCrb+txNqeEcns1W0lkGfk0rCyNXm7xB9rMxnCiZOoA==
graceful-fs@^4.1.9:
version "4.2.10"
resolved "https://registry.yarnpkg.com/graceful-fs/-/graceful-fs-4.2.10.tgz#147d3a006da4ca3ce14728c7aefc287c367d7a6c"
integrity sha512-9ByhssR2fPVsNZj478qUUbKfmL0+t5BDVyjShtyZZLiK7ZDAArFFfopyOTj0M05wE2tJPisA4iTnnXl2YoPvOA==
has-flag@^4.0.0:
version "4.0.0"
resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-4.0.0.tgz#944771fd9c81c81265c4d6941860da06bb59479b"
integrity sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==
inflight@^1.0.4:
version "1.0.6"
resolved "https://registry.yarnpkg.com/inflight/-/inflight-1.0.6.tgz#49bd6331d7d02d0c09bc910a1075ba8165b56df9"
integrity sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==
dependencies:
once "^1.3.0"
wrappy "1"
inherits@2:
version "2.0.4"
resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.4.tgz#0fa2c64f932917c3433a0ded55363aae37416b7c"
integrity sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==
js2xmlparser@^4.0.2:
version "4.0.2"
resolved "https://registry.yarnpkg.com/js2xmlparser/-/js2xmlparser-4.0.2.tgz#2a1fdf01e90585ef2ae872a01bc169c6a8d5e60a"
integrity sha512-6n4D8gLlLf1n5mNLQPRfViYzu9RATblzPEtm1SthMX1Pjao0r9YI9nw7ZIfRxQMERS87mcswrg+r/OYrPRX6jA==
dependencies:
xmlcreate "^2.0.4"
jsdoc@^3.6.3:
version "3.6.11"
resolved "https://registry.yarnpkg.com/jsdoc/-/jsdoc-3.6.11.tgz#8bbb5747e6f579f141a5238cbad4e95e004458ce"
integrity sha512-8UCU0TYeIYD9KeLzEcAu2q8N/mx9O3phAGl32nmHlE0LpaJL71mMkP4d+QE5zWfNt50qheHtOZ0qoxVrsX5TUg==
dependencies:
"@babel/parser" "^7.9.4"
"@types/markdown-it" "^12.2.3"
bluebird "^3.7.2"
catharsis "^0.9.0"
escape-string-regexp "^2.0.0"
js2xmlparser "^4.0.2"
klaw "^3.0.0"
markdown-it "^12.3.2"
markdown-it-anchor "^8.4.1"
marked "^4.0.10"
mkdirp "^1.0.4"
requizzle "^0.2.3"
strip-json-comments "^3.1.0"
taffydb "2.6.2"
underscore "~1.13.2"
klaw@^3.0.0:
version "3.0.0"
resolved "https://registry.yarnpkg.com/klaw/-/klaw-3.0.0.tgz#b11bec9cf2492f06756d6e809ab73a2910259146"
integrity sha512-0Fo5oir+O9jnXu5EefYbVK+mHMBeEVEy2cmctR1O1NECcCkPRreJKrS6Qt/j3KC2C148Dfo9i3pCmCMsdqGr0g==
dependencies:
graceful-fs "^4.1.9"
levn@~0.3.0:
version "0.3.0"
resolved "https://registry.yarnpkg.com/levn/-/levn-0.3.0.tgz#3b09924edf9f083c0490fdd4c0bc4421e04764ee"
integrity sha512-0OO4y2iOHix2W6ujICbKIaEQXvFQHue65vUG3pb5EUomzPI90z9hsA1VsO/dbIIpC53J8gxM9Q4Oho0jrCM/yA==
dependencies:
prelude-ls "~1.1.2"
type-check "~0.3.2"
linkify-it@^3.0.1:
version "3.0.3"
resolved "https://registry.yarnpkg.com/linkify-it/-/linkify-it-3.0.3.tgz#a98baf44ce45a550efb4d49c769d07524cc2fa2e"
integrity sha512-ynTsyrFSdE5oZ/O9GEf00kPngmOfVwazR5GKDq6EYfhlpFug3J2zybX56a2PRRpc9P+FuSoGNAwjlbDs9jJBPQ==
dependencies:
uc.micro "^1.0.1"
lodash@^4.17.14, lodash@^4.17.15:
version "4.17.21"
resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c"
integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==
long@^5.0.0:
version "5.2.0"
resolved "https://registry.yarnpkg.com/long/-/long-5.2.0.tgz#2696dadf4b4da2ce3f6f6b89186085d94d52fd61"
integrity sha512-9RTUNjK60eJbx3uz+TEGF7fUr29ZDxR5QzXcyDpeSfeH28S9ycINflOgOlppit5U+4kNTe83KQnMEerw7GmE8w==
lru-cache@^6.0.0:
version "6.0.0"
resolved "https://registry.yarnpkg.com/lru-cache/-/lru-cache-6.0.0.tgz#6d6fe6570ebd96aaf90fcad1dafa3b2566db3a94"
integrity sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==
dependencies:
yallist "^4.0.0"
markdown-it-anchor@^8.4.1:
version "8.6.5"
resolved "https://registry.yarnpkg.com/markdown-it-anchor/-/markdown-it-anchor-8.6.5.tgz#30c4bc5bbff327f15ce3c429010ec7ba75e7b5f8"
integrity sha512-PI1qEHHkTNWT+X6Ip9w+paonfIQ+QZP9sCeMYi47oqhH+EsW8CrJ8J7CzV19QVOj6il8ATGbK2nTECj22ZHGvQ==
markdown-it@^12.3.2:
version "12.3.2"
resolved "https://registry.yarnpkg.com/markdown-it/-/markdown-it-12.3.2.tgz#bf92ac92283fe983fe4de8ff8abfb5ad72cd0c90"
integrity sha512-TchMembfxfNVpHkbtriWltGWc+m3xszaRD0CZup7GFFhzIgQqxIfn3eGj1yZpfuflzPvfkt611B2Q/Bsk1YnGg==
dependencies:
argparse "^2.0.1"
entities "~2.1.0"
linkify-it "^3.0.1"
mdurl "^1.0.1"
uc.micro "^1.0.5"
marked@^4.0.10:
version "4.2.1"
resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.1.tgz#eaa32594e45b4e58c02e4d118531fd04345de3b4"
integrity sha512-VK1/jNtwqDLvPktNpL0Fdg3qoeUZhmRsuiIjPEy/lHwXW4ouLoZfO4XoWd4ClDt+hupV1VLpkZhEovjU0W/kqA==
mdurl@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/mdurl/-/mdurl-1.0.1.tgz#fe85b2ec75a59037f2adfec100fd6c601761152e"
integrity sha512-/sKlQJCBYVY9Ers9hqzKou4H6V5UWc/M59TH2dvkt+84itfnq7uFOMLpOiOS4ujvHP4etln18fmIxA5R5fll0g==
minimatch@^3.1.1:
version "3.1.2"
resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-3.1.2.tgz#19cd194bfd3e428f049a70817c038d89ab4be35b"
integrity sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==
dependencies:
brace-expansion "^1.1.7"
minimatch@^5.0.1:
version "5.1.0"
resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-5.1.0.tgz#1717b464f4971b144f6aabe8f2d0b8e4511e09c7"
integrity sha512-9TPBGGak4nHfGZsPBohm9AWg6NoT7QTCehS3BIJABslyZbzxfV78QM2Y6+i741OPZIafFAaiiEMh5OyIrJPgtg==
dependencies:
brace-expansion "^2.0.1"
minimist@^1.2.0:
version "1.2.7"
resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.7.tgz#daa1c4d91f507390437c6a8bc01078e7000c4d18"
integrity sha512-bzfL1YUZsP41gmu/qjrEk0Q6i2ix/cVeAhbCbqH9u3zYutS1cLg00qhrD0M2MVdCcx4Sc0UpP2eBWo9rotpq6g==
mkdirp@^1.0.4:
version "1.0.4"
resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-1.0.4.tgz#3eb5ed62622756d79a5f0e2a221dfebad75c2f7e"
integrity sha512-vVqVZQyf3WLx2Shd0qJ9xuvqgAyKPLAiqITEtqW0oIUjzo3PePDd6fW9iFz30ef7Ysp/oiWqbhszeGWW2T6Gzw==
once@^1.3.0:
version "1.4.0"
resolved "https://registry.yarnpkg.com/once/-/once-1.4.0.tgz#583b1aa775961d4b113ac17d9c50baef9dd76bd1"
integrity sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==
dependencies:
wrappy "1"
optionator@^0.8.1:
version "0.8.3"
resolved "https://registry.yarnpkg.com/optionator/-/optionator-0.8.3.tgz#84fa1d036fe9d3c7e21d99884b601167ec8fb495"
integrity sha512-+IW9pACdk3XWmmTXG8m3upGUJst5XRGzxMRjXzAuJ1XnIFNvfhjjIuYkDvysnPQ7qzqVzLt78BCruntqRhWQbA==
dependencies:
deep-is "~0.1.3"
fast-levenshtein "~2.0.6"
levn "~0.3.0"
prelude-ls "~1.1.2"
type-check "~0.3.2"
word-wrap "~1.2.3"
path-is-absolute@^1.0.0:
version "1.0.1"
resolved "https://registry.yarnpkg.com/path-is-absolute/-/path-is-absolute-1.0.1.tgz#174b9268735534ffbc7ace6bf53a5a9e1b5c5f5f"
integrity sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==
prelude-ls@~1.1.2:
version "1.1.2"
resolved "https://registry.yarnpkg.com/prelude-ls/-/prelude-ls-1.1.2.tgz#21932a549f5e52ffd9a827f570e04be62a97da54"
integrity sha512-ESF23V4SKG6lVSGZgYNpbsiaAkdab6ZgOxe52p7+Kid3W3u3bxR4Vfd/o21dmN7jSt0IwgZ4v5MUd26FEtXE9w==
protobufjs-cli@^1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/protobufjs-cli/-/protobufjs-cli-1.0.2.tgz#905fc49007cf4aaf3c45d5f250eb294eedeea062"
integrity sha512-cz9Pq9p/Zs7okc6avH20W7QuyjTclwJPgqXG11jNaulfS3nbVisID8rC+prfgq0gbZE0w9LBFd1OKFF03kgFzg==
dependencies:
chalk "^4.0.0"
escodegen "^1.13.0"
espree "^9.0.0"
estraverse "^5.1.0"
glob "^8.0.0"
jsdoc "^3.6.3"
minimist "^1.2.0"
semver "^7.1.2"
tmp "^0.2.1"
uglify-js "^3.7.7"
protobufjs@^7.1.2:
version "7.1.2"
resolved "https://registry.yarnpkg.com/protobufjs/-/protobufjs-7.1.2.tgz#a0cf6aeaf82f5625bffcf5a38b7cd2a7de05890c"
integrity sha512-4ZPTPkXCdel3+L81yw3dG6+Kq3umdWKh7Dc7GW/CpNk4SX3hK58iPCWeCyhVTDrbkNeKrYNZ7EojM5WDaEWTLQ==
dependencies:
"@protobufjs/aspromise" "^1.1.2"
"@protobufjs/base64" "^1.1.2"
"@protobufjs/codegen" "^2.0.4"
"@protobufjs/eventemitter" "^1.1.0"
"@protobufjs/fetch" "^1.1.0"
"@protobufjs/float" "^1.0.2"
"@protobufjs/inquire" "^1.1.0"
"@protobufjs/path" "^1.1.2"
"@protobufjs/pool" "^1.1.0"
"@protobufjs/utf8" "^1.1.0"
"@types/node" ">=13.7.0"
long "^5.0.0"
requizzle@^0.2.3:
version "0.2.3"
resolved "https://registry.yarnpkg.com/requizzle/-/requizzle-0.2.3.tgz#4675c90aacafb2c036bd39ba2daa4a1cb777fded"
integrity sha512-YanoyJjykPxGHii0fZP0uUPEXpvqfBDxWV7s6GKAiiOsiqhX6vHNyW3Qzdmqp/iq/ExbhaGbVrjB4ruEVSM4GQ==
dependencies:
lodash "^4.17.14"
rimraf@^3.0.0:
version "3.0.2"
resolved "https://registry.yarnpkg.com/rimraf/-/rimraf-3.0.2.tgz#f1a5402ba6220ad52cc1282bac1ae3aa49fd061a"
integrity sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==
dependencies:
glob "^7.1.3"
semver@5.6.0:
version "5.6.0"
resolved "https://registry.yarnpkg.com/semver/-/semver-5.6.0.tgz#7e74256fbaa49c75aa7c7a205cc22799cac80004"
integrity sha512-RS9R6R35NYgQn++fkDWaOmqGoj4Ek9gGs+DPxNUZKuwE183xjJroKvyo1IzVFeXvUrvmALy6FWD5xrdJT25gMg==
semver@^7.1.2:
version "7.3.8"
resolved "https://registry.yarnpkg.com/semver/-/semver-7.3.8.tgz#07a78feafb3f7b32347d725e33de7e2a2df67798"
integrity sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==
dependencies:
lru-cache "^6.0.0"
source-map-support@0.5.9:
version "0.5.9"
resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.9.tgz#41bc953b2534267ea2d605bccfa7bfa3111ced5f"
integrity sha512-gR6Rw4MvUlYy83vP0vxoVNzM6t8MUXqNuRsuBmBHQDu1Fh6X015FrLdgoDKcNdkwGubozq0P4N0Q37UyFVr1EA==
dependencies:
buffer-from "^1.0.0"
source-map "^0.6.0"
source-map@^0.6.0, source-map@~0.6.1:
version "0.6.1"
resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.6.1.tgz#74722af32e9614e9c287a8d0bbde48b5e2f1a263"
integrity sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==
strip-json-comments@^3.1.0:
version "3.1.1"
resolved "https://registry.yarnpkg.com/strip-json-comments/-/strip-json-comments-3.1.1.tgz#31f1281b3832630434831c310c01cccda8cbe006"
integrity sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==
supports-color@^7.1.0:
version "7.2.0"
resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-7.2.0.tgz#1b7dcdcb32b8138801b3e478ba6a51caa89648da"
integrity sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==
dependencies:
has-flag "^4.0.0"
taffydb@2.6.2:
version "2.6.2"
resolved "https://registry.yarnpkg.com/taffydb/-/taffydb-2.6.2.tgz#7cbcb64b5a141b6a2efc2c5d2c67b4e150b2a268"
integrity sha512-y3JaeRSplks6NYQuCOj3ZFMO3j60rTwbuKCvZxsAraGYH2epusatvZ0baZYA01WsGqJBq/Dl6vOrMUJqyMj8kA==
tmp@^0.2.1:
version "0.2.1"
resolved "https://registry.yarnpkg.com/tmp/-/tmp-0.2.1.tgz#8457fc3037dcf4719c251367a1af6500ee1ccf14"
integrity sha512-76SUhtfqR2Ijn+xllcI5P1oyannHNHByD80W1q447gU3mp9G9PSpGdWmjUOHRDPiHYacIk66W7ubDTuPF3BEtQ==
dependencies:
rimraf "^3.0.0"
ts-protoc-gen@^0.15.0:
version "0.15.0"
resolved "https://registry.yarnpkg.com/ts-protoc-gen/-/ts-protoc-gen-0.15.0.tgz#2fec5930b46def7dcc9fa73c060d770b7b076b7b"
integrity sha512-TycnzEyrdVDlATJ3bWFTtra3SCiEP0W0vySXReAuEygXCUr1j2uaVyL0DhzjwuUdQoW5oXPwk6oZWeA0955V+g==
dependencies:
google-protobuf "^3.15.5"
tslib@^1.8.1:
version "1.14.1"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.14.1.tgz#cf2d38bdc34a134bcaf1091c41f6619e2f672d00"
integrity sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==
tsutils@3.21.0:
version "3.21.0"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-3.21.0.tgz#b48717d394cea6c1e096983eed58e9d61715b623"
integrity sha512-mHKK3iUXL+3UF6xL5k0PEhKRUBKPBCv/+RkEOpjRWxxx27KKRBmmA60A9pgOUvMi8GKhRMPEmjBRPzs2W7O1OA==
dependencies:
tslib "^1.8.1"
type-check@~0.3.2:
version "0.3.2"
resolved "https://registry.yarnpkg.com/type-check/-/type-check-0.3.2.tgz#5884cab512cf1d355e3fb784f30804b2b520db72"
integrity sha512-ZCmOJdvOWDBYJlzAoFkC+Q0+bUyEOS1ltgp1MGU03fqHG+dbi9tBFU2Rd9QKiDZFAYrhPh2JUf7rZRIuHRKtOg==
dependencies:
prelude-ls "~1.1.2"
typescript@^4.8.4:
version "4.8.4"
resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.8.4.tgz#c464abca159669597be5f96b8943500b238e60e6"
integrity sha512-QCh+85mCy+h0IGff8r5XWzOVSbBO+KfeYrMQh7NJ58QujwcE22u+NUSmUxqF+un70P9GXKxa2HCNiTTMJknyjQ==
uc.micro@^1.0.1, uc.micro@^1.0.5:
version "1.0.6"
resolved "https://registry.yarnpkg.com/uc.micro/-/uc.micro-1.0.6.tgz#9c411a802a409a91fc6cf74081baba34b24499ac"
integrity sha512-8Y75pvTYkLJW2hWQHXxoqRgV7qb9B+9vFEtidML+7koHUFapnVJAZ6cKs+Qjz5Aw3aZWHMC6u0wJE3At+nSGwA==
uglify-js@^3.7.7:
version "3.17.4"
resolved "https://registry.yarnpkg.com/uglify-js/-/uglify-js-3.17.4.tgz#61678cf5fa3f5b7eb789bb345df29afb8257c22c"
integrity sha512-T9q82TJI9e/C1TAxYvfb16xO120tMVFZrGA3f9/P4424DNu6ypK103y0GPFVa17yotwSyZW5iYXgjYHkGrJW/g==
underscore@~1.13.2:
version "1.13.6"
resolved "https://registry.yarnpkg.com/underscore/-/underscore-1.13.6.tgz#04786a1f589dc6c09f761fc5f45b89e935136441"
integrity sha512-+A5Sja4HP1M08MaXya7p5LvjuM7K6q/2EaC0+iovj/wOcMsTzMvDFbasi/oSapiwOlt252IqsKqPjCl7huKS0A==
word-wrap@~1.2.3:
version "1.2.3"
resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c"
integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ==
wrappy@1:
version "1.0.2"
resolved "https://registry.yarnpkg.com/wrappy/-/wrappy-1.0.2.tgz#b5243d8f3ec1aa35f1364605bc0d1036e30ab69f"
integrity sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==
xmlcreate@^2.0.4:
version "2.0.4"
resolved "https://registry.yarnpkg.com/xmlcreate/-/xmlcreate-2.0.4.tgz#0c5ab0f99cdd02a81065fa9cd8f8ae87624889be"
integrity sha512-nquOebG4sngPmGPICTS5EnxqhKbCmz5Ox5hsszI2T6U5qdrJizBc+0ilYSEjTSzU0yZcmvppztXe/5Al5fUwdg==
yallist@^4.0.0:
version "4.0.0"
resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72"
integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==