Merge branch 'google:master' into image-embedder-python

This commit is contained in:
Kinar R 2022-11-09 12:41:55 +05:30 committed by GitHub
commit 36c50ff8f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
121 changed files with 13576 additions and 622 deletions

View File

@ -1324,6 +1324,7 @@ cc_test(
name = "image_to_tensor_utils_test", name = "image_to_tensor_utils_test",
srcs = ["image_to_tensor_utils_test.cc"], srcs = ["image_to_tensor_utils_test.cc"],
deps = [ deps = [
":image_to_tensor_calculator_cc_proto",
":image_to_tensor_utils", ":image_to_tensor_utils",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",

View File

@ -330,9 +330,8 @@ class GlProcessor : public ImageToTensorConverter {
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) { absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4) RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size(); << "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1) RET_CHECK_GE(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this " << "The batch dimension needs to be greater or equal to 1.";
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 3) RET_CHECK_EQ(output_shape.dims[3], 3)
<< "Wrong output channel: " << output_shape.dims[3]; << "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus(); return absl::OkStatus();

View File

@ -172,7 +172,7 @@ constexpr char kValidIntProto[] = R"(
output_tensor_height: 200 output_tensor_height: 200
)"; )";
TEST(ValidateOptionOutputDims, ValidProtos) { TEST(ValidateOptionOutputDims, ImageToTensorCalcOptions) {
const auto float_options = const auto float_options =
mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>( mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(
kValidFloatProto); kValidFloatProto);
@ -202,7 +202,7 @@ TEST(ValidateOptionOutputDims, EmptyProto) {
HasSubstr("Valid output tensor width is required"))); HasSubstr("Valid output tensor width is required")));
} }
TEST(GetOutputTensorParams, SetValues) { TEST(GetOutputTensorParams, ImageToTensorCalcOptionsSetValues) {
// Test int range with ImageToTensorCalculatorOptions. // Test int range with ImageToTensorCalculatorOptions.
const auto int_options = const auto int_options =
mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>( mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(

Binary file not shown.

After

Width:  |  Height:  |  Size: 319 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

View File

@ -106,6 +106,13 @@ class MultiPort : public Single {
return Single{&GetWithAutoGrow(&vec_, index)}; return Single{&GetWithAutoGrow(&vec_, index)};
} }
template <typename U>
auto Cast() {
using SingleCastT =
std::invoke_result_t<decltype(&Single::template Cast<U>), Single*>;
return MultiPort<SingleCastT>(&vec_);
}
private: private:
std::vector<std::unique_ptr<Base>>& vec_; std::vector<std::unique_ptr<Base>>& vec_;
}; };

View File

@ -445,6 +445,57 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
} }
TEST(BuilderTest, MultiPortIsCastToMultiPort) {
builder::Graph graph;
builder::MultiSource<AnyType> any_input = graph.In("ANY_INPUT");
builder::MultiSource<int> int_input = any_input.Cast<int>();
builder::MultiDestination<AnyType> any_output = graph.Out("ANY_OUTPUT");
builder::MultiDestination<int> int_output = any_output.Cast<int>();
int_input >> int_output;
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "ANY_INPUT:__stream_0"
output_stream: "ANY_OUTPUT:__stream_0"
)pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) {
builder::Graph graph;
builder::MultiSource<AnyType> any_multi_input = graph.In("ANY_INPUT");
builder::Source<AnyType> any_input = any_multi_input;
builder::MultiDestination<AnyType> any_multi_output = graph.Out("ANY_OUTPUT");
builder::Destination<AnyType> any_output = any_multi_output;
any_input >> any_output;
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "ANY_INPUT:__stream_0"
output_stream: "ANY_OUTPUT:__stream_0"
)pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) {
builder::Graph graph;
builder::Source<int> int_input = graph.In("INT_INPUT").Cast<int>();
builder::Source<AnyType> any_input = graph.In("ANY_OUTPUT");
builder::Destination<int> int_output = graph.Out("INT_OUTPUT").Cast<int>();
builder::Destination<AnyType> any_output = graph.Out("ANY_OUTPUT");
int_input >> int_output;
any_input >> any_output;
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "ANY_OUTPUT:__stream_0"
input_stream: "INT_INPUT:__stream_1"
output_stream: "ANY_OUTPUT:__stream_0"
output_stream: "INT_OUTPUT:__stream_1"
)pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
} // namespace test } // namespace test
} // namespace api2 } // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -38,20 +38,18 @@ static pthread_key_t egl_release_thread_key;
static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT; static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT;
static void EglThreadExitCallback(void* key_value) { static void EglThreadExitCallback(void* key_value) {
#if defined(__ANDROID__) EGLDisplay current_display = eglGetCurrentDisplay();
eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE, if (current_display != EGL_NO_DISPLAY) {
EGL_NO_CONTEXT); // Some implementations have chosen to allow EGL_NO_DISPLAY as a valid
#else // display parameter for eglMakeCurrent. This behavior is not portable to
// Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display // all EGL implementations, and should be considered as an undocumented
// parameter for eglMakeCurrent. This behavior is not portable to all EGL // vendor extension.
// implementations, and should be considered as an undocumented vendor // https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
// extension. // Instead, to release the current context, we pass the current display.
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml // If the current display is already EGL_NO_DISPLAY, no context is current.
// eglMakeCurrent(current_display, EGL_NO_SURFACE, EGL_NO_SURFACE,
// NOTE: crashes on some Android devices (occurs with libGLES_meow.so). EGL_NO_CONTEXT);
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE, }
EGL_NO_SURFACE, EGL_NO_CONTEXT);
#endif
eglReleaseThread(); eglReleaseThread();
} }

View File

@ -20,3 +20,10 @@ package_group(
"//mediapipe/model_maker/...", "//mediapipe/model_maker/...",
], ],
) )
package_group(
name = "1p_client",
packages = [
"//research/privacy/learning/fl_eval/pcvr/...",
],
)

View File

@ -19,7 +19,6 @@ import tempfile
from typing import Optional from typing import Optional
# TODO: Integrate this class into ImageClassifier and other tasks.
@dataclasses.dataclass @dataclasses.dataclass
class BaseHParams: class BaseHParams:
"""Hyperparameters used for training models. """Hyperparameters used for training models.

View File

@ -45,7 +45,10 @@ py_library(
srcs = ["classifier.py"], srcs = ["classifier.py"],
deps = [ deps = [
":custom_model", ":custom_model",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/utils:model_util",
], ],
) )

View File

@ -13,24 +13,24 @@
# limitations under the License. # limitations under the License.
"""Custom classifier.""" """Custom classifier."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os import os
from typing import Any, List from typing import Any, Callable, Optional, Sequence, Union
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
from mediapipe.model_maker.python.core.data import dataset from mediapipe.model_maker.python.core.data import dataset
from mediapipe.model_maker.python.core.tasks import custom_model from mediapipe.model_maker.python.core.tasks import custom_model
from mediapipe.model_maker.python.core.utils import model_util
class Classifier(custom_model.CustomModel): class Classifier(custom_model.CustomModel):
"""An abstract base class that represents a TensorFlow classifier.""" """An abstract base class that represents a TensorFlow classifier."""
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool): def __init__(self, model_spec: Any, label_names: Sequence[str],
"""Initilizes a classifier with its specifications. shuffle: bool):
"""Initializes a classifier with its specifications.
Args: Args:
model_spec: Specification for the model. model_spec: Specification for the model.
@ -40,6 +40,59 @@ class Classifier(custom_model.CustomModel):
super(Classifier, self).__init__(model_spec, shuffle) super(Classifier, self).__init__(model_spec, shuffle)
self._label_names = label_names self._label_names = label_names
self._num_classes = len(label_names) self._num_classes = len(label_names)
self._model: tf.keras.Model = None
self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None
self._loss_function: Union[str, tf.keras.losses.Loss] = None
self._metric_function: Union[str, tf.keras.metrics.Metric] = None
self._callbacks: Sequence[tf.keras.callbacks.Callback] = None
self._hparams: hp.BaseHParams = None
self._history: tf.keras.callbacks.History = None
# TODO: Integrate this into all Model Maker tasks.
def _train_model(self,
train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset,
preprocessor: Optional[Callable[..., bool]] = None):
"""Trains the classifier model.
Compiles and fits the tf.keras `_model` and records the `_history`.
Args:
train_data: Training data.
validation_data: Validation data.
preprocessor: An optional data preprocessor that can be used when
generating a tf.data.Dataset.
"""
tf.compat.v1.logging.info('Training the models...')
if len(train_data) < self._hparams.batch_size:
raise ValueError(
f'The size of the train_data {len(train_data)} can\'t be smaller than'
f' batch_size {self._hparams.batch_size}. To solve this problem, set'
' the batch_size smaller or increase the size of the train_data.')
train_dataset = train_data.gen_tf_dataset(
batch_size=self._hparams.batch_size,
is_training=True,
shuffle=self._shuffle,
preprocess=preprocessor)
self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=self._hparams.steps_per_epoch,
batch_size=self._hparams.batch_size,
train_data=train_data)
train_dataset = train_dataset.take(count=self._hparams.steps_per_epoch)
validation_dataset = validation_data.gen_tf_dataset(
batch_size=self._hparams.batch_size,
is_training=False,
preprocess=preprocessor)
self._model.compile(
optimizer=self._optimizer,
loss=self._loss_function,
metrics=[self._metric_function])
self._history = self._model.fit(
x=train_dataset,
epochs=self._hparams.epochs,
validation_data=validation_dataset,
callbacks=self._callbacks)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
"""Evaluates the classifier with the provided evaluation dataset. """Evaluates the classifier with the provided evaluation dataset.

View File

@ -35,6 +35,7 @@ py_library(
name = "model_util", name = "model_util",
srcs = ["model_util.py"], srcs = ["model_util.py"],
deps = [ deps = [
":file_util",
":quantization", ":quantization",
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
], ],
@ -50,6 +51,18 @@ py_test(
], ],
) )
py_library(
name = "file_util",
srcs = ["file_util.py"],
)
py_test(
name = "file_util_test",
srcs = ["file_util_test.py"],
data = ["//mediapipe/model_maker/python/core/utils/testdata"],
deps = [":file_util"],
)
py_library( py_library(
name = "loss_functions", name = "loss_functions",
srcs = ["loss_functions.py"], srcs = ["loss_functions.py"],

View File

@ -0,0 +1,36 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for files."""
import os
# resources dependency
def get_absolute_path(file_path: str) -> str:
"""Gets the absolute path of a file.
Args:
file_path: The path to a file relative to the `mediapipe` dir
Returns:
The full path of the file
"""
# Extract the file path before mediapipe/ as the `base_dir`. By joining it
# with the `path` which defines the relative path under mediapipe/, it
# yields to the absolute path of the model files directory.
cwd = os.path.dirname(__file__)
base_dir = cwd[:cwd.rfind('mediapipe')]
absolute_path = os.path.join(base_dir, file_path)
return absolute_path

View File

@ -0,0 +1,29 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from absl.testing import absltest
from mediapipe.model_maker.python.core.utils import file_util
class FileUtilTest(absltest.TestCase):
def test_get_absolute_path(self):
test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt'
absolute_path = file_util.get_absolute_path(test_file)
self.assertTrue(os.path.exists(absolute_path))
if __name__ == '__main__':
absltest.main()

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Utilities for keras models.""" """Utilities for models."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -19,21 +19,33 @@ from __future__ import print_function
import os import os
import tempfile import tempfile
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
# Dependency imports # Dependency imports
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
# resources dependency
from mediapipe.model_maker.python.core.data import dataset from mediapipe.model_maker.python.core.data import dataset
from mediapipe.model_maker.python.core.utils import file_util
from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import quantization
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0 DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
ESTIMITED_STEPS_PER_EPOCH = 1000 ESTIMITED_STEPS_PER_EPOCH = 1000
def get_default_callbacks(
export_dir: str) -> Sequence[tf.keras.callbacks.Callback]:
"""Gets default callbacks."""
summary_dir = os.path.join(export_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
checkpoint_path = os.path.join(export_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True)
return [summary_callback, checkpoint_callback]
def load_keras_model(model_path: str, def load_keras_model(model_path: str,
compile_on_load: bool = False) -> tf.keras.Model: compile_on_load: bool = False) -> tf.keras.Model:
"""Loads a tensorflow Keras model from file and returns the Keras model. """Loads a tensorflow Keras model from file and returns the Keras model.
@ -49,16 +61,26 @@ def load_keras_model(model_path: str,
Returns: Returns:
A tensorflow Keras model. A tensorflow Keras model.
""" """
# Extract the file path before mediapipe/ as the `base_dir`. By joining it absolute_path = file_util.get_absolute_path(model_path)
# with the `model_path` which defines the relative path under mediapipe/, it
# yields to the aboslution path of the model files directory.
cwd = os.path.dirname(__file__)
base_dir = cwd[:cwd.rfind('mediapipe')]
absolute_path = os.path.join(base_dir, model_path)
return tf.keras.models.load_model( return tf.keras.models.load_model(
absolute_path, custom_objects={'tf': tf}, compile=compile_on_load) absolute_path, custom_objects={'tf': tf}, compile=compile_on_load)
def load_tflite_model_buffer(model_path: str) -> bytearray:
"""Loads a TFLite model buffer from file.
Args:
model_path: Relative path to a TFLite file
Returns:
A TFLite model buffer
"""
absolute_path = file_util.get_absolute_path(model_path)
with tf.io.gfile.GFile(absolute_path, 'rb') as f:
tflite_model_buffer = f.read()
return tflite_model_buffer
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None, def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
batch_size: Optional[int] = None, batch_size: Optional[int] = None,
train_data: Optional[dataset.Dataset] = None) -> int: train_data: Optional[dataset.Dataset] = None) -> int:
@ -174,7 +196,7 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
lambda: self.decay_schedule_fn(step), lambda: self.decay_schedule_fn(step),
name=name) name=name)
def get_config(self) -> Dict[Text, Any]: def get_config(self) -> Dict[str, Any]:
return { return {
'initial_learning_rate': self.initial_learning_rate, 'initial_learning_rate': self.initial_learning_rate,
'decay_schedule_fn': self.decay_schedule_fn, 'decay_schedule_fn': self.decay_schedule_fn,

View File

@ -24,7 +24,7 @@ from mediapipe.model_maker.python.core.utils import test_util
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
def test_load_model(self): def test_load_keras_model(self):
input_dim = 4 input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2) model = test_util.build_model(input_shape=[input_dim], num_classes=2)
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model') saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
@ -36,6 +36,19 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
loaded_model_output = loaded_model.predict_on_batch(input_tensors) loaded_model_output = loaded_model.predict_on_batch(input_tensors)
self.assertTrue((model_output == loaded_model_output).all()) self.assertTrue((model_output == loaded_model_output).all())
def test_load_tflite_model_buffer(self):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_model = model_util.convert_to_tflite(model)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file)
test_util.test_tflite(
keras_model=model,
tflite_model=tflite_model_buffer,
size=[1, input_dim])
@parameterized.named_parameters( @parameterized.named_parameters(
dict( dict(
testcase_name='input_only_steps_per_epoch', testcase_name='input_only_steps_per_epoch',

View File

@ -0,0 +1,23 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//mediapipe/model_maker/python/core/utils:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "testdata",
srcs = ["test.txt"],
)

View File

@ -28,6 +28,8 @@ py_library(
":dataset", ":dataset",
":hyperparameters", ":hyperparameters",
":image_classifier", ":image_classifier",
":image_classifier_options",
":model_options",
":model_spec", ":model_spec",
], ],
) )
@ -58,6 +60,24 @@ py_test(
py_library( py_library(
name = "hyperparameters", name = "hyperparameters",
srcs = ["hyperparameters.py"], srcs = ["hyperparameters.py"],
deps = [
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
py_library(
name = "model_options",
srcs = ["model_options.py"],
)
py_library(
name = "image_classifier_options",
srcs = ["image_classifier_options.py"],
deps = [
":hyperparameters",
":model_options",
":model_spec",
],
) )
py_library( py_library(
@ -74,6 +94,8 @@ py_library(
srcs = ["image_classifier.py"], srcs = ["image_classifier.py"],
deps = [ deps = [
":hyperparameters", ":hyperparameters",
":image_classifier_options",
":model_options",
":model_spec", ":model_spec",
":train_image_classifier_lib", ":train_image_classifier_lib",
"//mediapipe/model_maker/python/core/data:classification_dataset", "//mediapipe/model_maker/python/core/data:classification_dataset",
@ -99,6 +121,7 @@ py_library(
py_test( py_test(
name = "image_classifier_test", name = "image_classifier_test",
size = "large",
srcs = ["image_classifier_test.py"], srcs = ["image_classifier_test.py"],
shard_count = 2, shard_count = 2,
tags = ["requires-net:external"], tags = ["requires-net:external"],

View File

@ -16,10 +16,14 @@
from mediapipe.model_maker.python.vision.image_classifier import dataset from mediapipe.model_maker.python.vision.image_classifier import dataset
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
from mediapipe.model_maker.python.vision.image_classifier import image_classifier from mediapipe.model_maker.python.vision.image_classifier import image_classifier
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
from mediapipe.model_maker.python.vision.image_classifier import model_options
from mediapipe.model_maker.python.vision.image_classifier import model_spec from mediapipe.model_maker.python.vision.image_classifier import model_spec
ImageClassifier = image_classifier.ImageClassifier ImageClassifier = image_classifier.ImageClassifier
HParams = hyperparameters.HParams HParams = hyperparameters.HParams
Dataset = dataset.Dataset Dataset = dataset.Dataset
ModelOptions = model_options.ImageClassifierModelOptions
ModelSpec = model_spec.ModelSpec ModelSpec = model_spec.ModelSpec
SupportedModels = model_spec.SupportedModels SupportedModels = model_spec.SupportedModels
ImageClassifierOptions = image_classifier_options.ImageClassifierOptions

View File

@ -14,28 +14,20 @@
"""Hyperparameters for training image classification models.""" """Hyperparameters for training image classification models."""
import dataclasses import dataclasses
import tempfile
from typing import Optional from mediapipe.model_maker.python.core import hyperparameters as hp
# TODO: Expose other hyperparameters, e.g. data augmentation
# hyperparameters if requested.
@dataclasses.dataclass @dataclasses.dataclass
class HParams: class HParams(hp.BaseHParams):
"""The hyperparameters for training image classifiers. """The hyperparameters for training image classifiers.
The hyperparameters include: Attributes:
# Parameters about training data. learning_rate: Learning rate to use for gradient descent training.
batch_size: Batch size for training.
epochs: Number of training iterations over the dataset.
do_fine_tuning: If true, the base module is trained together with the do_fine_tuning: If true, the base module is trained together with the
classification layer on top. classification layer on top.
shuffle: A boolean controlling if shuffle the dataset. Default to false.
# Parameters about training configuration
train_epochs: Training will do this many iterations over the dataset.
batch_size: Each training step samples a batch of this many images.
learning_rate: The learning rate to use for gradient descent training.
dropout_rate: The fraction of the input units to drop, used in dropout
layer.
l1_regularizer: A regularizer that applies a L1 regularization penalty. l1_regularizer: A regularizer that applies a L1 regularization penalty.
l2_regularizer: A regularizer that applies a L2 regularization penalty. l2_regularizer: A regularizer that applies a L2 regularization penalty.
label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for
@ -43,32 +35,21 @@ class HParams:
do_data_augmentation: A boolean controlling whether the training dataset is do_data_augmentation: A boolean controlling whether the training dataset is
augmented by randomly distorting input images, including random cropping, augmented by randomly distorting input images, including random cropping,
flipping, etc. See utils.image_preprocessing documentation for details. flipping, etc. See utils.image_preprocessing documentation for details.
steps_per_epoch: An optional integer indicate the number of training steps
per epoch. If not set, the training pipeline calculates the default steps
per epoch as the training dataset size devided by batch size.
decay_samples: Number of training samples used to calculate the decay steps decay_samples: Number of training samples used to calculate the decay steps
and create the training optimizer. and create the training optimizer.
warmup_steps: Number of warmup steps for a linear increasing warmup schedule warmup_steps: Number of warmup steps for a linear increasing warmup schedule
on learning rate. Used to set up warmup schedule by model_util.WarmUp. on learning rate. Used to set up warmup schedule by model_util.WarmUp.s
# Parameters about the saved checkpoint
model_dir: The location of model checkpoint files and exported model files.
""" """
# Parameters about training data # Parameters from BaseHParams class.
do_fine_tuning: bool = False learning_rate: float = 0.001
shuffle: bool = False batch_size: int = 2
epochs: int = 10
# Parameters about training configuration # Parameters about training configuration
train_epochs: int = 5 do_fine_tuning: bool = False
batch_size: int = 32
learning_rate: float = 0.005
dropout_rate: float = 0.2
l1_regularizer: float = 0.0 l1_regularizer: float = 0.0
l2_regularizer: float = 0.0001 l2_regularizer: float = 0.0001
label_smoothing: float = 0.1 label_smoothing: float = 0.1
do_data_augmentation: bool = True do_data_augmentation: bool = True
steps_per_epoch: Optional[int] = None # TODO: Use lr_decay in hp.baseHParams to infer decay_samples.
decay_samples: int = 10000 * 256 decay_samples: int = 10000 * 256
warmup_epochs: int = 2 warmup_epochs: int = 2
# Parameters about the saved checkpoint
model_dir: str = tempfile.mkdtemp()

View File

@ -25,6 +25,8 @@ from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.vision.core import image_preprocessing from mediapipe.model_maker.python.vision.core import image_preprocessing
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
@ -35,17 +37,20 @@ class ImageClassifier(classifier.Classifier):
"""ImageClassifier for building image classification model.""" """ImageClassifier for building image classification model."""
def __init__(self, model_spec: ms.ModelSpec, label_names: List[str], def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
hparams: hp.HParams): hparams: hp.HParams,
model_options: model_opt.ImageClassifierModelOptions):
"""Initializes ImageClassifier class. """Initializes ImageClassifier class.
Args: Args:
model_spec: Specification for the model. model_spec: Specification for the model.
label_names: A list of label names for the classes. label_names: A list of label names for the classes.
hparams: The hyperparameters for training image classifier. hparams: The hyperparameters for training image classifier.
model_options: Model options for creating image classifier.
""" """
super().__init__( super().__init__(
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle) model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
self._hparams = hparams self._hparams = hparams
self._model_options = model_options
self._preprocess = image_preprocessing.Preprocessor( self._preprocess = image_preprocessing.Preprocessor(
input_shape=self._model_spec.input_image_shape, input_shape=self._model_spec.input_image_shape,
num_classes=self._num_classes, num_classes=self._num_classes,
@ -57,30 +62,37 @@ class ImageClassifier(classifier.Classifier):
@classmethod @classmethod
def create( def create(
cls, cls,
model_spec: ms.SupportedModels,
train_data: classification_ds.ClassificationDataset, train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset, validation_data: classification_ds.ClassificationDataset,
hparams: Optional[hp.HParams] = None, options: image_classifier_options.ImageClassifierOptions,
) -> 'ImageClassifier': ) -> 'ImageClassifier':
"""Creates and trains an image classifier. """Creates and trains an image classifier.
Loads data and trains the model based on data for image classification. Loads data and trains the model based on data for image classification. If a
checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
directory, the training process will load the weight from the checkpoint
file for continual training.
Args: Args:
model_spec: Specification for the model.
train_data: Training data. train_data: Training data.
validation_data: Validation data. validation_data: Validation data.
hparams: Hyperparameters for training image classifier. options: configuration to create image classifier.
Returns: Returns:
An instance based on ImageClassifier. An instance based on ImageClassifier.
""" """
if hparams is None: if options.hparams is None:
hparams = hp.HParams() options.hparams = hp.HParams()
spec = ms.SupportedModels.get(model_spec) if options.model_options is None:
options.model_options = model_opt.ImageClassifierModelOptions()
spec = ms.SupportedModels.get(options.supported_model)
image_classifier = cls( image_classifier = cls(
model_spec=spec, label_names=train_data.label_names, hparams=hparams) model_spec=spec,
label_names=train_data.label_names,
hparams=options.hparams,
model_options=options.model_options)
image_classifier._create_model() image_classifier._create_model()
@ -90,6 +102,7 @@ class ImageClassifier(classifier.Classifier):
return image_classifier return image_classifier
# TODO: Migrate to the shared training library of Model Maker.
def _train(self, train_data: classification_ds.ClassificationDataset, def _train(self, train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset): validation_data: classification_ds.ClassificationDataset):
"""Trains the model with input train_data. """Trains the model with input train_data.
@ -142,7 +155,7 @@ class ImageClassifier(classifier.Classifier):
self._model = tf.keras.Sequential([ self._model = tf.keras.Sequential([
tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer, tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer,
tf.keras.layers.Dropout(rate=self._hparams.dropout_rate), tf.keras.layers.Dropout(rate=self._model_options.dropout_rate),
tf.keras.layers.Dense( tf.keras.layers.Dense(
units=self._num_classes, units=self._num_classes,
activation='softmax', activation='softmax',
@ -167,10 +180,10 @@ class ImageClassifier(classifier.Classifier):
path is {self._hparams.model_dir}/{model_name}. path is {self._hparams.model_dir}/{model_name}.
quantization_config: The configuration for model quantization. quantization_config: The configuration for model quantization.
""" """
if not tf.io.gfile.exists(self._hparams.model_dir): if not tf.io.gfile.exists(self._hparams.export_dir):
tf.io.gfile.makedirs(self._hparams.model_dir) tf.io.gfile.makedirs(self._hparams.export_dir)
tflite_file = os.path.join(self._hparams.model_dir, model_name) tflite_file = os.path.join(self._hparams.export_dir, model_name)
metadata_file = os.path.join(self._hparams.model_dir, 'metadata.json') metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
tflite_model = model_util.convert_to_tflite( tflite_model = model_util.convert_to_tflite(
model=self._model, model=self._model,
@ -180,7 +193,7 @@ class ImageClassifier(classifier.Classifier):
tflite_model, tflite_model,
self._model_spec.mean_rgb, self._model_spec.mean_rgb,
self._model_spec.stddev_rgb, self._model_spec.stddev_rgb,
labels=metadata_writer.Labels().add(self._label_names)) labels=metadata_writer.Labels().add(list(self._label_names)))
tflite_model_with_metadata, metadata_json = writer.populate() tflite_model_with_metadata, metadata_json = writer.populate()
model_util.save_tflite(tflite_model_with_metadata, tflite_file) model_util.save_tflite(tflite_model_with_metadata, tflite_file)
with open(metadata_file, 'w') as f: with open(metadata_file, 'w') as f:

View File

@ -0,0 +1,35 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Options for building image classifier."""
import dataclasses
from typing import Optional
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
from mediapipe.model_maker.python.vision.image_classifier import model_spec
@dataclasses.dataclass
class ImageClassifierOptions:
"""Configurable options for building image classifier.
Attributes:
supported_model: A model from the SupportedModels enum.
model_options: A set of options for configuring the selected model.
hparams: A set of hyperparameters used to train the image classifier.
"""
supported_model: model_spec.SupportedModels
model_options: Optional[model_opt.ImageClassifierModelOptions] = None
hparams: Optional[hyperparameters.HParams] = None

View File

@ -13,9 +13,13 @@
# limitations under the License. # limitations under the License.
import filecmp import filecmp
import io
import os import os
import tempfile
from unittest import mock as unittest_mock
from absl.testing import parameterized from absl.testing import parameterized
import mock
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -54,54 +58,74 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
super(ImageClassifierTest, self).setUp() super(ImageClassifierTest, self).setUp()
all_data = self._gen_cmy_data() all_data = self._gen_cmy_data()
# Splits data, 90% data for training, 10% for testing # Splits data, 90% data for training, 10% for testing
self.train_data, self.test_data = all_data.split(0.9) self._train_data, self._test_data = all_data.split(0.9)
@parameterized.named_parameters( @parameterized.named_parameters(
dict( dict(
testcase_name='mobilenet_v2', testcase_name='mobilenet_v2',
model_spec=image_classifier.SupportedModels.MOBILENET_V2, options=image_classifier.ImageClassifierOptions(
hparams=image_classifier.HParams( supported_model=image_classifier.SupportedModels.MOBILENET_V2,
train_epochs=1, batch_size=1, shuffle=True)), hparams=image_classifier.HParams(
epochs=1,
batch_size=1,
shuffle=True,
export_dir=tempfile.mkdtemp()))),
dict( dict(
testcase_name='efficientnet_lite0', testcase_name='efficientnet_lite0',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0, options=image_classifier.ImageClassifierOptions(
hparams=image_classifier.HParams( supported_model=(
train_epochs=1, batch_size=1, shuffle=True)), image_classifier.SupportedModels.EFFICIENTNET_LITE0),
hparams=image_classifier.HParams(
epochs=1,
batch_size=1,
shuffle=True,
export_dir=tempfile.mkdtemp()))),
dict(
testcase_name='efficientnet_lite0_change_dropout_rate',
options=image_classifier.ImageClassifierOptions(
supported_model=(
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
model_options=image_classifier.ModelOptions(dropout_rate=0.1),
hparams=image_classifier.HParams(
epochs=1,
batch_size=1,
shuffle=True,
export_dir=tempfile.mkdtemp()))),
dict( dict(
testcase_name='efficientnet_lite2', testcase_name='efficientnet_lite2',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2, options=image_classifier.ImageClassifierOptions(
hparams=image_classifier.HParams( supported_model=(
train_epochs=1, batch_size=1, shuffle=True)), image_classifier.SupportedModels.EFFICIENTNET_LITE2),
hparams=image_classifier.HParams(
epochs=1,
batch_size=1,
shuffle=True,
export_dir=tempfile.mkdtemp()))),
dict( dict(
testcase_name='efficientnet_lite4', testcase_name='efficientnet_lite4',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4, options=image_classifier.ImageClassifierOptions(
hparams=image_classifier.HParams( supported_model=(
train_epochs=1, batch_size=1, shuffle=True)), image_classifier.SupportedModels.EFFICIENTNET_LITE4),
hparams=image_classifier.HParams(
epochs=1,
batch_size=1,
shuffle=True,
export_dir=tempfile.mkdtemp()))),
) )
def test_create_and_train_model(self, def test_create_and_train_model(
model_spec: image_classifier.SupportedModels, self, options: image_classifier.ImageClassifierOptions):
hparams: image_classifier.HParams):
model = image_classifier.ImageClassifier.create( model = image_classifier.ImageClassifier.create(
model_spec=model_spec, train_data=self._train_data,
train_data=self.train_data, validation_data=self._test_data,
hparams=hparams, options=options)
validation_data=self.test_data)
self._test_accuracy(model)
def test_efficientnetlite0_model_train_and_export(self):
hparams = image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)
model = image_classifier.ImageClassifier.create(
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
train_data=self.train_data,
hparams=hparams,
validation_data=self.test_data)
self._test_accuracy(model) self._test_accuracy(model)
# Test export_model # Test export_model
model.export_model() model.export_model()
output_metadata_file = os.path.join(hparams.model_dir, 'metadata.json') output_metadata_file = os.path.join(options.hparams.export_dir,
output_tflite_file = os.path.join(hparams.model_dir, 'model.tflite') 'metadata.json')
output_tflite_file = os.path.join(options.hparams.export_dir,
'model.tflite')
expected_metadata_file = test_utils.get_test_data_path('metadata.json') expected_metadata_file = test_utils.get_test_data_path('metadata.json')
self.assertTrue(os.path.exists(output_tflite_file)) self.assertTrue(os.path.exists(output_tflite_file))
@ -111,10 +135,50 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreater(os.path.getsize(output_metadata_file), 0) self.assertGreater(os.path.getsize(output_metadata_file), 0)
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file)) self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
def test_continual_training_by_loading_checkpoint(self):
mock_stdout = io.StringIO()
with mock.patch('sys.stdout', mock_stdout):
options = image_classifier.ImageClassifierOptions(
supported_model=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
hparams=image_classifier.HParams(
epochs=5, batch_size=1, shuffle=True))
model = image_classifier.ImageClassifier.create(
train_data=self._train_data,
validation_data=self._test_data,
options=options)
model = image_classifier.ImageClassifier.create(
train_data=self._train_data,
validation_data=self._test_data,
options=options)
self._test_accuracy(model)
self.assertRegex(mock_stdout.getvalue(), 'Resuming from')
def _test_accuracy(self, model, threshold=0.0): def _test_accuracy(self, model, threshold=0.0):
_, accuracy = model.evaluate(self.test_data) _, accuracy = model.evaluate(self._test_data)
self.assertGreaterEqual(accuracy, threshold) self.assertGreaterEqual(accuracy, threshold)
@unittest_mock.patch.object(
image_classifier.hyperparameters,
'HParams',
autospec=True,
return_value=image_classifier.HParams(epochs=1))
@unittest_mock.patch.object(
image_classifier.model_options,
'ImageClassifierModelOptions',
autospec=True,
return_value=image_classifier.ModelOptions())
def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
self, mock_hparams, mock_model_options):
options = image_classifier.ImageClassifierOptions(
supported_model=(image_classifier.SupportedModels.EFFICIENTNET_LITE0))
image_classifier.ImageClassifier.create(
train_data=self._train_data,
validation_data=self._test_data,
options=options)
mock_hparams.assert_called_once()
mock_model_options.assert_called_once()
if __name__ == '__main__': if __name__ == '__main__':
# Load compressed models from tensorflow_hub # Load compressed models from tensorflow_hub

View File

@ -0,0 +1,27 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurable model options for image classifier models."""
import dataclasses
@dataclasses.dataclass
class ImageClassifierModelOptions:
"""Configurable options for image classifier model.
Attributes:
dropout_rate: The fraction of the input units to drop, used in dropout
layer.
"""
dropout_rate: float = 0.2

View File

@ -14,8 +14,6 @@
"""Library to train model.""" """Library to train model."""
import os import os
from typing import List
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import model_util
@ -49,18 +47,6 @@ def _create_optimizer(init_lr: float, decay_steps: int,
return optimizer return optimizer
def _get_default_callbacks(model_dir: str) -> List[tf.keras.callbacks.Callback]:
"""Gets default callbacks."""
summary_dir = os.path.join(model_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
# Save checkpoint every 20 epochs.
checkpoint_path = os.path.join(model_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True, period=20)
return [summary_callback, checkpoint_callback]
def train_model(model: tf.keras.Model, hparams: hp.HParams, def train_model(model: tf.keras.Model, hparams: hp.HParams,
train_ds: tf.data.Dataset, train_ds: tf.data.Dataset,
validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History: validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History:
@ -81,7 +67,8 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
learning_rate = hparams.learning_rate * hparams.batch_size / 256 learning_rate = hparams.learning_rate * hparams.batch_size / 256
# Get decay steps. # Get decay steps.
total_training_steps = hparams.steps_per_epoch * hparams.train_epochs # NOMUTANTS--(b/256493858):Plan to test it in the unified training library.
total_training_steps = hparams.steps_per_epoch * hparams.epochs
default_decay_steps = hparams.decay_samples // hparams.batch_size default_decay_steps = hparams.decay_samples // hparams.batch_size
decay_steps = max(total_training_steps, default_decay_steps) decay_steps = max(total_training_steps, default_decay_steps)
@ -92,11 +79,24 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
loss = tf.keras.losses.CategoricalCrossentropy( loss = tf.keras.losses.CategoricalCrossentropy(
label_smoothing=hparams.label_smoothing) label_smoothing=hparams.label_smoothing)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
callbacks = _get_default_callbacks(hparams.model_dir)
summary_dir = os.path.join(hparams.export_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
# Save checkpoint every 5 epochs.
checkpoint_path = os.path.join(hparams.export_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
save_weights_only=True,
period=5)
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
if latest_checkpoint:
print(f'Resuming from {latest_checkpoint}')
model.load_weights(latest_checkpoint)
# Train the model. # Train the model.
return model.fit( return model.fit(
x=train_ds, x=train_ds,
epochs=hparams.train_epochs, epochs=hparams.epochs,
validation_data=validation_ds, validation_data=validation_ds,
callbacks=callbacks) callbacks=[summary_callback, checkpoint_callback])

View File

@ -87,7 +87,6 @@ cc_library(
cc_library( cc_library(
name = "builtin_task_graphs", name = "builtin_task_graphs",
deps = [ deps = [
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
@ -95,8 +94,10 @@ cc_library(
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
] + select({ ] + select({
# TODO: Build text_classifier_graph on Windows. # TODO: Build text_classifier_graph on Windows.
# TODO: Build audio_classifier_graph on Windows.
"//mediapipe:windows": [], "//mediapipe:windows": [],
"//conditions:default": [ "//conditions:default": [
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
], ],
}), }),

View File

@ -30,6 +30,15 @@ cc_library(
], ],
) )
cc_library(
name = "hand_landmarks_detection_result",
hdrs = ["hand_landmarks_detection_result.h"],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
],
)
cc_library( cc_library(
name = "category", name = "category",
srcs = ["category.cc"], srcs = ["category.cc"],

View File

@ -0,0 +1,43 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace containers {
// The hand landmarks detection result from HandLandmarker, where each vector
// element represents a single hand detected in the image.
struct HandLandmarksDetectionResult {
// Classification of handedness.
std::vector<mediapipe::ClassificationList> handedness;
// Detected hand landmarks in normalized image coordinates.
std::vector<mediapipe::NormalizedLandmarkList> hand_landmarks;
// Detected hand landmarks in world coordinates.
std::vector<mediapipe::LandmarkList> hand_world_landmarks;
};
} // namespace containers
} // namespace components
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_

View File

@ -17,6 +17,9 @@ syntax = "proto2";
package mediapipe.tasks.components.proto; package mediapipe.tasks.components.proto;
option java_package = "com.google.mediapipe.tasks.components.proto";
option java_outer_classname = "SegmenterOptionsProto";
// Shared options used by image segmentation tasks. // Shared options used by image segmentation tasks.
message SegmenterOptions { message SegmenterOptions {
// Optional output mask type. // Optional output mask type.

View File

@ -0,0 +1,87 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "text_embedder",
srcs = ["text_embedder.cc"],
hdrs = ["text_embedder.h"],
deps = [
":text_embedder_graph",
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/tasks/cc/components/containers:embedding_result",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors:embedder_options",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:cosine_similarity",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:base_task_api",
"//mediapipe/tasks/cc/core:task_api_factory",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "text_embedder_graph",
srcs = ["text_embedder_graph.cc"],
deps = [
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/tasks/cc/components:text_preprocessing_graph",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto",
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_test(
name = "text_embedder_test",
srcs = ["text_embedder_test.cc"],
data = [
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
],
deps = [
":text_embedder",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/containers:embedding_result",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
],
)

View File

@ -0,0 +1,30 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_proto_library(
name = "text_embedder_graph_options_proto",
srcs = ["text_embedder_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)

View File

@ -0,0 +1,36 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.text.text_embedder.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
message TextEmbedderGraphOptions {
extend mediapipe.CalculatorOptions {
optional TextEmbedderGraphOptions ext = 477589892;
}
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
// Options for configuring the embedder behavior, such as normalization or
// quantization.
optional components.processors.proto.EmbedderOptions embedder_options = 2;
}

View File

@ -0,0 +1,104 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h"
#include <memory>
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/task_api_factory.h"
#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h"
namespace mediapipe::tasks::text::text_embedder {
namespace {
constexpr char kTextTag[] = "TEXT";
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kTextInStreamName[] = "text_in";
constexpr char kEmbeddingsStreamName[] = "embeddings_out";
constexpr char kGraphTypeName[] =
"mediapipe.tasks.text.text_embedder.TextEmbedderGraph";
using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
// Creates a MediaPipe graph config that contains a single node of type
// "mediapipe.tasks.text.text_embedder.TextEmbedderGraph".
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<proto::TextEmbedderGraphOptions> options_proto) {
api2::builder::Graph graph;
auto& task_graph = graph.AddNode(kGraphTypeName);
task_graph.GetOptions<proto::TextEmbedderGraphOptions>().Swap(
options_proto.get());
graph.In(kTextTag).SetName(kTextInStreamName) >> task_graph.In(kTextTag);
task_graph.Out(kEmbeddingsTag).SetName(kEmbeddingsStreamName) >>
graph.Out(kEmbeddingsTag);
return graph.GetConfig();
}
// Converts the user-facing TextEmbedderOptions struct to the internal
// TextEmbedderGraphOptions proto.
std::unique_ptr<proto::TextEmbedderGraphOptions>
ConvertTextEmbedderOptionsToProto(TextEmbedderOptions* options) {
auto options_proto = std::make_unique<proto::TextEmbedderGraphOptions>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
auto embedder_options_proto =
std::make_unique<components::processors::proto::EmbedderOptions>(
components::processors::ConvertEmbedderOptionsToProto(
&(options->embedder_options)));
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
return options_proto;
}
} // namespace
absl::StatusOr<std::unique_ptr<TextEmbedder>> TextEmbedder::Create(
std::unique_ptr<TextEmbedderOptions> options) {
std::unique_ptr<proto::TextEmbedderGraphOptions> options_proto =
ConvertTextEmbedderOptionsToProto(options.get());
return core::TaskApiFactory::Create<TextEmbedder,
proto::TextEmbedderGraphOptions>(
CreateGraphConfig(std::move(options_proto)),
std::move(options->base_options.op_resolver));
}
absl::StatusOr<TextEmbedderResult> TextEmbedder::Embed(absl::string_view text) {
ASSIGN_OR_RETURN(
auto output_packets,
runner_->Process(
{{kTextInStreamName, MakePacket<std::string>(std::string(text))}}));
return ConvertToEmbeddingResult(
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
}
absl::StatusOr<double> TextEmbedder::CosineSimilarity(
const components::containers::Embedding& u,
const components::containers::Embedding& v) {
return components::utils::CosineSimilarity(u, v);
}
} // namespace mediapipe::tasks::text::text_embedder

View File

@ -0,0 +1,96 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
#include <memory>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/base_task_api.h"
namespace mediapipe::tasks::text::text_embedder {
// Alias the shared EmbeddingResult struct as result typo.
using TextEmbedderResult =
::mediapipe::tasks::components::containers::EmbeddingResult;
// Options for configuring a MediaPipe text embedder task.
struct TextEmbedderOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// file with metadata, accelerator options, op resolver, etc.
tasks::core::BaseOptions base_options;
// Options for configuring the embedder behavior, such as L2-normalization or
// scalar-quantization.
components::processors::EmbedderOptions embedder_options;
};
// Performs embedding extraction on text.
//
// This API expects a TFLite model with TFLite Model Metadata that contains the
// mandatory (described below) input tensors and output tensors. Metadata should
// contain the input process unit for the model's Tokenizer as well as input /
// output tensor metadata.
//
// TODO: Support Universal Sentence Encoder.
// Input tensors:
// (kTfLiteInt32)
// - 3 input tensors of size `[batch_size x bert_max_seq_len]` with names
// "ids", "mask", and "segment_ids" representing the input ids, mask ids, and
// segment ids respectively
// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the
// input ids
//
// At least one output tensor with:
// (kTfLiteFloat32)
// - `N` components corresponding to the `N` dimensions of the returned
// feature vector for this output layer.
// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`.
class TextEmbedder : core::BaseTaskApi {
public:
using BaseTaskApi::BaseTaskApi;
// Creates a TextEmbedder from the provided `options`. A non-default
// OpResolver can be specified in the BaseOptions in order to support custom
// Ops or specify a subset of built-in Ops.
static absl::StatusOr<std::unique_ptr<TextEmbedder>> Create(
std::unique_ptr<TextEmbedderOptions> options);
// Performs embedding extraction on the input `text`.
absl::StatusOr<TextEmbedderResult> Embed(absl::string_view text);
// Shuts down the TextEmbedder when all the work is done.
absl::Status Close() { return runner_->Close(); }
// Utility function to compute cosine similarity [1] between two embeddings.
// May return an InvalidArgumentError if e.g. the embeddings are of different
// types (quantized vs. float), have different sizes, or have a an L2-norm of
// 0.
//
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
static absl::StatusOr<double> CosineSimilarity(
const components::containers::Embedding& u,
const components::containers::Embedding& v);
};
} // namespace mediapipe::tasks::text::text_embedder
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_

View File

@ -0,0 +1,145 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h"
namespace mediapipe::tasks::text::text_embedder {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::core::ModelResources;
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kTextTag[] = "TEXT";
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
constexpr char kTensorsTag[] = "TENSORS";
} // namespace
// A "mediapipe.tasks.text.TextEmbedderGraph" performs text embedding
// extraction.
// - Accepts input text and outputs embeddings on CPU.
//
// Inputs:
// TEXT - std::string
// Input text to perform embedding extraction on.
//
// Outputs:
// EMBEDDINGS - EmbeddingResult
// The embedding result.
//
// Example:
// node {
// calculator: "mediapipe.tasks.text.TextEmbedderGraph"
// input_stream: "TEXT:text_in"
// output_stream: "EMBEDDINGS:embedding_result_out"
// options {
// [mediapipe.tasks.text.text_embedder.proto.TextEmbedderGraphOptions.ext] {
// base_options {
// model_asset {
// file_name: "/path/to/model.tflite"
// }
// }
// }
// }
// }
class TextEmbedderGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
CHECK(sc != nullptr);
ASSIGN_OR_RETURN(const ModelResources* model_resources,
CreateModelResources<proto::TextEmbedderGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(
Source<EmbeddingResult> embedding_result_out,
BuildTextEmbedderTask(sc->Options<proto::TextEmbedderGraphOptions>(),
*model_resources,
graph[Input<std::string>(kTextTag)], graph));
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
return graph.GetConfig();
}
private:
// Adds a mediapipe TextEmbedder task graph into the provided
// builder::Graph instance. The TextEmbedder task takes an input
// text (std::string) and returns an embedding result.
//
// task_options: the mediapipe tasks TextEmbedderGraphOptions proto.
// model_resources: the ModelResources object initialized from a
// TextEmbedder model file with model metadata.
// text_in: (std::string) stream to run embedding extraction on.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<EmbeddingResult>> BuildTextEmbedderTask(
const proto::TextEmbedderGraphOptions& task_options,
const ModelResources& model_resources, Source<std::string> text_in,
Graph& graph) {
// Adds preprocessing calculators and connects them to the text input
// stream.
auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph");
MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph(
model_resources,
preprocessing.GetOptions<
tasks::components::proto::TextPreprocessingGraphOptions>()));
text_in >> preprocessing.In(kTextTag);
// Adds both InferenceCalculator and ModelResourcesCalculator.
auto& inference = AddInference(
model_resources, task_options.base_options().acceleration(), graph);
// The metadata extractor side-output comes from the
// ModelResourcesCalculator.
inference.SideOut(kMetadataExtractorTag) >>
preprocessing.SideIn(kMetadataExtractorTag);
preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag);
// Adds postprocessing calculators and connects its input stream to the
// inference results.
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
model_resources, task_options.embedder_options(),
&postprocessing.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the embedding result.
return postprocessing[Output<EmbeddingResult>(kEmbeddingsTag)];
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::text::text_embedder::TextEmbedderGraph);
} // namespace mediapipe::tasks::text::text_embedder

View File

@ -0,0 +1,143 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h"
#include <memory>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe::tasks::text::text_embedder {
namespace {
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
// Note that these models use dynamic-sized tensors.
// Embedding model with BERT preprocessing.
constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite";
// Embedding model with regex preprocessing.
constexpr char kRegexOneEmbeddingModel[] =
"regex_one_embedding_with_metadata.tflite";
// Tolerance for embedding vector coordinate values.
constexpr float kEpsilon = 1e-4;
// Tolerancy for cosine similarity evaluation.
constexpr double kSimilarityTolerancy = 1e-6;
using ::mediapipe::file::JoinPath;
using ::testing::HasSubstr;
using ::testing::Optional;
class EmbedderTest : public tflite_shims::testing::Test {};
TEST_F(EmbedderTest, FailsWithMissingModel) {
auto text_embedder =
TextEmbedder::Create(std::make_unique<TextEmbedderOptions>());
ASSERT_EQ(text_embedder.status().code(), absl::StatusCode::kInvalidArgument);
ASSERT_THAT(
text_embedder.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
ASSERT_THAT(text_embedder.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError))));
}
TEST_F(EmbedderTest, SucceedsWithMobileBert) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileBert);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result0,
text_embedder->Embed("it's a charming and often affecting journey"));
ASSERT_EQ(result0.embeddings.size(), 1);
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512);
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon);
MP_ASSERT_OK_AND_ASSIGN(
auto result1, text_embedder->Embed("what a great and fantastic trip"));
ASSERT_EQ(result1.embeddings.size(), 1);
ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 512);
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 22.626251f, kEpsilon);
// Check cosine similarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
result1.embeddings[0]));
EXPECT_NEAR(similarity, 0.969514, kSimilarityTolerancy);
MP_ASSERT_OK(text_embedder->Close());
}
TEST(EmbedTest, SucceedsWithRegexOneEmbeddingModel) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kRegexOneEmbeddingModel);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
auto result0,
text_embedder->Embed("it's a charming and often affecting journey"));
EXPECT_EQ(result0.embeddings.size(), 1);
EXPECT_EQ(result0.embeddings[0].float_embedding.size(), 16);
EXPECT_NEAR(result0.embeddings[0].float_embedding[0], 0.0309356f, kEpsilon);
MP_ASSERT_OK_AND_ASSIGN(
auto result1, text_embedder->Embed("what a great and fantastic trip"));
EXPECT_EQ(result1.embeddings.size(), 1);
EXPECT_EQ(result1.embeddings[0].float_embedding.size(), 16);
EXPECT_NEAR(result1.embeddings[0].float_embedding[0], 0.0312863f, kEpsilon);
// Check cosine similarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
result1.embeddings[0]));
EXPECT_NEAR(similarity, 0.999937, kSimilarityTolerancy);
MP_ASSERT_OK(text_embedder->Close());
}
TEST_F(EmbedderTest, SucceedsWithQuantization) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileBert);
options->embedder_options.quantize = true;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result,
text_embedder->Embed("it's a charming and often affecting journey"));
ASSERT_EQ(result.embeddings.size(), 1);
ASSERT_EQ(result.embeddings[0].quantized_embedding.size(), 512);
MP_ASSERT_OK(text_embedder->Close());
}
} // namespace
} // namespace mediapipe::tasks::text::text_embedder

View File

@ -110,4 +110,38 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "hand_landmarker",
srcs = ["hand_landmarker.cc"],
hdrs = ["hand_landmarker.h"],
deps = [
":hand_landmarker_graph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/containers:hand_landmarks_detection_result",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:base_task_api",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor",
],
)
# TODO: Enable this test # TODO: Enable this test

View File

@ -0,0 +1,269 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/base_task_api.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_landmarker {
namespace {
using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision::
hand_landmarker::proto::HandLandmarkerGraphOptions;
using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult;
constexpr char kHandLandmarkerGraphTypeName[] =
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kHandednessStreamName[] = "handedness";
constexpr char kHandLandmarksTag[] = "LANDMARKS";
constexpr char kHandLandmarksStreamName[] = "landmarks";
constexpr char kHandWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks";
constexpr int kMicroSecondsPerMilliSecond = 1000;
// Creates a MediaPipe graph config that contains a subgraph node of
// "mediapipe.tasks.vision.hand_ladnamrker.HandLandmarkerGraph". If the task is
// running in the live stream mode, a "FlowLimiterCalculator" will be added to
// limit the number of frames in flight.
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<HandLandmarkerGraphOptionsProto> options,
bool enable_flow_limiting) {
api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kHandLandmarkerGraphTypeName);
subgraph.GetOptions<HandLandmarkerGraphOptionsProto>().Swap(options.get());
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
subgraph.Out(kHandednessTag).SetName(kHandednessStreamName) >>
graph.Out(kHandednessTag);
subgraph.Out(kHandLandmarksTag).SetName(kHandLandmarksStreamName) >>
graph.Out(kHandLandmarksTag);
subgraph.Out(kHandWorldLandmarksTag).SetName(kHandWorldLandmarksStreamName) >>
graph.Out(kHandWorldLandmarksTag);
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(
graph, subgraph, {kImageTag, kNormRectTag}, kHandLandmarksTag);
}
graph.In(kImageTag) >> subgraph.In(kImageTag);
graph.In(kNormRectTag) >> subgraph.In(kNormRectTag);
return graph.GetConfig();
}
// Converts the user-facing HandLandmarkerOptions struct to the internal
// HandLandmarkerGraphOptions proto.
std::unique_ptr<HandLandmarkerGraphOptionsProto>
ConvertHandLandmarkerGraphOptionsProto(HandLandmarkerOptions* options) {
auto options_proto = std::make_unique<HandLandmarkerGraphOptionsProto>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE);
// Configure hand detector options.
auto* hand_detector_graph_options =
options_proto->mutable_hand_detector_graph_options();
hand_detector_graph_options->set_num_hands(options->num_hands);
hand_detector_graph_options->set_min_detection_confidence(
options->min_hand_detection_confidence);
// Configure hand landmark detector options.
options_proto->set_min_tracking_confidence(options->min_tracking_confidence);
auto* hand_landmarks_detector_graph_options =
options_proto->mutable_hand_landmarks_detector_graph_options();
hand_landmarks_detector_graph_options->set_min_detection_confidence(
options->min_hand_presence_confidence);
return options_proto;
}
} // namespace
absl::StatusOr<std::unique_ptr<HandLandmarker>> HandLandmarker::Create(
std::unique_ptr<HandLandmarkerOptions> options) {
auto options_proto = ConvertHandLandmarkerGraphOptionsProto(options.get());
tasks::core::PacketsCallback packets_callback = nullptr;
if (options->result_callback) {
auto result_callback = options->result_callback;
packets_callback = [=](absl::StatusOr<tasks::core::PacketMap>
status_or_packets) {
if (!status_or_packets.ok()) {
Image image;
result_callback(status_or_packets.status(), image,
Timestamp::Unset().Value());
return;
}
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
return;
}
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
if (status_or_packets.value()[kHandLandmarksStreamName].IsEmpty()) {
Packet empty_packet =
status_or_packets.value()[kHandLandmarksStreamName];
result_callback(
{HandLandmarksDetectionResult()}, image_packet.Get<Image>(),
empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
return;
}
Packet handedness_packet =
status_or_packets.value()[kHandednessStreamName];
Packet hand_landmarks_packet =
status_or_packets.value()[kHandLandmarksStreamName];
Packet hand_world_landmarks_packet =
status_or_packets.value()[kHandWorldLandmarksStreamName];
result_callback(
{{handedness_packet.Get<std::vector<ClassificationList>>(),
hand_landmarks_packet.Get<std::vector<NormalizedLandmarkList>>(),
hand_world_landmarks_packet.Get<std::vector<LandmarkList>>()}},
image_packet.Get<Image>(),
hand_landmarks_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond);
};
}
return core::VisionTaskApiFactory::Create<HandLandmarker,
HandLandmarkerGraphOptionsProto>(
CreateGraphConfig(
std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM),
std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback));
}
absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::Detect(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessImageData(
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
return {HandLandmarksDetectionResult()};
}
return {{/* handedness= */
{output_packets[kHandednessStreamName]
.Get<std::vector<mediapipe::ClassificationList>>()},
/* hand_landmarks= */
{output_packets[kHandLandmarksStreamName]
.Get<std::vector<mediapipe::NormalizedLandmarkList>>()},
/* hand_world_landmarks */
{output_packets[kHandWorldLandmarksStreamName]
.Get<std::vector<mediapipe::LandmarkList>>()}}};
}
absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::DetectForVideo(
mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessVideoData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
return {HandLandmarksDetectionResult()};
}
return {
{/* handedness= */
{output_packets[kHandednessStreamName]
.Get<std::vector<mediapipe::ClassificationList>>()},
/* hand_landmarks= */
{output_packets[kHandLandmarksStreamName]
.Get<std::vector<mediapipe::NormalizedLandmarkList>>()},
/* hand_world_landmarks */
{output_packets[kHandWorldLandmarksStreamName]
.Get<std::vector<mediapipe::LandmarkList>>()}},
};
}
absl::Status HandLandmarker::DetectAsync(
mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
return SendLiveStreamData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
}
} // namespace hand_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,192 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_H_
#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_H_
#include <memory>
#include <optional>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_landmarker {
struct HandLandmarkerOptions {
// Base options for configuring MediaPipe Tasks library, such as specifying
// the TfLite model bundle file with metadata, accelerator options, op
// resolver, etc.
tasks::core::BaseOptions base_options;
// The running mode of the task. Default to the image mode.
// HandLandmarker has three running modes:
// 1) The image mode for detecting hand landmarks on single image inputs.
// 2) The video mode for detecting hand landmarks on the decoded frames of a
// video.
// 3) The live stream mode for detecting hand landmarks on the live stream of
// input data, such as from camera. In this mode, the "result_callback"
// below must be specified to receive the detection results asynchronously.
core::RunningMode running_mode = core::RunningMode::IMAGE;
// The maximum number of hands can be detected by the HandLandmarker.
int num_hands = 1;
// The minimum confidence score for the hand detection to be considered
// successful.
float min_hand_detection_confidence = 0.5;
// The minimum confidence score of hand presence score in the hand landmark
// detection.
float min_hand_presence_confidence = 0.5;
// The minimum confidence score for the hand tracking to be considered
// successful.
float min_tracking_confidence = 0.5;
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM.
std::function<void(
absl::StatusOr<components::containers::HandLandmarksDetectionResult>,
const Image&, int64)>
result_callback = nullptr;
};
// Performs hand landmarks detection on the given image.
//
// TODO add the link to DevSite.
// This API expects a pre-trained hand landmarker model asset bundle.
//
// Inputs:
// Image
// - The image that hand landmarks detection runs on.
// std::optional<NormalizedRect>
// - If provided, can be used to specify the rotation to apply to the image
// before performing hand landmarks detection, by setting its 'rotation'
// field in radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation).
// Note that specifying a region-of-interest using the 'x_center',
// 'y_center', 'width' and 'height' fields is NOT supported and will
// result in an invalid argument error being returned.
// Outputs:
// HandLandmarksDetectionResult
// - The hand landmarks detection results.
class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
public:
using BaseVisionTaskApi::BaseVisionTaskApi;
// Creates a HandLandmarker from a HandLandmarkerOptions to process image data
// or streaming data. Hand landmarker can be created with one of the following
// three running modes:
// 1) Image mode for detecting hand landmarks on single image inputs. Users
// provide mediapipe::Image to the `Detect` method, and will receive the
// deteced hand landmarks results as the return value.
// 2) Video mode for detecting hand landmarks on the decoded frames of a
// video. Users call `DetectForVideo` method, and will receive the detected
// hand landmarks results as the return value.
// 3) Live stream mode for detecting hand landmarks on the live stream of the
// input data, such as from camera. Users call `DetectAsync` to push the
// image data into the HandLandmarker, the detected results along with the
// input timestamp and the image that hand landmarker runs on will be
// available in the result callback when the hand landmarker finishes the
// work.
static absl::StatusOr<std::unique_ptr<HandLandmarker>> Create(
std::unique_ptr<HandLandmarkerOptions> options);
// Performs hand landmarks detection on the given image.
// Only use this method when the HandLandmarker is created with the image
// running mode.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing detection, by setting
// its 'rotation_degrees' field. Note that specifying a region-of-interest
// using the 'region_of_interest' field is NOT supported and will result in an
// invalid argument error being returned.
//
// The image can be of any size with format RGB or RGBA.
// TODO: Describes how the input image will be preprocessed
// after the yuv support is implemented.
absl::StatusOr<components::containers::HandLandmarksDetectionResult> Detect(
Image image,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Performs hand landmarks detection on the provided video frame.
// Only use this method when the HandLandmarker is created with the video
// running mode.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing detection, by setting
// its 'rotation_degrees' field. Note that specifying a region-of-interest
// using the 'region_of_interest' field is NOT supported and will result in an
// invalid argument error being returned.
//
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
absl::StatusOr<components::containers::HandLandmarksDetectionResult>
DetectForVideo(Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
// Sends live image data to perform hand landmarks detection, and the results
// will be available via the "result_callback" provided in the
// HandLandmarkerOptions. Only use this method when the HandLandmarker
// is created with the live stream running mode.
//
// The image can be of any size with format RGB or RGBA. It's required to
// provide a timestamp (in milliseconds) to indicate when the input image is
// sent to the hand landmarker. The input timestamps must be monotonically
// increasing.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing detection, by setting
// its 'rotation_degrees' field. Note that specifying a region-of-interest
// using the 'region_of_interest' field is NOT supported and will result in an
// invalid argument error being returned.
//
// The "result_callback" provides
// - A vector of HandLandmarksDetectionResult, each is the detected results
// for a input frame.
// - The const reference to the corresponding input image that the hand
// landmarker runs on. Note that the const reference to the image will no
// longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
absl::Status DetectAsync(Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
// Shuts down the HandLandmarker when all works are done.
absl::Status Close() { return runner_->Close(); }
};
} // namespace hand_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_H_

View File

@ -0,0 +1,511 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h"
#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_landmarker {
namespace {
using ::file::Defaults;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::EqualsProto;
using ::testing::HasSubstr;
using ::testing::Optional;
using ::testing::Pointwise;
using ::testing::TestParamInfo;
using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kHandLandmarkerBundleAsset[] = "hand_landmarker.task";
constexpr char kThumbUpLandmarksFilename[] = "thumb_up_landmarks.pbtxt";
constexpr char kPointingUpLandmarksFilename[] = "pointing_up_landmarks.pbtxt";
constexpr char kPointingUpRotatedLandmarksFilename[] =
"pointing_up_rotated_landmarks.pbtxt";
constexpr char kThumbUpImage[] = "thumb_up.jpg";
constexpr char kPointingUpImage[] = "pointing_up.jpg";
constexpr char kPointingUpRotatedImage[] = "pointing_up_rotated.jpg";
constexpr char kNoHandsImage[] = "cats_and_dogs.jpg";
constexpr float kLandmarksFractionDiff = 0.03; // percentage
constexpr float kLandmarksAbsMargin = 0.03;
constexpr float kHandednessMargin = 0.05;
LandmarksDetectionResult GetLandmarksDetectionResult(
absl::string_view landmarks_file_name) {
LandmarksDetectionResult result;
MP_EXPECT_OK(GetTextProto(
file::JoinPath("./", kTestDataDirectory, landmarks_file_name), &result,
Defaults()));
// Remove z position of landmarks, because they are not used in correctness
// testing. For video or live stream mode, the z positions varies a lot during
// tracking from frame to frame.
for (int i = 0; i < result.landmarks().landmark().size(); i++) {
auto& landmark = *result.mutable_landmarks()->mutable_landmark(i);
landmark.clear_z();
}
return result;
}
HandLandmarksDetectionResult GetExpectedHandLandmarksDetectionResult(
const std::vector<absl::string_view>& landmarks_file_names) {
HandLandmarksDetectionResult expected_results;
for (const auto& file_name : landmarks_file_names) {
const auto landmarks_detection_result =
GetLandmarksDetectionResult(file_name);
expected_results.hand_landmarks.push_back(
landmarks_detection_result.landmarks());
expected_results.handedness.push_back(
landmarks_detection_result.classifications());
}
return expected_results;
}
void ExpectHandLandmarksDetectionResultsCorrect(
const HandLandmarksDetectionResult& actual_results,
const HandLandmarksDetectionResult& expected_results) {
const auto& actual_landmarks = actual_results.hand_landmarks;
const auto& actual_handedness = actual_results.handedness;
const auto& expected_landmarks = expected_results.hand_landmarks;
const auto& expected_handedness = expected_results.handedness;
ASSERT_EQ(actual_landmarks.size(), expected_landmarks.size());
ASSERT_EQ(actual_handedness.size(), expected_handedness.size());
EXPECT_THAT(
actual_handedness,
Pointwise(Approximately(Partially(EqualsProto()), kHandednessMargin),
expected_handedness));
EXPECT_THAT(actual_landmarks,
Pointwise(Approximately(Partially(EqualsProto()),
/*margin=*/kLandmarksAbsMargin,
/*fraction=*/kLandmarksFractionDiff),
expected_landmarks));
}
} // namespace
struct TestParams {
// The name of this test, for convenience when displaying test results.
std::string test_name;
// The filename of test image.
std::string test_image_name;
// The filename of test model.
std::string test_model_file;
// The rotation to apply to the test image before processing, in degrees
// clockwise.
int rotation;
// Expected results from the hand landmarker model output.
HandLandmarksDetectionResult expected_results;
};
class ImageModeTest : public testing::TestWithParam<TestParams> {};
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage)));
auto options = std::make_unique<HandLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
options->running_mode = core::RunningMode::IMAGE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options)));
auto results = hand_landmarker->DetectForVideo(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the video mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
results = hand_landmarker->DetectAsync(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the live stream mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
MP_ASSERT_OK(hand_landmarker->Close());
}
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage)));
auto options = std::make_unique<HandLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
options->running_mode = core::RunningMode::IMAGE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options)));
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
auto results = hand_landmarker->Detect(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("This task doesn't support region-of-interest"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
}
TEST_P(ImageModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
auto options = std::make_unique<HandLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
options->running_mode = core::RunningMode::IMAGE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options)));
HandLandmarksDetectionResult hand_landmarker_results;
if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation;
MP_ASSERT_OK_AND_ASSIGN(
hand_landmarker_results,
hand_landmarker->Detect(image, image_processing_options));
} else {
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
hand_landmarker->Detect(image));
}
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results,
GetParam().expected_results);
MP_ASSERT_OK(hand_landmarker->Close());
}
INSTANTIATE_TEST_SUITE_P(
HandGestureTest, ImageModeTest,
Values(TestParams{
/* test_name= */ "LandmarksThumbUp",
/* test_image_name= */ kThumbUpImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kThumbUpLandmarksFilename}),
},
TestParams{
/* test_name= */ "LandmarksPointingUp",
/* test_image_name= */ kPointingUpImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kPointingUpLandmarksFilename}),
},
TestParams{
/* test_name= */ "LandmarksPointingUpRotated",
/* test_image_name= */ kPointingUpRotatedImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ -90,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kPointingUpRotatedLandmarksFilename}),
},
TestParams{
/* test_name= */ "NoHands",
/* test_image_name= */ kNoHandsImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
{{}, {}, {}},
}),
[](const TestParamInfo<ImageModeTest::ParamType>& info) {
return info.param.test_name;
});
class VideoModeTest : public testing::TestWithParam<TestParams> {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage)));
auto options = std::make_unique<HandLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
options->running_mode = core::RunningMode::VIDEO;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options)));
auto results = hand_landmarker->Detect(image);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the image mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
results = hand_landmarker->DetectAsync(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the live stream mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
MP_ASSERT_OK(hand_landmarker->Close());
}
TEST_P(VideoModeTest, Succeeds) {
const int iterations = 100;
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
auto options = std::make_unique<HandLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
options->running_mode = core::RunningMode::VIDEO;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options)));
const auto expected_results = GetParam().expected_results;
for (int i = 0; i < iterations; ++i) {
HandLandmarksDetectionResult hand_landmarker_results;
if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation;
MP_ASSERT_OK_AND_ASSIGN(
hand_landmarker_results,
hand_landmarker->DetectForVideo(image, i, image_processing_options));
} else {
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
hand_landmarker->DetectForVideo(image, i));
}
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results,
expected_results);
}
MP_ASSERT_OK(hand_landmarker->Close());
}
INSTANTIATE_TEST_SUITE_P(
HandGestureTest, VideoModeTest,
Values(TestParams{
/* test_name= */ "LandmarksThumbUp",
/* test_image_name= */ kThumbUpImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kThumbUpLandmarksFilename}),
},
TestParams{
/* test_name= */ "LandmarksPointingUp",
/* test_image_name= */ kPointingUpImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kPointingUpLandmarksFilename}),
},
TestParams{
/* test_name= */ "LandmarksPointingUpRotated",
/* test_image_name= */ kPointingUpRotatedImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ -90,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kPointingUpRotatedLandmarksFilename}),
},
TestParams{
/* test_name= */ "NoHands",
/* test_image_name= */ kNoHandsImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
{{}, {}, {}},
}),
[](const TestParamInfo<ImageModeTest::ParamType>& info) {
return info.param.test_name;
});
class LiveStreamModeTest : public testing::TestWithParam<TestParams> {};
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage)));
auto options = std::make_unique<HandLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback =
[](absl::StatusOr<HandLandmarksDetectionResult> results,
const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options)));
auto results = hand_landmarker->Detect(image);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the image mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
results = hand_landmarker->DetectForVideo(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the video mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
MP_ASSERT_OK(hand_landmarker->Close());
}
TEST_P(LiveStreamModeTest, Succeeds) {
const int iterations = 100;
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
auto options = std::make_unique<HandLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
options->running_mode = core::RunningMode::LIVE_STREAM;
std::vector<HandLandmarksDetectionResult> hand_landmarker_results;
std::vector<std::pair<int, int>> image_sizes;
std::vector<int64> timestamps;
options->result_callback =
[&hand_landmarker_results, &image_sizes, &timestamps](
absl::StatusOr<HandLandmarksDetectionResult> results,
const Image& image, int64 timestamp_ms) {
MP_ASSERT_OK(results.status());
hand_landmarker_results.push_back(std::move(results.value()));
image_sizes.push_back({image.width(), image.height()});
timestamps.push_back(timestamp_ms);
};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options)));
for (int i = 0; i < iterations; ++i) {
HandLandmarksDetectionResult hand_landmarker_results;
if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation;
MP_ASSERT_OK(
hand_landmarker->DetectAsync(image, i, image_processing_options));
} else {
MP_ASSERT_OK(hand_landmarker->DetectAsync(image, i));
}
}
MP_ASSERT_OK(hand_landmarker->Close());
// Due to the flow limiter, the total of outputs will be smaller than the
// number of iterations.
ASSERT_LE(hand_landmarker_results.size(), iterations);
ASSERT_GT(hand_landmarker_results.size(), 0);
const auto expected_results = GetParam().expected_results;
for (int i = 0; i < hand_landmarker_results.size(); ++i) {
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results[i],
expected_results);
}
for (const auto& image_size : image_sizes) {
EXPECT_EQ(image_size.first, image.width());
EXPECT_EQ(image_size.second, image.height());
}
int64 timestamp_ms = -1;
for (const auto& timestamp : timestamps) {
EXPECT_GT(timestamp, timestamp_ms);
timestamp_ms = timestamp;
}
}
INSTANTIATE_TEST_SUITE_P(
HandGestureTest, LiveStreamModeTest,
Values(TestParams{
/* test_name= */ "LandmarksThumbUp",
/* test_image_name= */ kThumbUpImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kThumbUpLandmarksFilename}),
},
TestParams{
/* test_name= */ "LandmarksPointingUp",
/* test_image_name= */ kPointingUpImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kPointingUpLandmarksFilename}),
},
TestParams{
/* test_name= */ "LandmarksPointingUpRotated",
/* test_image_name= */ kPointingUpRotatedImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ -90,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kPointingUpRotatedLandmarksFilename}),
},
TestParams{
/* test_name= */ "NoHands",
/* test_image_name= */ kNoHandsImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
{{}, {}, {}},
}),
[](const TestParamInfo<ImageModeTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace hand_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -32,7 +32,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
@ -63,7 +63,7 @@ cc_library(
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc",
"//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:label_map_util", "//mediapipe/util:label_map_util",

View File

@ -23,10 +23,12 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace vision {
namespace image_segmenter {
namespace { namespace {
constexpr char kSegmentationStreamName[] = "segmented_mask_out"; constexpr char kSegmentationStreamName[] = "segmented_mask_out";
@ -37,23 +39,24 @@ constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectStreamName[] = "norm_rect_in"; constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kSubgraphTypeName[] = constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.ImageSegmenterGraph"; "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000; constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Image; using ::mediapipe::Image;
using ::mediapipe::tasks::components::proto::SegmenterOptions; using ::mediapipe::tasks::components::proto::SegmenterOptions;
using ImageSegmenterOptionsProto = using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
image_segmenter::proto::ImageSegmenterOptions; image_segmenter::proto::ImageSegmenterGraphOptions;
// Creates a MediaPipe graph config that only contains a single subgraph node of // Creates a MediaPipe graph config that only contains a single subgraph node of
// "mediapipe.tasks.vision.ImageSegmenterGraph". // "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
CalculatorGraphConfig CreateGraphConfig( CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<ImageSegmenterOptionsProto> options, std::unique_ptr<ImageSegmenterGraphOptionsProto> options,
bool enable_flow_limiting) { bool enable_flow_limiting) {
api2::builder::Graph graph; api2::builder::Graph graph;
auto& task_subgraph = graph.AddNode(kSubgraphTypeName); auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageSegmenterOptionsProto>().Swap(options.get()); task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
options.get());
graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName);
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
@ -72,9 +75,9 @@ CalculatorGraphConfig CreateGraphConfig(
// Converts the user-facing ImageSegmenterOptions struct to the internal // Converts the user-facing ImageSegmenterOptions struct to the internal
// ImageSegmenterOptions proto. // ImageSegmenterOptions proto.
std::unique_ptr<ImageSegmenterOptionsProto> ConvertImageSegmenterOptionsToProto( std::unique_ptr<ImageSegmenterGraphOptionsProto>
ImageSegmenterOptions* options) { ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) {
auto options_proto = std::make_unique<ImageSegmenterOptionsProto>(); auto options_proto = std::make_unique<ImageSegmenterGraphOptionsProto>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>( auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get()); options_proto->mutable_base_options()->Swap(base_options_proto.get());
@ -137,7 +140,7 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
}; };
} }
return core::VisionTaskApiFactory::Create<ImageSegmenter, return core::VisionTaskApiFactory::Create<ImageSegmenter,
ImageSegmenterOptionsProto>( ImageSegmenterGraphOptionsProto>(
CreateGraphConfig( CreateGraphConfig(
std::move(options_proto), std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM), options->running_mode == core::RunningMode::LIVE_STREAM),
@ -211,6 +214,7 @@ absl::Status ImageSegmenter::SegmentAsync(
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
} }
} // namespace image_segmenter
} // namespace vision } // namespace vision
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -26,12 +26,12 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace vision {
namespace image_segmenter {
// The options for configuring a mediapipe image segmenter task. // The options for configuring a mediapipe image segmenter task.
struct ImageSegmenterOptions { struct ImageSegmenterOptions {
@ -191,6 +191,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }
}; };
} // namespace image_segmenter
} // namespace vision } // namespace vision
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -35,7 +35,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h"
#include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/label_map_util.h" #include "mediapipe/util/label_map_util.h"
@ -44,6 +44,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace vision {
namespace image_segmenter {
namespace { namespace {
@ -55,7 +56,8 @@ using ::mediapipe::api2::builder::MultiSource;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::proto::SegmenterOptions; using ::mediapipe::tasks::components::proto::SegmenterOptions;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::mediapipe::tasks::vision::image_segmenter::proto::ImageSegmenterOptions; using ::mediapipe::tasks::vision::image_segmenter::proto::
ImageSegmenterGraphOptions;
using ::tflite::Tensor; using ::tflite::Tensor;
using ::tflite::TensorMetadata; using ::tflite::TensorMetadata;
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>; using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
@ -77,7 +79,7 @@ struct ImageSegmenterOutputs {
} // namespace } // namespace
absl::Status SanityCheckOptions(const ImageSegmenterOptions& options) { absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) {
if (options.segmenter_options().output_type() == if (options.segmenter_options().output_type() ==
SegmenterOptions::UNSPECIFIED) { SegmenterOptions::UNSPECIFIED) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
@ -112,7 +114,7 @@ absl::StatusOr<LabelItems> GetLabelItemsIfAny(
} }
absl::Status ConfigureTensorsToSegmentationCalculator( absl::Status ConfigureTensorsToSegmentationCalculator(
const ImageSegmenterOptions& segmenter_option, const ImageSegmenterGraphOptions& segmenter_option,
const core::ModelResources& model_resources, const core::ModelResources& model_resources,
TensorsToSegmentationCalculatorOptions* options) { TensorsToSegmentationCalculatorOptions* options) {
*options->mutable_segmenter_options() = segmenter_option.segmenter_options(); *options->mutable_segmenter_options() = segmenter_option.segmenter_options();
@ -181,7 +183,7 @@ absl::StatusOr<const Tensor*> GetOutputTensor(
// input_stream: "IMAGE:image" // input_stream: "IMAGE:image"
// output_stream: "SEGMENTATION:segmented_masks" // output_stream: "SEGMENTATION:segmented_masks"
// options { // options {
// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterOptions.ext] // [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterGraphOptions.ext]
// { // {
// base_options { // base_options {
// model_asset { // model_asset {
@ -200,12 +202,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig( absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* sc) override { mediapipe::SubgraphContext* sc) override {
ASSIGN_OR_RETURN(const auto* model_resources, ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<ImageSegmenterOptions>(sc)); CreateModelResources<ImageSegmenterGraphOptions>(sc));
Graph graph; Graph graph;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_streams, auto output_streams,
BuildSegmentationTask( BuildSegmentationTask(
sc->Options<ImageSegmenterOptions>(), *model_resources, sc->Options<ImageSegmenterGraphOptions>(), *model_resources,
graph[Input<Image>(kImageTag)], graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
@ -228,13 +230,13 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
// builder::Graph instance. The segmentation pipeline takes images // builder::Graph instance. The segmentation pipeline takes images
// (mediapipe::Image) as the input and returns segmented image mask as output. // (mediapipe::Image) as the input and returns segmented image mask as output.
// //
// task_options: the mediapipe tasks ImageSegmenterOptions proto. // task_options: the mediapipe tasks ImageSegmenterGraphOptions proto.
// model_resources: the ModelSources object initialized from a segmentation // model_resources: the ModelSources object initialized from a segmentation
// model file with model metadata. // model file with model metadata.
// image_in: (mediapipe::Image) stream to run segmentation on. // image_in: (mediapipe::Image) stream to run segmentation on.
// graph: the mediapipe builder::Graph instance to be updated. // graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask( absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
const ImageSegmenterOptions& task_options, const ImageSegmenterGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Image> image_in, const core::ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) { Source<NormalizedRect> norm_rect_in, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
@ -293,8 +295,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
} }
}; };
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::ImageSegmenterGraph); REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::image_segmenter::ImageSegmenterGraph);
} // namespace image_segmenter
} // namespace vision } // namespace vision
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -33,7 +33,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h" #include "tensorflow/lite/kernels/builtin_op_kernels.h"
@ -42,6 +42,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace vision {
namespace image_segmenter {
namespace { namespace {
using ::mediapipe::Image; using ::mediapipe::Image;
@ -547,6 +548,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
// TODO: Add test for hair segmentation model. // TODO: Add test for hair segmentation model.
} // namespace } // namespace
} // namespace image_segmenter
} // namespace vision } // namespace vision
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -19,8 +19,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
mediapipe_proto_library( mediapipe_proto_library(
name = "image_segmenter_options_proto", name = "image_segmenter_graph_options_proto",
srcs = ["image_segmenter_options.proto"], srcs = ["image_segmenter_graph_options.proto"],
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",

View File

@ -21,9 +21,12 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/segmenter_options.proto"; import "mediapipe/tasks/cc/components/proto/segmenter_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
message ImageSegmenterOptions { option java_package = "com.google.mediapipe.tasks.vision.imagesegmenter.proto";
option java_outer_classname = "ImageSegmenterGraphOptionsProto";
message ImageSegmenterGraphOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional ImageSegmenterOptions ext = 458105758; optional ImageSegmenterGraphOptions ext = 458105758;
} }
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite // Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc. // model file with metadata, accelerator options, etc.

View File

@ -20,6 +20,7 @@ android_library(
name = "category", name = "category",
srcs = ["Category.java"], srcs = ["Category.java"],
deps = [ deps = [
"//mediapipe/framework/formats:classification_java_proto_lite",
"//third_party:autovalue", "//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],
@ -36,20 +37,29 @@ android_library(
) )
android_library( android_library(
name = "classification_entry", name = "classifications",
srcs = ["ClassificationEntry.java"], srcs = ["Classifications.java"],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [ deps = [
":category", ":category",
"//mediapipe/framework/formats:classification_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//third_party:autovalue", "//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],
) )
android_library( android_library(
name = "classifications", name = "classificationresult",
srcs = ["Classifications.java"], srcs = ["ClassificationResult.java"],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [ deps = [
":classification_entry", ":classifications",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//third_party:autovalue", "//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],

View File

@ -15,6 +15,7 @@
package com.google.mediapipe.tasks.components.containers; package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.ClassificationProto;
import java.util.Objects; import java.util.Objects;
/** /**
@ -38,6 +39,16 @@ public abstract class Category {
return new AutoValue_Category(score, index, categoryName, displayName); return new AutoValue_Category(score, index, categoryName, displayName);
} }
/**
* Creates a {@link Category} object from a {@link ClassificationProto.Classification} protobuf
* message.
*
* @param proto the {@link ClassificationProto.Classification} protobuf message to convert.
*/
public static Category createFromProto(ClassificationProto.Classification proto) {
return create(proto.getScore(), proto.getIndex(), proto.getLabel(), proto.getDisplayName());
}
/** The probability score of this label category. */ /** The probability score of this label category. */
public abstract float score(); public abstract float score();

View File

@ -1,48 +0,0 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue;
import java.util.Collections;
import java.util.List;
/**
* Represents a list of predicted categories with an optional timestamp. Typically used as result
* for classification tasks.
*/
@AutoValue
public abstract class ClassificationEntry {
/**
* Creates a {@link ClassificationEntry} instance from a list of {@link Category} and optional
* timestamp.
*
* @param categories the list of {@link Category} objects that contain category name, display
* name, score and label index.
* @param timestampMs the {@link long} representing the timestamp for which these categories were
* obtained.
*/
public static ClassificationEntry create(List<Category> categories, long timestampMs) {
return new AutoValue_ClassificationEntry(Collections.unmodifiableList(categories), timestampMs);
}
/** The list of predicted {@link Category} objects, sorted by descending score. */
public abstract List<Category> categories();
/**
* The timestamp (in milliseconds) associated to the classification entry. This is useful for time
* series use cases, e.g. audio classification.
*/
public abstract long timestampMs();
}

View File

@ -0,0 +1,76 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/**
* Represents the classification results of a model. Typically used as a result for classification
* tasks.
*/
@AutoValue
public abstract class ClassificationResult {
/**
* Creates a {@link ClassificationResult} instance.
*
* @param classifications the list of {@link Classifications} objects containing the predicted
* categories for each head of the model.
* @param timestampMs the optional timestamp (in milliseconds) of the start of the chunk of data
* corresponding to these results.
*/
public static ClassificationResult create(
List<Classifications> classifications, Optional<Long> timestampMs) {
return new AutoValue_ClassificationResult(
Collections.unmodifiableList(classifications), timestampMs);
}
/**
* Creates a {@link ClassificationResult} object from a {@link
* ClassificationsProto.ClassificationResult} protobuf message.
*
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
*/
public static ClassificationResult createFromProto(
ClassificationsProto.ClassificationResult proto) {
List<Classifications> classifications = new ArrayList<>();
for (ClassificationsProto.Classifications classificationsProto :
proto.getClassificationsList()) {
classifications.add(Classifications.createFromProto(classificationsProto));
}
Optional<Long> timestampMs =
proto.hasTimestampMs() ? Optional.of(proto.getTimestampMs()) : Optional.empty();
return create(classifications, timestampMs);
}
/** The classification results for each head of the model. */
public abstract List<Classifications> classifications();
/**
* The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to
* these results.
*
* <p>This is only used for classification on time series (e.g. audio classification). In these
* use cases, the amount of data to process might exceed the maximum size that the model can
* process: to solve this, the input data is split into multiple chunks starting at different
* timestamps.
*/
public abstract Optional<Long> timestampMs();
}

View File

@ -15,8 +15,12 @@
package com.google.mediapipe.tasks.components.containers; package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.ClassificationProto;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional;
/** /**
* Represents the list of classification for a given classifier head. Typically used as a result for * Represents the list of classification for a given classifier head. Typically used as a result for
@ -28,25 +32,41 @@ public abstract class Classifications {
/** /**
* Creates a {@link Classifications} instance. * Creates a {@link Classifications} instance.
* *
* @param entries the list of {@link ClassificationEntry} objects containing the predicted * @param categories the list of {@link Category} objects containing the predicted categories.
* categories.
* @param headIndex the index of the classifier head. * @param headIndex the index of the classifier head.
* @param headName the name of the classifier head. * @param headName the optional name of the classifier head.
*/ */
public static Classifications create( public static Classifications create(
List<ClassificationEntry> entries, int headIndex, String headName) { List<Category> categories, int headIndex, Optional<String> headName) {
return new AutoValue_Classifications( return new AutoValue_Classifications(
Collections.unmodifiableList(entries), headIndex, headName); Collections.unmodifiableList(categories), headIndex, headName);
} }
/** A list of {@link ClassificationEntry} objects. */ /**
public abstract List<ClassificationEntry> entries(); * Creates a {@link Classifications} object from a {@link ClassificationsProto.Classifications}
* protobuf message.
*
* @param proto the {@link ClassificationsProto.Classifications} protobuf message to convert.
*/
public static Classifications createFromProto(ClassificationsProto.Classifications proto) {
List<Category> categories = new ArrayList<>();
for (ClassificationProto.Classification classificationProto :
proto.getClassificationList().getClassificationList()) {
categories.add(Category.createFromProto(classificationProto));
}
Optional<String> headName =
proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty();
return create(categories, proto.getHeadIndex(), headName);
}
/** A list of {@link Category} objects. */
public abstract List<Category> categories();
/** /**
* The index of the classifier head these entries refer to. This is useful for multi-head models. * The index of the classifier head these entries refer to. This is useful for multi-head models.
*/ */
public abstract int headIndex(); public abstract int headIndex();
/** The name of the classifier head, which is the corresponding tensor metadata name. */ /** The optional name of the classifier head, which is the corresponding tensor metadata name. */
public abstract String headName(); public abstract Optional<String> headName();
} }

View File

@ -26,22 +26,23 @@ public abstract class BaseOptions {
@AutoValue.Builder @AutoValue.Builder
public abstract static class Builder { public abstract static class Builder {
/** /**
* Sets the model path to a tflite model with metadata in the assets. * Sets the model path to a model asset file (a tflite model or a model asset bundle file) in
* the Android app assets folder.
* *
* <p>Note: when model path is set, both model file descriptor and model buffer should be empty. * <p>Note: when model path is set, both model file descriptor and model buffer should be empty.
*/ */
public abstract Builder setModelAssetPath(String value); public abstract Builder setModelAssetPath(String value);
/** /**
* Sets the native fd int of a tflite model with metadata. * Sets the native fd int of a model asset file (a tflite model or a model asset bundle file).
* *
* <p>Note: when model file descriptor is set, both model path and model buffer should be empty. * <p>Note: when model file descriptor is set, both model path and model buffer should be empty.
*/ */
public abstract Builder setModelAssetFileDescriptor(Integer value); public abstract Builder setModelAssetFileDescriptor(Integer value);
/** /**
* Sets either the direct {@link ByteBuffer} or the {@link MappedByteBuffer} of a tflite model * Sets either the direct {@link ByteBuffer} or the {@link MappedByteBuffer} of a model asset
* with metadata. * file (a tflite model or a model asset bundle file).
* *
* <p>Note: when model buffer is set, both model file and model file descriptor should be empty. * <p>Note: when model buffer is set, both model file and model file descriptor should be empty.
*/ */

View File

@ -22,6 +22,7 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite",
"//mediapipe/tasks/cc/components/proto:segmenter_options_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite", "//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
@ -36,6 +37,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
@ -232,7 +234,7 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l
"//mediapipe/framework/formats:rect_java_proto_lite", "//mediapipe/framework/formats:rect_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",

View File

@ -37,8 +37,8 @@ cc_library(
android_library( android_library(
name = "textclassifier", name = "textclassifier",
srcs = [ srcs = [
"textclassifier/TextClassificationResult.java",
"textclassifier/TextClassifier.java", "textclassifier/TextClassifier.java",
"textclassifier/TextClassifierResult.java",
], ],
javacopts = [ javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF", "-Xep:AndroidJdkLibsChecker:OFF",
@ -51,9 +51,7 @@ android_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",

View File

@ -1,103 +0,0 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.text.textclassifier;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
import com.google.mediapipe.tasks.components.containers.Classifications;
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Represents the classification results generated by {@link TextClassifier}. */
@AutoValue
public abstract class TextClassificationResult implements TaskResult {
/**
* Creates an {@link TextClassificationResult} instance from a {@link
* ClassificationsProto.ClassificationResult} protobuf message.
*
* @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf
* message.
* @param timestampMs a timestamp for this result.
*/
// TODO: consolidate output formats across platforms.
static TextClassificationResult create(
ClassificationsProto.ClassificationResult classificationResult, long timestampMs) {
List<Classifications> classifications = new ArrayList<>();
for (ClassificationsProto.Classifications classificationsProto :
classificationResult.getClassificationsList()) {
classifications.add(classificationsFromProto(classificationsProto));
}
return new AutoValue_TextClassificationResult(
timestampMs, Collections.unmodifiableList(classifications));
}
@Override
public abstract long timestampMs();
/** Contains one set of results per classifier head. */
@SuppressWarnings("AutoValueImmutableFields")
public abstract List<Classifications> classifications();
/**
* Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object.
*
* @param category the {@link CategoryProto.Category} protobuf message to convert.
*/
static Category categoryFromProto(CategoryProto.Category category) {
return Category.create(
category.getScore(),
category.getIndex(),
category.getCategoryName(),
category.getDisplayName());
}
/**
* Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link
* ClassificationEntry} object.
*
* @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert.
*/
static ClassificationEntry classificationEntryFromProto(
ClassificationsProto.ClassificationEntry entry) {
List<Category> categories = new ArrayList<>();
for (CategoryProto.Category category : entry.getCategoriesList()) {
categories.add(categoryFromProto(category));
}
return ClassificationEntry.create(categories, entry.getTimestampMs());
}
/**
* Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link
* Classifications} object.
*
* @param classifications the {@link ClassificationsProto.Classifications} protobuf message to
* convert.
*/
static Classifications classificationsFromProto(
ClassificationsProto.Classifications classifications) {
List<ClassificationEntry> entries = new ArrayList<>();
for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
entries.add(classificationEntryFromProto(entry));
}
return Classifications.create(
entries, classifications.getHeadIndex(), classifications.getHeadName());
}
}

View File

@ -22,6 +22,7 @@ import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
@ -86,10 +87,9 @@ public final class TextClassifier implements AutoCloseable {
@SuppressWarnings("ConstantCaseForConstants") @SuppressWarnings("ConstantCaseForConstants")
private static final List<String> OUTPUT_STREAMS = private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList( Collections.unmodifiableList(Arrays.asList("CLASSIFICATIONS:classifications_out"));
Arrays.asList("CLASSIFICATION_RESULT:classification_result_out"));
private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0; private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
private static final String TASK_GRAPH_NAME = private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.text.text_classifier.TextClassifierGraph"; "mediapipe.tasks.text.text_classifier.TextClassifierGraph";
private final TaskRunner runner; private final TaskRunner runner;
@ -142,17 +142,18 @@ public final class TextClassifier implements AutoCloseable {
* @throws MediaPipeException if there is an error during {@link TextClassifier} creation. * @throws MediaPipeException if there is an error during {@link TextClassifier} creation.
*/ */
public static TextClassifier createFromOptions(Context context, TextClassifierOptions options) { public static TextClassifier createFromOptions(Context context, TextClassifierOptions options) {
OutputHandler<TextClassificationResult, Void> handler = new OutputHandler<>(); OutputHandler<TextClassifierResult, Void> handler = new OutputHandler<>();
handler.setOutputPacketConverter( handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<TextClassificationResult, Void>() { new OutputHandler.OutputPacketConverter<TextClassifierResult, Void>() {
@Override @Override
public TextClassificationResult convertToTaskResult(List<Packet> packets) { public TextClassifierResult convertToTaskResult(List<Packet> packets) {
try { try {
return TextClassificationResult.create( return TextClassifierResult.create(
PacketGetter.getProto( ClassificationResult.createFromProto(
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX), PacketGetter.getProto(
ClassificationsProto.ClassificationResult.getDefaultInstance()), packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp()); ClassificationsProto.ClassificationResult.getDefaultInstance())),
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp());
} catch (IOException e) { } catch (IOException e) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
@ -192,10 +193,10 @@ public final class TextClassifier implements AutoCloseable {
* *
* @param inputText a {@link String} for processing. * @param inputText a {@link String} for processing.
*/ */
public TextClassificationResult classify(String inputText) { public TextClassifierResult classify(String inputText) {
Map<String, Packet> inputPackets = new HashMap<>(); Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText)); inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText));
return (TextClassificationResult) runner.process(inputPackets); return (TextClassifierResult) runner.process(inputPackets);
} }
/** Closes and cleans up the {@link TextClassifier}. */ /** Closes and cleans up the {@link TextClassifier}. */

View File

@ -0,0 +1,55 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.text.textclassifier;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.core.TaskResult;
/** Represents the classification results generated by {@link TextClassifier}. */
@AutoValue
public abstract class TextClassifierResult implements TaskResult {
/**
* Creates an {@link TextClassifierResult} instance.
*
* @param classificationResult the {@link ClassificationResult} object containing one set of
* results per classifier head.
* @param timestampMs a timestamp for this result.
*/
static TextClassifierResult create(ClassificationResult classificationResult, long timestampMs) {
return new AutoValue_TextClassifierResult(classificationResult, timestampMs);
}
/**
* Creates an {@link TextClassifierResult} instance from a {@link
* ClassificationsProto.ClassificationResult} protobuf message.
*
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
* @param timestampMs a timestamp for this result.
*/
// TODO: consolidate output formats across platforms.
static TextClassifierResult createFromProto(
ClassificationsProto.ClassificationResult proto, long timestampMs) {
return create(ClassificationResult.createFromProto(proto), timestampMs);
}
/** Contains one set of results per classifier head. */
public abstract ClassificationResult classificationResult();
@Override
public abstract long timestampMs();
}

View File

@ -84,8 +84,8 @@ android_library(
android_library( android_library(
name = "imageclassifier", name = "imageclassifier",
srcs = [ srcs = [
"imageclassifier/ImageClassificationResult.java",
"imageclassifier/ImageClassifier.java", "imageclassifier/ImageClassifier.java",
"imageclassifier/ImageClassifierResult.java",
], ],
javacopts = [ javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF", "-Xep:AndroidJdkLibsChecker:OFF",
@ -100,9 +100,7 @@ android_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue", "//third_party:autovalue",

View File

@ -1,102 +0,0 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.imageclassifier;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
import com.google.mediapipe.tasks.components.containers.Classifications;
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Represents the classification results generated by {@link ImageClassifier}. */
@AutoValue
public abstract class ImageClassificationResult implements TaskResult {
/**
* Creates an {@link ImageClassificationResult} instance from a {@link
* ClassificationsProto.ClassificationResult} protobuf message.
*
* @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf
* message.
* @param timestampMs a timestamp for this result.
*/
// TODO: consolidate output formats across platforms.
static ImageClassificationResult create(
ClassificationsProto.ClassificationResult classificationResult, long timestampMs) {
List<Classifications> classifications = new ArrayList<>();
for (ClassificationsProto.Classifications classificationsProto :
classificationResult.getClassificationsList()) {
classifications.add(classificationsFromProto(classificationsProto));
}
return new AutoValue_ImageClassificationResult(
timestampMs, Collections.unmodifiableList(classifications));
}
@Override
public abstract long timestampMs();
/** Contains one set of results per classifier head. */
public abstract List<Classifications> classifications();
/**
* Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object.
*
* @param category the {@link CategoryProto.Category} protobuf message to convert.
*/
static Category categoryFromProto(CategoryProto.Category category) {
return Category.create(
category.getScore(),
category.getIndex(),
category.getCategoryName(),
category.getDisplayName());
}
/**
* Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link
* ClassificationEntry} object.
*
* @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert.
*/
static ClassificationEntry classificationEntryFromProto(
ClassificationsProto.ClassificationEntry entry) {
List<Category> categories = new ArrayList<>();
for (CategoryProto.Category category : entry.getCategoriesList()) {
categories.add(categoryFromProto(category));
}
return ClassificationEntry.create(categories, entry.getTimestampMs());
}
/**
* Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link
* Classifications} object.
*
* @param classifications the {@link ClassificationsProto.Classifications} protobuf message to
* convert.
*/
static Classifications classificationsFromProto(
ClassificationsProto.Classifications classifications) {
List<ClassificationEntry> entries = new ArrayList<>();
for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
entries.add(classificationEntryFromProto(entry));
}
return Classifications.create(
entries, classifications.getHeadIndex(), classifications.getHeadName());
}
}

View File

@ -25,6 +25,7 @@ import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
@ -96,8 +97,8 @@ public final class ImageClassifier extends BaseVisionTaskApi {
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS = private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList( Collections.unmodifiableList(
Arrays.asList("CLASSIFICATION_RESULT:classification_result_out", "IMAGE:image_out")); Arrays.asList("CLASSIFICATIONS:classifications_out", "IMAGE:image_out"));
private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0; private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
private static final int IMAGE_OUT_STREAM_INDEX = 1; private static final int IMAGE_OUT_STREAM_INDEX = 1;
private static final String TASK_GRAPH_NAME = private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
@ -164,17 +165,18 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* @throws MediaPipeException if there is an error during {@link ImageClassifier} creation. * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
*/ */
public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) { public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) {
OutputHandler<ImageClassificationResult, MPImage> handler = new OutputHandler<>(); OutputHandler<ImageClassifierResult, MPImage> handler = new OutputHandler<>();
handler.setOutputPacketConverter( handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<ImageClassificationResult, MPImage>() { new OutputHandler.OutputPacketConverter<ImageClassifierResult, MPImage>() {
@Override @Override
public ImageClassificationResult convertToTaskResult(List<Packet> packets) { public ImageClassifierResult convertToTaskResult(List<Packet> packets) {
try { try {
return ImageClassificationResult.create( return ImageClassifierResult.create(
PacketGetter.getProto( ClassificationResult.createFromProto(
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX), PacketGetter.getProto(
ClassificationsProto.ClassificationResult.getDefaultInstance()), packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp()); ClassificationsProto.ClassificationResult.getDefaultInstance())),
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp());
} catch (IOException e) { } catch (IOException e) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
@ -229,7 +231,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* @param image a MediaPipe {@link MPImage} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ImageClassificationResult classify(MPImage image) { public ImageClassifierResult classify(MPImage image) {
return classify(image, ImageProcessingOptions.builder().build()); return classify(image, ImageProcessingOptions.builder().build());
} }
@ -248,9 +250,9 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* input image before running inference. * input image before running inference.
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ImageClassificationResult classify( public ImageClassifierResult classify(
MPImage image, ImageProcessingOptions imageProcessingOptions) { MPImage image, ImageProcessingOptions imageProcessingOptions) {
return (ImageClassificationResult) processImageData(image, imageProcessingOptions); return (ImageClassifierResult) processImageData(image, imageProcessingOptions);
} }
/** /**
@ -271,7 +273,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* @param timestampMs the input timestamp (in milliseconds). * @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ImageClassificationResult classifyForVideo(MPImage image, long timestampMs) { public ImageClassifierResult classifyForVideo(MPImage image, long timestampMs) {
return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
} }
@ -294,9 +296,9 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* @param timestampMs the input timestamp (in milliseconds). * @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ImageClassificationResult classifyForVideo( public ImageClassifierResult classifyForVideo(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
return (ImageClassificationResult) processVideoData(image, imageProcessingOptions, timestampMs); return (ImageClassifierResult) processVideoData(image, imageProcessingOptions, timestampMs);
} }
/** /**
@ -383,7 +385,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* the image classifier is in the live stream mode. * the image classifier is in the live stream mode.
*/ */
public abstract Builder setResultListener( public abstract Builder setResultListener(
ResultListener<ImageClassificationResult, MPImage> resultListener); ResultListener<ImageClassifierResult, MPImage> resultListener);
/** Sets an optional {@link ErrorListener}. */ /** Sets an optional {@link ErrorListener}. */
public abstract Builder setErrorListener(ErrorListener errorListener); public abstract Builder setErrorListener(ErrorListener errorListener);
@ -420,7 +422,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
abstract Optional<ClassifierOptions> classifierOptions(); abstract Optional<ClassifierOptions> classifierOptions();
abstract Optional<ResultListener<ImageClassificationResult, MPImage>> resultListener(); abstract Optional<ResultListener<ImageClassifierResult, MPImage>> resultListener();
abstract Optional<ErrorListener> errorListener(); abstract Optional<ErrorListener> errorListener();

View File

@ -0,0 +1,55 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.imageclassifier;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.core.TaskResult;
/** Represents the classification results generated by {@link ImageClassifier}. */
@AutoValue
public abstract class ImageClassifierResult implements TaskResult {
/**
* Creates an {@link ImageClassifierResult} instance.
*
* @param classificationResult the {@link ClassificationResult} object containing one set of
* results per classifier head.
* @param timestampMs a timestamp for this result.
*/
static ImageClassifierResult create(ClassificationResult classificationResult, long timestampMs) {
return new AutoValue_ImageClassifierResult(classificationResult, timestampMs);
}
/**
* Creates an {@link ImageClassifierResult} instance from a {@link
* ClassificationsProto.ClassificationResult} protobuf message.
*
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
* @param timestampMs a timestamp for this result.
*/
// TODO: consolidate output formats across platforms.
static ImageClassifierResult createFromProto(
ClassificationsProto.ClassificationResult proto, long timestampMs) {
return create(ClassificationResult.createFromProto(proto), timestampMs);
}
/** Contains one set of results per classifier head. */
public abstract ClassificationResult classificationResult();
@Override
public abstract long timestampMs();
}

View File

@ -76,7 +76,7 @@ public class TextClassifierTest {
public void classify_succeedsWithBert() throws Exception { public void classify_succeedsWithBert() throws Exception {
TextClassifier textClassifier = TextClassifier textClassifier =
TextClassifier.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE); TextClassifier.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
assertHasOneHead(negativeResults); assertHasOneHead(negativeResults);
assertCategoriesAre( assertCategoriesAre(
negativeResults, negativeResults,
@ -84,7 +84,7 @@ public class TextClassifierTest {
Category.create(0.95630914f, 0, "negative", ""), Category.create(0.95630914f, 0, "negative", ""),
Category.create(0.04369091f, 1, "positive", ""))); Category.create(0.04369091f, 1, "positive", "")));
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT); TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
assertHasOneHead(positiveResults); assertHasOneHead(positiveResults);
assertCategoriesAre( assertCategoriesAre(
positiveResults, positiveResults,
@ -99,7 +99,7 @@ public class TextClassifierTest {
TextClassifier.createFromFile( TextClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), ApplicationProvider.getApplicationContext(),
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE)); TestUtils.loadFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE));
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
assertHasOneHead(negativeResults); assertHasOneHead(negativeResults);
assertCategoriesAre( assertCategoriesAre(
negativeResults, negativeResults,
@ -107,7 +107,7 @@ public class TextClassifierTest {
Category.create(0.95630914f, 0, "negative", ""), Category.create(0.95630914f, 0, "negative", ""),
Category.create(0.04369091f, 1, "positive", ""))); Category.create(0.04369091f, 1, "positive", "")));
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT); TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
assertHasOneHead(positiveResults); assertHasOneHead(positiveResults);
assertHasOneHead(positiveResults); assertHasOneHead(positiveResults);
assertCategoriesAre( assertCategoriesAre(
@ -122,7 +122,7 @@ public class TextClassifierTest {
TextClassifier textClassifier = TextClassifier textClassifier =
TextClassifier.createFromFile( TextClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE); ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE);
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
assertHasOneHead(negativeResults); assertHasOneHead(negativeResults);
assertCategoriesAre( assertCategoriesAre(
negativeResults, negativeResults,
@ -130,7 +130,7 @@ public class TextClassifierTest {
Category.create(0.6647746f, 0, "Negative", ""), Category.create(0.6647746f, 0, "Negative", ""),
Category.create(0.33522537f, 1, "Positive", ""))); Category.create(0.33522537f, 1, "Positive", "")));
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT); TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
assertHasOneHead(positiveResults); assertHasOneHead(positiveResults);
assertCategoriesAre( assertCategoriesAre(
positiveResults, positiveResults,
@ -139,16 +139,15 @@ public class TextClassifierTest {
Category.create(0.48799595f, 1, "Positive", ""))); Category.create(0.48799595f, 1, "Positive", "")));
} }
private static void assertHasOneHead(TextClassificationResult results) { private static void assertHasOneHead(TextClassifierResult results) {
assertThat(results.classifications()).hasSize(1); assertThat(results.classificationResult().classifications()).hasSize(1);
assertThat(results.classifications().get(0).headIndex()).isEqualTo(0); assertThat(results.classificationResult().classifications().get(0).headIndex()).isEqualTo(0);
assertThat(results.classifications().get(0).headName()).isEqualTo("probability"); assertThat(results.classificationResult().classifications().get(0).headName().get())
assertThat(results.classifications().get(0).entries()).hasSize(1); .isEqualTo("probability");
} }
private static void assertCategoriesAre( private static void assertCategoriesAre(TextClassifierResult results, List<Category> categories) {
TextClassificationResult results, List<Category> categories) { assertThat(results.classificationResult().classifications().get(0).categories())
assertThat(results.classifications().get(0).entries().get(0).categories())
.isEqualTo(categories); .isEqualTo(categories);
} }
} }

View File

@ -91,11 +91,12 @@ public class ImageClassifierTest {
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromFile( ImageClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE); ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertThat(results.classifications().get(0).entries().get(0).categories()).hasSize(1001); assertThat(results.classificationResult().classifications().get(0).categories())
assertThat(results.classifications().get(0).entries().get(0).categories().get(0)) .hasSize(1001);
assertThat(results.classificationResult().classifications().get(0).categories().get(0))
.isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", "")); .isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", ""));
} }
@ -108,9 +109,9 @@ public class ImageClassifierTest {
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, results,
Arrays.asList( Arrays.asList(
@ -128,9 +129,9 @@ public class ImageClassifierTest {
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", ""))); results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", "")));
} }
@ -144,9 +145,9 @@ public class ImageClassifierTest {
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, results,
Arrays.asList( Arrays.asList(
@ -166,9 +167,9 @@ public class ImageClassifierTest {
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, results,
Arrays.asList( Arrays.asList(
@ -190,9 +191,9 @@ public class ImageClassifierTest {
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, results,
Arrays.asList( Arrays.asList(
@ -214,10 +215,10 @@ public class ImageClassifierTest {
RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f); RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f);
ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRegionOfInterest(roi).build(); ImageProcessingOptions.builder().setRegionOfInterest(roi).build();
ImageClassificationResult results = ImageClassifierResult results =
imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions); imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions);
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", ""))); results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", "")));
} }
@ -233,10 +234,10 @@ public class ImageClassifierTest {
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRotationDegrees(-90).build(); ImageProcessingOptions.builder().setRotationDegrees(-90).build();
ImageClassificationResult results = ImageClassifierResult results =
imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, results,
Arrays.asList( Arrays.asList(
@ -258,11 +259,11 @@ public class ImageClassifierTest {
RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f); RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f);
ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build(); ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build();
ImageClassificationResult results = ImageClassifierResult results =
imageClassifier.classify( imageClassifier.classify(
getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions); getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions);
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", ""))); results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", "")));
} }
@ -391,9 +392,9 @@ public class ImageClassifierTest {
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
} }
@ -410,9 +411,8 @@ public class ImageClassifierTest {
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
ImageClassificationResult results = ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i);
imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); assertHasOneHead(results);
assertHasOneHeadAndOneTimestamp(results, i);
assertCategoriesAre( assertCategoriesAre(
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
} }
@ -478,24 +478,17 @@ public class ImageClassifierTest {
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
} }
private static void assertHasOneHeadAndOneTimestamp( private static void assertHasOneHead(ImageClassifierResult results) {
ImageClassificationResult results, long timestampMs) { assertThat(results.classificationResult().classifications()).hasSize(1);
assertThat(results.classifications()).hasSize(1); assertThat(results.classificationResult().classifications().get(0).headIndex()).isEqualTo(0);
assertThat(results.classifications().get(0).headIndex()).isEqualTo(0); assertThat(results.classificationResult().classifications().get(0).headName().get())
assertThat(results.classifications().get(0).headName()).isEqualTo("probability"); .isEqualTo("probability");
assertThat(results.classifications().get(0).entries()).hasSize(1);
assertThat(results.classifications().get(0).entries().get(0).timestampMs())
.isEqualTo(timestampMs);
} }
private static void assertCategoriesAre( private static void assertCategoriesAre(
ImageClassificationResult results, List<Category> categories) { ImageClassifierResult results, List<Category> categories) {
assertThat(results.classifications().get(0).entries().get(0).categories()) assertThat(results.classificationResult().classifications().get(0).categories())
.hasSize(categories.size()); .isEqualTo(categories);
for (int i = 0; i < categories.size(); i++) {
assertThat(results.classifications().get(0).entries().get(0).categories().get(i))
.isEqualTo(categories.get(i));
}
} }
private static void assertImageSizeIsExpected(MPImage inputImage) { private static void assertImageSizeIsExpected(MPImage inputImage) {

View File

@ -43,3 +43,9 @@ py_library(
srcs = ["image_classifier.py"], srcs = ["image_classifier.py"],
deps = [":metadata_writer"], deps = [":metadata_writer"],
) )
py_library(
name = "text_classifier",
srcs = ["text_classifier.py"],
deps = [":metadata_writer"],
)

View File

@ -62,10 +62,10 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456 https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
Returns: Returns:
An MetadataWrite object. A MetadataWriter object.
""" """
writer = metadata_writer.MetadataWriter(model_buffer) writer = metadata_writer.MetadataWriter(model_buffer)
writer.add_genernal_info(_MODEL_NAME, _MODEL_DESCRIPTION) writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
writer.add_image_input(input_norm_mean, input_norm_std) writer.add_image_input(input_norm_mean, input_norm_std)
writer.add_classification_output(labels, score_calibration) writer.add_classification_output(labels, score_calibration)
return cls(writer) return cls(writer)

View File

@ -228,6 +228,45 @@ class ScoreThresholdingMd:
return score_thresholding return score_thresholding
class RegexTokenizerMd:
"""A container for the Regex tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
"""
def __init__(self, delim_regex_pattern: str, vocab_file_path: str):
"""Initializes a RegexTokenizerMd object.
Args:
delim_regex_pattern: the regular expression to segment strings and create
tokens.
vocab_file_path: path to the vocabulary file.
"""
self._delim_regex_pattern = delim_regex_pattern
self._vocab_file_path = vocab_file_path
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the Regex tokenizer metadata based on the information.
Returns:
A Flatbuffers Python object of the Regex tokenizer metadata.
"""
vocab = _metadata_fb.AssociatedFileT()
vocab.name = self._vocab_file_path
vocab.description = _VOCAB_FILE_DESCRIPTION
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
# Create the RegexTokenizer.
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = (
_metadata_fb.ProcessUnitOptions.RegexTokenizerOptions)
tokenizer.options = _metadata_fb.RegexTokenizerOptionsT()
tokenizer.options.delimRegexPattern = self._delim_regex_pattern
tokenizer.options.vocabFile = [vocab]
return tokenizer
class TensorMd: class TensorMd:
"""A container for common tensor metadata information. """A container for common tensor metadata information.
@ -397,6 +436,56 @@ class InputImageTensorMd(TensorMd):
return tensor_metadata return tensor_metadata
class InputTextTensorMd(TensorMd):
"""A container for the input text tensor metadata information.
Attributes:
tokenizer_md: information of the tokenizer in the input text tensor, if any.
"""
def __init__(self,
name: Optional[str] = None,
description: Optional[str] = None,
tokenizer_md: Optional[RegexTokenizerMd] = None):
"""Initializes the instance of InputTextTensorMd.
Args:
name: name of the tensor.
description: description of what the tensor is.
tokenizer_md: information of the tokenizer in the input text tensor, if
any. Only `RegexTokenizer` [1] is currenly supported. If the tokenizer
is `BertTokenizer` [2] or `SentencePieceTokenizer` [3], refer to
`BertInputTensorsMd` class.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
[3]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
"""
super().__init__(name, description)
self.tokenizer_md = tokenizer_md
def create_metadata(self) -> _metadata_fb.TensorMetadataT:
"""Creates the input text metadata based on the information.
Returns:
A Flatbuffers Python object of the input text metadata.
Raises:
ValueError: if the type of tokenizer_md is unsupported.
"""
if not isinstance(self.tokenizer_md, (type(None), RegexTokenizerMd)):
raise ValueError(
f"The type of tokenizer_options, {type(self.tokenizer_md)}, is "
f"unsupported")
tensor_metadata = super().create_metadata()
if self.tokenizer_md:
tensor_metadata.processUnits = [self.tokenizer_md.create_metadata()]
return tensor_metadata
class ClassificationTensorMd(TensorMd): class ClassificationTensorMd(TensorMd):
"""A container for the classification tensor metadata information. """A container for the classification tensor metadata information.

View File

@ -29,6 +29,9 @@ from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
_INPUT_IMAGE_NAME = 'image' _INPUT_IMAGE_NAME = 'image'
_INPUT_IMAGE_DESCRIPTION = 'Input image to be processed.' _INPUT_IMAGE_DESCRIPTION = 'Input image to be processed.'
_INPUT_REGEX_TEXT_NAME = 'input_text'
_INPUT_REGEX_TEXT_DESCRIPTION = ('Embedding vectors representing the input '
'text to be processed.')
_OUTPUT_CLASSIFICATION_NAME = 'score' _OUTPUT_CLASSIFICATION_NAME = 'score'
_OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively.' _OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively.'
@ -82,6 +85,22 @@ class ScoreThresholding:
global_score_threshold: float global_score_threshold: float
@dataclasses.dataclass
class RegexTokenizer:
"""Parameters of the Regex tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
Attributes:
delim_regex_pattern: the regular expression to segment strings and create
tokens.
vocab_file_path: path to the vocabulary file.
"""
delim_regex_pattern: str
vocab_file_path: str
class Labels(object): class Labels(object):
"""Simple container holding classification labels of a particular tensor. """Simple container holding classification labels of a particular tensor.
@ -355,11 +374,11 @@ class MetadataWriter(object):
if os.path.exists(self._temp_folder.name): if os.path.exists(self._temp_folder.name):
self._temp_folder.cleanup() self._temp_folder.cleanup()
def add_genernal_info( def add_general_info(
self, self,
model_name: str, model_name: str,
model_description: Optional[str] = None) -> 'MetadataWriter': model_description: Optional[str] = None) -> 'MetadataWriter':
"""Adds a genernal info metadata for the general metadata informantion.""" """Adds a general info metadata for the general metadata informantion."""
# Will overwrite the previous `self._general_md` if exists. # Will overwrite the previous `self._general_md` if exists.
self._general_md = metadata_info.GeneralMd( self._general_md = metadata_info.GeneralMd(
name=model_name, description=model_description) name=model_name, description=model_description)
@ -415,6 +434,34 @@ class MetadataWriter(object):
self._input_mds.append(input_md) self._input_mds.append(input_md)
return self return self
def add_regex_text_input(
self,
regex_tokenizer: RegexTokenizer,
name: str = _INPUT_REGEX_TEXT_NAME,
description: str = _INPUT_REGEX_TEXT_DESCRIPTION) -> 'MetadataWriter':
"""Adds an input text metadata for the text input with regex tokenizer.
Args:
regex_tokenizer: information of the regex tokenizer [1] used to process
the input string.
name: Name of the input tensor.
description: Description of the input tensor.
Returns:
The MetaWriter instance, can be used for chained operation.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
"""
tokenizer_md = metadata_info.RegexTokenizerMd(
delim_regex_pattern=regex_tokenizer.delim_regex_pattern,
vocab_file_path=regex_tokenizer.vocab_file_path)
input_md = metadata_info.InputTextTensorMd(
name=name, description=description, tokenizer_md=tokenizer_md)
self._input_mds.append(input_md)
self._associated_files.append(regex_tokenizer.vocab_file_path)
return self
def add_classification_output( def add_classification_output(
self, self,
labels: Optional[Labels] = None, labels: Optional[Labels] = None,

View File

@ -0,0 +1,64 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Writes metadata and label file to the Text classifier models."""
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
_MODEL_NAME = "TextClassifier"
_MODEL_DESCRIPTION = ("Classify the input text into a set of known categories.")
class MetadataWriter(metadata_writer.MetadataWriterBase):
"""MetadataWriter to write the metadata into the text classifier."""
@classmethod
def create_for_regex_model(
cls, model_buffer: bytearray,
regex_tokenizer: metadata_writer.RegexTokenizer,
labels: metadata_writer.Labels) -> "MetadataWriter":
"""Creates MetadataWriter for TFLite model with regex tokentizer.
The parameters required in this method are mandatory when using MediaPipe
Tasks.
Note that only the output TFLite is used for deployment. The output JSON
content is used to interpret the metadata content.
Args:
model_buffer: A valid flatbuffer loaded from the TFLite model file.
regex_tokenizer: information of the regex tokenizer [1] used to process
the input string. If the tokenizer is `BertTokenizer` [2] or
`SentencePieceTokenizer` [3], please refer to
`create_for_bert_model`.
labels: an instance of Labels helper class used in the output
classification tensor [4].
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
[3]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
[4]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
Returns:
A MetadataWriter object.
"""
writer = metadata_writer.MetadataWriter(model_buffer)
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
writer.add_regex_text_input(regex_tokenizer)
writer.add_classification_output(labels)
return cls(writer)

View File

@ -174,12 +174,10 @@ class AudioClassifierTest(parameterized.TestCase):
self.assertIsInstance(classifier, _AudioClassifier) self.assertIsInstance(classifier, _AudioClassifier)
def test_create_from_options_fails_with_invalid_model_path(self): def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
r"ExternalFile must specify at least one of 'file_content', " base_options = _BaseOptions(
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): model_asset_path='/path/to/invalid/model.tflite')
base_options = _BaseOptions(model_asset_path='')
options = _AudioClassifierOptions(base_options=base_options) options = _AudioClassifierOptions(base_options=base_options)
_AudioClassifier.create_from_options(options) _AudioClassifier.create_from_options(options)

View File

@ -53,3 +53,17 @@ py_test(
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
) )
py_test(
name = "text_classifier_test",
srcs = ["text_classifier_test.py"],
data = [
"//mediapipe/tasks/testdata/metadata:data_files",
"//mediapipe/tasks/testdata/metadata:model_files",
],
deps = [
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
"//mediapipe/tasks/python/metadata/metadata_writers:text_classifier",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -191,6 +191,43 @@ class InputImageTensorMdTest(parameterized.TestCase):
f"{len(norm_mean)} and {len(norm_std)}", str(error.exception)) f"{len(norm_mean)} and {len(norm_std)}", str(error.exception))
class InputTextTensorMdTest(absltest.TestCase):
_NAME = "input text"
_DESCRIPTION = "The input string."
_VOCAB_FILE = "vocab.txt"
_DELIM_REGEX_PATTERN = r"[^\w\']+"
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "input_text_tensor_meta.json"))
_EXPECTED_TENSOR_DEFAULT_JSON = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "input_text_tensor_default_meta.json"))
def test_create_metadata_should_succeed(self):
regex_tokenizer_md = metadata_info.RegexTokenizerMd(
self._DELIM_REGEX_PATTERN, self._VOCAB_FILE)
text_tensor_md = metadata_info.InputTextTensorMd(self._NAME,
self._DESCRIPTION,
regex_tokenizer_md)
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_tensor(
text_tensor_md.create_metadata()))
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def test_create_metadata_by_default_should_succeed(self):
text_tensor_md = metadata_info.InputTextTensorMd()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_tensor(
text_tensor_md.create_metadata()))
with open(self._EXPECTED_TENSOR_DEFAULT_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
class ClassificationTensorMdTest(parameterized.TestCase): class ClassificationTensorMdTest(parameterized.TestCase):
_NAME = "probability" _NAME = "probability"

View File

@ -113,7 +113,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
def test_initialize_and_populate(self): def test_initialize_and_populate(self):
writer = metadata_writer.MetadataWriter.create( writer = metadata_writer.MetadataWriter.create(
self.image_classifier_model_buffer) self.image_classifier_model_buffer)
writer.add_genernal_info( writer.add_general_info(
model_name='my_image_model', model_description='my_description') model_name='my_image_model', model_description='my_description')
tflite_model, metadata_json = writer.populate() tflite_model, metadata_json = writer.populate()
self.assertLen(tflite_model, 1882986) self.assertLen(tflite_model, 1882986)
@ -142,7 +142,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
def test_add_feature_input_output(self): def test_add_feature_input_output(self):
writer = metadata_writer.MetadataWriter.create( writer = metadata_writer.MetadataWriter.create(
self.image_classifier_model_buffer) self.image_classifier_model_buffer)
writer.add_genernal_info( writer.add_general_info(
model_name='my_model', model_description='my_description') model_name='my_model', model_description='my_description')
writer.add_feature_input( writer.add_feature_input(
name='input_tesnor', description='a feature input tensor') name='input_tesnor', description='a feature input tensor')
@ -191,7 +191,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
def test_image_classifier(self): def test_image_classifier(self):
writer = metadata_writer.MetadataWriter.create( writer = metadata_writer.MetadataWriter.create(
self.image_classifier_model_buffer) self.image_classifier_model_buffer)
writer.add_genernal_info( writer.add_general_info(
model_name='image_classifier', model_name='image_classifier',
model_description='Imagenet classification model') model_description='Imagenet classification model')
writer.add_image_input( writer.add_image_input(
@ -282,7 +282,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
def test_image_classifier_with_locale_and_score_calibration(self): def test_image_classifier_with_locale_and_score_calibration(self):
writer = metadata_writer.MetadataWriter(self.image_classifier_model_buffer) writer = metadata_writer.MetadataWriter(self.image_classifier_model_buffer)
writer.add_genernal_info( writer.add_general_info(
model_name='image_classifier', model_name='image_classifier',
model_description='Classify the input image.') model_description='Classify the input image.')
writer.add_image_input( writer.add_image_input(

View File

@ -0,0 +1,51 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metadata_writer.text_classifier."""
from absl.testing import absltest
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import text_classifier
from mediapipe.tasks.python.test import test_utils
_TEST_DIR = "mediapipe/tasks/testdata/metadata/"
_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite")
_LABEL_FILE = test_utils.get_test_data_path(_TEST_DIR +
"movie_review_labels.txt")
_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt")
_DELIM_REGEX_PATTERN = r"[^\w\']+"
_JSON_FILE = test_utils.get_test_data_path("movie_review.json")
class TextClassifierTest(absltest.TestCase):
def test_write_metadata(self,):
with open(_MODEL, "rb") as f:
model_buffer = f.read()
writer = text_classifier.MetadataWriter.create_for_regex_model(
model_buffer,
regex_tokenizer=metadata_writer.RegexTokenizer(
delim_regex_pattern=_DELIM_REGEX_PATTERN,
vocab_file_path=_VOCAB_FILE),
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
_, metadata_json = writer.populate()
with open(_JSON_FILE, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
if __name__ == "__main__":
absltest.main()

View File

@ -154,12 +154,10 @@ class ImageClassifierTest(parameterized.TestCase):
self.assertIsInstance(classifier, _TextClassifier) self.assertIsInstance(classifier, _TextClassifier)
def test_create_from_options_fails_with_invalid_model_path(self): def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
r"ExternalFile must specify at least one of 'file_content', " base_options = _BaseOptions(
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): model_asset_path='/path/to/invalid/model.tflite')
base_options = _BaseOptions(model_asset_path='')
options = _TextClassifierOptions(base_options=base_options) options = _TextClassifierOptions(base_options=base_options)
_TextClassifier.create_from_options(options) _TextClassifier.create_from_options(options)

View File

@ -147,12 +147,10 @@ class ImageClassifierTest(parameterized.TestCase):
self.assertIsInstance(classifier, _ImageClassifier) self.assertIsInstance(classifier, _ImageClassifier)
def test_create_from_options_fails_with_invalid_model_path(self): def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
r"ExternalFile must specify at least one of 'file_content', " base_options = _BaseOptions(
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): model_asset_path='/path/to/invalid/model.tflite')
base_options = _BaseOptions(model_asset_path='')
options = _ImageClassifierOptions(base_options=base_options) options = _ImageClassifierOptions(base_options=base_options)
_ImageClassifier.create_from_options(options) _ImageClassifier.create_from_options(options)

View File

@ -97,12 +97,10 @@ class ImageSegmenterTest(parameterized.TestCase):
self.assertIsInstance(segmenter, _ImageSegmenter) self.assertIsInstance(segmenter, _ImageSegmenter)
def test_create_from_options_fails_with_invalid_model_path(self): def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
r"ExternalFile must specify at least one of 'file_content', " base_options = _BaseOptions(
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): model_asset_path='/path/to/invalid/model.tflite')
base_options = _BaseOptions(model_asset_path='')
options = _ImageSegmenterOptions(base_options=base_options) options = _ImageSegmenterOptions(base_options=base_options)
_ImageSegmenter.create_from_options(options) _ImageSegmenter.create_from_options(options)

View File

@ -119,12 +119,10 @@ class ObjectDetectorTest(parameterized.TestCase):
self.assertIsInstance(detector, _ObjectDetector) self.assertIsInstance(detector, _ObjectDetector)
def test_create_from_options_fails_with_invalid_model_path(self): def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
r"ExternalFile must specify at least one of 'file_content', " base_options = _BaseOptions(
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): model_asset_path='/path/to/invalid/model.tflite')
base_options = _BaseOptions(model_asset_path='')
options = _ObjectDetectorOptions(base_options=base_options) options = _ObjectDetectorOptions(base_options=base_options)
_ObjectDetector.create_from_options(options) _ObjectDetector.create_from_options(options)

View File

@ -70,7 +70,7 @@ py_library(
"//mediapipe/python:packet_creator", "//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2", "//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/core:task_info",

View File

@ -22,7 +22,7 @@ from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.components.proto import segmenter_options_pb2 from mediapipe.tasks.cc.components.proto import segmenter_options_pb2
from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_options_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -31,7 +31,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
_ImageSegmenterOptionsProto = image_segmenter_options_pb2.ImageSegmenterOptions _ImageSegmenterGraphOptionsProto = image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions
_RunningMode = vision_task_running_mode.VisionTaskRunningMode _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
@ -40,7 +40,7 @@ _SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageSegmenterGraph' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000 _MICRO_SECONDS_PER_MILLISECOND = 1000
@ -81,13 +81,13 @@ class ImageSegmenterOptions:
[List[image_module.Image], image_module.Image, int], None]] = None [List[image_module.Image], image_module.Image, int], None]] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageSegmenterOptionsProto: def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
"""Generates an ImageSegmenterOptions protobuf object.""" """Generates an ImageSegmenterOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
segmenter_options_proto = _SegmenterOptionsProto( segmenter_options_proto = _SegmenterOptionsProto(
output_type=self.output_type.value, activation=self.activation.value) output_type=self.output_type.value, activation=self.activation.value)
return _ImageSegmenterOptionsProto( return _ImageSegmenterGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,
segmenter_options=segmenter_options_proto) segmenter_options=segmenter_options_proto)

View File

@ -31,6 +31,7 @@ mediapipe_files(srcs = [
"mobilenet_v2_1.0_224_quant.tflite", "mobilenet_v2_1.0_224_quant.tflite",
"mobilenet_v2_1.0_224_quant_without_metadata.tflite", "mobilenet_v2_1.0_224_quant_without_metadata.tflite",
"mobilenet_v2_1.0_224_without_metadata.tflite", "mobilenet_v2_1.0_224_without_metadata.tflite",
"movie_review.tflite",
]) ])
exports_files([ exports_files([
@ -54,6 +55,11 @@ exports_files([
"labels.txt", "labels.txt",
"mobilenet_v2_1.0_224.json", "mobilenet_v2_1.0_224.json",
"mobilenet_v2_1.0_224_quant.json", "mobilenet_v2_1.0_224_quant.json",
"input_text_tensor_meta.json",
"input_text_tensor_default_meta.json",
"movie_review_labels.txt",
"regex_vocab.txt",
"movie_review.json",
]) ])
filegroup( filegroup(
@ -67,6 +73,7 @@ filegroup(
"mobilenet_v2_1.0_224_quant.tflite", "mobilenet_v2_1.0_224_quant.tflite",
"mobilenet_v2_1.0_224_quant_without_metadata.tflite", "mobilenet_v2_1.0_224_quant_without_metadata.tflite",
"mobilenet_v2_1.0_224_without_metadata.tflite", "mobilenet_v2_1.0_224_without_metadata.tflite",
"movie_review.tflite",
], ],
) )
@ -86,9 +93,14 @@ filegroup(
"input_image_tensor_float_meta.json", "input_image_tensor_float_meta.json",
"input_image_tensor_uint8_meta.json", "input_image_tensor_uint8_meta.json",
"input_image_tensor_unsupported_meta.json", "input_image_tensor_unsupported_meta.json",
"input_text_tensor_default_meta.json",
"input_text_tensor_meta.json",
"labels.txt", "labels.txt",
"mobilenet_v2_1.0_224.json", "mobilenet_v2_1.0_224.json",
"mobilenet_v2_1.0_224_quant.json", "mobilenet_v2_1.0_224_quant.json",
"movie_review.json",
"movie_review_labels.txt",
"regex_vocab.txt",
"score_calibration.txt", "score_calibration.txt",
"score_calibration_file_meta.json", "score_calibration_file_meta.json",
"score_calibration_tensor_meta.json", "score_calibration_tensor_meta.json",

View File

@ -0,0 +1,17 @@
{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
}
]
}
]
}

View File

@ -0,0 +1,34 @@
{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "input text",
"description": "The input string.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"process_units": [
{
"options_type": "RegexTokenizerOptions",
"options": {
"delim_regex_pattern": "[^\\w\\']+",
"vocab_file": [
{
"name": "vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors.",
"type": "VOCABULARY"
}
]
}
}
],
"stats": {
}
}
]
}
]
}

View File

@ -0,0 +1,63 @@
{
"name": "TextClassifier",
"description": "Classify the input text into a set of known categories.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "input_text",
"description": "Embedding vectors representing the input text to be processed.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"process_units": [
{
"options_type": "RegexTokenizerOptions",
"options": {
"delim_regex_pattern": "[^\\w\\']+",
"vocab_file": [
{
"name": "regex_vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors.",
"type": "VOCABULARY"
}
]
}
}
],
"stats": {
}
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
]
}
],
"min_parser_version": "1.2.1"
}

View File

@ -0,0 +1,2 @@
Negative
Positive

File diff suppressed because it is too large Load Diff

View File

@ -28,6 +28,7 @@ mediapipe_files(srcs = [
"bert_text_classifier.tflite", "bert_text_classifier.tflite",
"mobilebert_embedding_with_metadata.tflite", "mobilebert_embedding_with_metadata.tflite",
"mobilebert_with_metadata.tflite", "mobilebert_with_metadata.tflite",
"regex_one_embedding_with_metadata.tflite",
"test_model_text_classifier_bool_output.tflite", "test_model_text_classifier_bool_output.tflite",
"test_model_text_classifier_with_regex_tokenizer.tflite", "test_model_text_classifier_with_regex_tokenizer.tflite",
"universal_sentence_encoder_qa_with_metadata.tflite", "universal_sentence_encoder_qa_with_metadata.tflite",
@ -92,6 +93,11 @@ filegroup(
srcs = ["mobilebert_embedding_with_metadata.tflite"], srcs = ["mobilebert_embedding_with_metadata.tflite"],
) )
filegroup(
name = "regex_embedding_with_metadata",
srcs = ["regex_one_embedding_with_metadata.tflite"],
)
filegroup( filegroup(
name = "universal_sentence_encoder_qa", name = "universal_sentence_encoder_qa",
data = ["universal_sentence_encoder_qa_with_metadata.tflite"], data = ["universal_sentence_encoder_qa_with_metadata.tflite"],

View File

@ -144,8 +144,13 @@ filegroup(
) )
# Gestures related models. Visible to model_maker. # Gestures related models. Visible to model_maker.
# TODO: Upload canned gesture model and gesture embedding model to GCS after Model Card approval
filegroup( filegroup(
name = "test_gesture_models", name = "test_gesture_models",
srcs = [
"hand_landmark_full.tflite",
"palm_detection_full.tflite",
],
visibility = [ visibility = [
"//mediapipe/model_maker:__subpackages__", "//mediapipe/model_maker:__subpackages__",
"//mediapipe/tasks:internal", "//mediapipe/tasks:internal",

106
mediapipe/tasks/web/BUILD Normal file
View File

@ -0,0 +1,106 @@
# This contains the MediaPipe Tasks NPM package definitions.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm")
load("@npm//@bazel/rollup:index.bzl", "rollup_bundle")
package(default_visibility = ["//mediapipe/tasks:internal"])
# Audio
mediapipe_ts_library(
name = "audio_lib",
srcs = ["audio.ts"],
deps = ["//mediapipe/tasks/web/audio:audio_lib"],
)
rollup_bundle(
name = "audio_bundle",
config_file = "rollup.config.mjs",
entry_point = "audio.ts",
output_dir = False,
deps = [
":audio_lib",
"@npm//@rollup/plugin-commonjs",
"@npm//@rollup/plugin-node-resolve",
],
)
pkg_npm(
name = "audio_pkg",
package_name = "__PACKAGE_NAME__",
srcs = ["package.json"],
substitutions = {
"__PACKAGE_NAME__": "@mediapipe/tasks-audio",
"__DESCRIPTION__": "MediaPipe Audio Tasks",
"__BUNDLE__": "audio_bundle.js",
},
tgz = "audio.tgz",
deps = [":audio_bundle"],
)
# Text
mediapipe_ts_library(
name = "text_lib",
srcs = ["text.ts"],
deps = ["//mediapipe/tasks/web/text:text_lib"],
)
rollup_bundle(
name = "text_bundle",
config_file = "rollup.config.mjs",
entry_point = "text.ts",
output_dir = False,
deps = [
":text_lib",
"@npm//@rollup/plugin-commonjs",
"@npm//@rollup/plugin-node-resolve",
],
)
pkg_npm(
name = "text_pkg",
package_name = "__PACKAGE_NAME__",
srcs = ["package.json"],
substitutions = {
"__PACKAGE_NAME__": "@mediapipe/tasks-text",
"__DESCRIPTION__": "MediaPipe Text Tasks",
"__BUNDLE__": "text_bundle.js",
},
tgz = "text.tgz",
deps = [":text_bundle"],
)
# Vision
mediapipe_ts_library(
name = "vision_lib",
srcs = ["vision.ts"],
deps = ["//mediapipe/tasks/web/vision:vision_lib"],
)
rollup_bundle(
name = "vision_bundle",
config_file = "rollup.config.mjs",
entry_point = "vision.ts",
output_dir = False,
deps = [
":vision_lib",
"@npm//@rollup/plugin-commonjs",
"@npm//@rollup/plugin-node-resolve",
],
)
pkg_npm(
name = "vision_pkg",
package_name = "__PACKAGE_NAME__",
srcs = ["package.json"],
substitutions = {
"__PACKAGE_NAME__": "@mediapipe/tasks-vision",
"__DESCRIPTION__": "MediaPipe Vision Tasks",
"__BUNDLE__": "vision_bundle.js",
},
tgz = "vision.tgz",
deps = [":vision_bundle"],
)

View File

@ -0,0 +1,17 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
export * from '../../tasks/web/audio/index';

View File

@ -2,6 +2,8 @@
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
mediapipe_ts_library( mediapipe_ts_library(
name = "audio_lib", name = "audio_lib",
srcs = ["index.ts"], srcs = ["index.ts"],

View File

@ -119,8 +119,8 @@ export class AudioClassifier extends TaskRunner {
*/ */
async setOptions(options: AudioClassifierOptions): Promise<void> { async setOptions(options: AudioClassifierOptions): Promise<void> {
if (options.baseOptions) { if (options.baseOptions) {
const baseOptionsProto = const baseOptionsProto = await convertBaseOptionsToProto(
await convertBaseOptionsToProto(options.baseOptions); options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto); this.options.setBaseOptions(baseOptionsProto);
} }
@ -198,7 +198,7 @@ export class AudioClassifier extends TaskRunner {
classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM); classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM);
classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM); classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM);
classifierNode.addOutputStream( classifierNode.addOutputStream(
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM); 'CLASSIFICATIONS:' + CLASSIFICATION_RESULT_STREAM);
classifierNode.setOptions(calculatorOptions); classifierNode.setOptions(calculatorOptions);
graphConfig.addNode(classifierNode); graphConfig.addNode(classifierNode);

View File

@ -26,6 +26,8 @@ mediapipe_ts_library(
name = "base_options", name = "base_options",
srcs = ["base_options.ts"], srcs = ["base_options.ts"],
deps = [ deps = [
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",

View File

@ -14,6 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb';
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
import {BaseOptions} from '../../../../tasks/web/core/base_options'; import {BaseOptions} from '../../../../tasks/web/core/base_options';
@ -25,26 +27,60 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options';
* Converts a BaseOptions API object to its Protobuf representation. * Converts a BaseOptions API object to its Protobuf representation.
* @throws If neither a model assset path or buffer is provided * @throws If neither a model assset path or buffer is provided
*/ */
export async function convertBaseOptionsToProto(baseOptions: BaseOptions): export async function convertBaseOptionsToProto(
Promise<BaseOptionsProto> { updatedOptions: BaseOptions,
if (baseOptions.modelAssetPath && baseOptions.modelAssetBuffer) { currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> {
throw new Error( const result =
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); currentOptions ? currentOptions.clone() : new BaseOptionsProto();
await configureExternalFile(updatedOptions, result);
configureAcceleration(updatedOptions, result);
return result;
}
/**
* Configues the `externalFile` option and validates that a single model is
* provided.
*/
async function configureExternalFile(
options: BaseOptions, proto: BaseOptionsProto) {
const externalFile = proto.getModelAsset() || new ExternalFile();
proto.setModelAsset(externalFile);
if (options.modelAssetPath || options.modelAssetBuffer) {
if (options.modelAssetPath && options.modelAssetBuffer) {
throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
}
let modelAssetBuffer = options.modelAssetBuffer;
if (!modelAssetBuffer) {
const response = await fetch(options.modelAssetPath!.toString());
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
}
externalFile.setFileContent(modelAssetBuffer);
} }
if (!baseOptions.modelAssetPath && !baseOptions.modelAssetBuffer) {
if (!externalFile.hasFileContent()) {
throw new Error( throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
} }
}
let modelAssetBuffer = baseOptions.modelAssetBuffer;
if (!modelAssetBuffer) { /** Configues the `acceleration` option. */
const response = await fetch(baseOptions.modelAssetPath!.toString()); function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
modelAssetBuffer = new Uint8Array(await response.arrayBuffer()); if ('delegate' in options) {
} const acceleration = new Acceleration();
if (options.delegate === 'cpu') {
const proto = new BaseOptionsProto(); acceleration.setXnnpack(
const externalFile = new ExternalFile(); new InferenceCalculatorOptions.Delegate.Xnnpack());
externalFile.setFileContent(modelAssetBuffer); proto.setAcceleration(acceleration);
proto.setModelAsset(externalFile); } else if (options.delegate === 'gpu') {
return proto; acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
proto.setAcceleration(acceleration);
} else {
proto.clearAcceleration();
}
}
} }

View File

@ -22,10 +22,14 @@ export interface BaseOptions {
* The model path to the model asset file. Only one of `modelAssetPath` or * The model path to the model asset file. Only one of `modelAssetPath` or
* `modelAssetBuffer` can be set. * `modelAssetBuffer` can be set.
*/ */
modelAssetPath?: string; modelAssetPath?: string|undefined;
/** /**
* A buffer containing the model aaset. Only one of `modelAssetPath` or * A buffer containing the model aaset. Only one of `modelAssetPath` or
* `modelAssetBuffer` can be set. * `modelAssetBuffer` can be set.
*/ */
modelAssetBuffer?: Uint8Array; modelAssetBuffer?: Uint8Array|undefined;
/** Overrides the default backend to use for the provided model. */
delegate?: 'cpu'|'gpu'|undefined;
} }

View File

@ -0,0 +1,15 @@
{
"name": "__PACKAGE_NAME__",
"version": "__VERSION__",
"description": "__DESCRIPTION__",
"main": "__BUNDLE__",
"module": "__BUNDLE__",
"author": "mediapipe@google.com",
"license": "Apache-2.0",
"type": "module",
"dependencies": {
"google-protobuf": "^3.21.2"
},
"homepage": "http://mediapipe.dev",
"keywords": [ "AR", "ML", "Augmented", "MediaPipe", "MediaPipe Tasks" ]
}

Some files were not shown because too many files have changed in this diff Show More