Merge branch 'google:master' into image-embedder-python
This commit is contained in:
commit
36c50ff8f3
|
@ -1324,6 +1324,7 @@ cc_test(
|
|||
name = "image_to_tensor_utils_test",
|
||||
srcs = ["image_to_tensor_utils_test.cc"],
|
||||
deps = [
|
||||
":image_to_tensor_calculator_cc_proto",
|
||||
":image_to_tensor_utils",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
|
|
|
@ -330,9 +330,8 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
||||
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
||||
<< "Wrong output dims size: " << output_shape.dims.size();
|
||||
RET_CHECK_EQ(output_shape.dims[0], 1)
|
||||
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
||||
"converter.";
|
||||
RET_CHECK_GE(output_shape.dims[0], 1)
|
||||
<< "The batch dimension needs to be greater or equal to 1.";
|
||||
RET_CHECK_EQ(output_shape.dims[3], 3)
|
||||
<< "Wrong output channel: " << output_shape.dims[3];
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -172,7 +172,7 @@ constexpr char kValidIntProto[] = R"(
|
|||
output_tensor_height: 200
|
||||
)";
|
||||
|
||||
TEST(ValidateOptionOutputDims, ValidProtos) {
|
||||
TEST(ValidateOptionOutputDims, ImageToTensorCalcOptions) {
|
||||
const auto float_options =
|
||||
mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(
|
||||
kValidFloatProto);
|
||||
|
@ -202,7 +202,7 @@ TEST(ValidateOptionOutputDims, EmptyProto) {
|
|||
HasSubstr("Valid output tensor width is required")));
|
||||
}
|
||||
|
||||
TEST(GetOutputTensorParams, SetValues) {
|
||||
TEST(GetOutputTensorParams, ImageToTensorCalcOptionsSetValues) {
|
||||
// Test int range with ImageToTensorCalculatorOptions.
|
||||
const auto int_options =
|
||||
mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(
|
||||
|
|
BIN
mediapipe/calculators/tensor/testdata/image_to_tensor/crop_empty.png
vendored
Normal file
BIN
mediapipe/calculators/tensor/testdata/image_to_tensor/crop_empty.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 319 B |
BIN
mediapipe/calculators/tensor/testdata/image_to_tensor/crop_rect1.png
vendored
Normal file
BIN
mediapipe/calculators/tensor/testdata/image_to_tensor/crop_rect1.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 12 KiB |
BIN
mediapipe/calculators/tensor/testdata/image_to_tensor/crop_rect2.png
vendored
Normal file
BIN
mediapipe/calculators/tensor/testdata/image_to_tensor/crop_rect2.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 19 KiB |
BIN
mediapipe/calculators/tensor/testdata/image_to_tensor/crop_rect3.png
vendored
Normal file
BIN
mediapipe/calculators/tensor/testdata/image_to_tensor/crop_rect3.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 17 KiB |
|
@ -106,6 +106,13 @@ class MultiPort : public Single {
|
|||
return Single{&GetWithAutoGrow(&vec_, index)};
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
auto Cast() {
|
||||
using SingleCastT =
|
||||
std::invoke_result_t<decltype(&Single::template Cast<U>), Single*>;
|
||||
return MultiPort<SingleCastT>(&vec_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<Base>>& vec_;
|
||||
};
|
||||
|
|
|
@ -445,6 +445,57 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
|
|||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
TEST(BuilderTest, MultiPortIsCastToMultiPort) {
|
||||
builder::Graph graph;
|
||||
builder::MultiSource<AnyType> any_input = graph.In("ANY_INPUT");
|
||||
builder::MultiSource<int> int_input = any_input.Cast<int>();
|
||||
builder::MultiDestination<AnyType> any_output = graph.Out("ANY_OUTPUT");
|
||||
builder::MultiDestination<int> int_output = any_output.Cast<int>();
|
||||
int_input >> int_output;
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "ANY_INPUT:__stream_0"
|
||||
output_stream: "ANY_OUTPUT:__stream_0"
|
||||
)pb");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) {
|
||||
builder::Graph graph;
|
||||
builder::MultiSource<AnyType> any_multi_input = graph.In("ANY_INPUT");
|
||||
builder::Source<AnyType> any_input = any_multi_input;
|
||||
builder::MultiDestination<AnyType> any_multi_output = graph.Out("ANY_OUTPUT");
|
||||
builder::Destination<AnyType> any_output = any_multi_output;
|
||||
any_input >> any_output;
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "ANY_INPUT:__stream_0"
|
||||
output_stream: "ANY_OUTPUT:__stream_0"
|
||||
)pb");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) {
|
||||
builder::Graph graph;
|
||||
builder::Source<int> int_input = graph.In("INT_INPUT").Cast<int>();
|
||||
builder::Source<AnyType> any_input = graph.In("ANY_OUTPUT");
|
||||
builder::Destination<int> int_output = graph.Out("INT_OUTPUT").Cast<int>();
|
||||
builder::Destination<AnyType> any_output = graph.Out("ANY_OUTPUT");
|
||||
int_input >> int_output;
|
||||
any_input >> any_output;
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "ANY_OUTPUT:__stream_0"
|
||||
input_stream: "INT_INPUT:__stream_1"
|
||||
output_stream: "ANY_OUTPUT:__stream_0"
|
||||
output_stream: "INT_OUTPUT:__stream_1"
|
||||
)pb");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -38,20 +38,18 @@ static pthread_key_t egl_release_thread_key;
|
|||
static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT;
|
||||
|
||||
static void EglThreadExitCallback(void* key_value) {
|
||||
#if defined(__ANDROID__)
|
||||
eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE,
|
||||
EGL_NO_CONTEXT);
|
||||
#else
|
||||
// Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display
|
||||
// parameter for eglMakeCurrent. This behavior is not portable to all EGL
|
||||
// implementations, and should be considered as an undocumented vendor
|
||||
// extension.
|
||||
EGLDisplay current_display = eglGetCurrentDisplay();
|
||||
if (current_display != EGL_NO_DISPLAY) {
|
||||
// Some implementations have chosen to allow EGL_NO_DISPLAY as a valid
|
||||
// display parameter for eglMakeCurrent. This behavior is not portable to
|
||||
// all EGL implementations, and should be considered as an undocumented
|
||||
// vendor extension.
|
||||
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
|
||||
//
|
||||
// NOTE: crashes on some Android devices (occurs with libGLES_meow.so).
|
||||
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
|
||||
EGL_NO_SURFACE, EGL_NO_CONTEXT);
|
||||
#endif
|
||||
// Instead, to release the current context, we pass the current display.
|
||||
// If the current display is already EGL_NO_DISPLAY, no context is current.
|
||||
eglMakeCurrent(current_display, EGL_NO_SURFACE, EGL_NO_SURFACE,
|
||||
EGL_NO_CONTEXT);
|
||||
}
|
||||
eglReleaseThread();
|
||||
}
|
||||
|
||||
|
|
|
@ -20,3 +20,10 @@ package_group(
|
|||
"//mediapipe/model_maker/...",
|
||||
],
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "1p_client",
|
||||
packages = [
|
||||
"//research/privacy/learning/fl_eval/pcvr/...",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -19,7 +19,6 @@ import tempfile
|
|||
from typing import Optional
|
||||
|
||||
|
||||
# TODO: Integrate this class into ImageClassifier and other tasks.
|
||||
@dataclasses.dataclass
|
||||
class BaseHParams:
|
||||
"""Hyperparameters used for training models.
|
||||
|
|
|
@ -45,7 +45,10 @@ py_library(
|
|||
srcs = ["classifier.py"],
|
||||
deps = [
|
||||
":custom_model",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -13,24 +13,24 @@
|
|||
# limitations under the License.
|
||||
"""Custom classifier."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from typing import Any, List
|
||||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
|
||||
from mediapipe.model_maker.python.core.data import dataset
|
||||
from mediapipe.model_maker.python.core.tasks import custom_model
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
|
||||
|
||||
class Classifier(custom_model.CustomModel):
|
||||
"""An abstract base class that represents a TensorFlow classifier."""
|
||||
|
||||
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool):
|
||||
"""Initilizes a classifier with its specifications.
|
||||
def __init__(self, model_spec: Any, label_names: Sequence[str],
|
||||
shuffle: bool):
|
||||
"""Initializes a classifier with its specifications.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
|
@ -40,6 +40,59 @@ class Classifier(custom_model.CustomModel):
|
|||
super(Classifier, self).__init__(model_spec, shuffle)
|
||||
self._label_names = label_names
|
||||
self._num_classes = len(label_names)
|
||||
self._model: tf.keras.Model = None
|
||||
self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None
|
||||
self._loss_function: Union[str, tf.keras.losses.Loss] = None
|
||||
self._metric_function: Union[str, tf.keras.metrics.Metric] = None
|
||||
self._callbacks: Sequence[tf.keras.callbacks.Callback] = None
|
||||
self._hparams: hp.BaseHParams = None
|
||||
self._history: tf.keras.callbacks.History = None
|
||||
|
||||
# TODO: Integrate this into all Model Maker tasks.
|
||||
def _train_model(self,
|
||||
train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset,
|
||||
preprocessor: Optional[Callable[..., bool]] = None):
|
||||
"""Trains the classifier model.
|
||||
|
||||
Compiles and fits the tf.keras `_model` and records the `_history`.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
preprocessor: An optional data preprocessor that can be used when
|
||||
generating a tf.data.Dataset.
|
||||
"""
|
||||
tf.compat.v1.logging.info('Training the models...')
|
||||
if len(train_data) < self._hparams.batch_size:
|
||||
raise ValueError(
|
||||
f'The size of the train_data {len(train_data)} can\'t be smaller than'
|
||||
f' batch_size {self._hparams.batch_size}. To solve this problem, set'
|
||||
' the batch_size smaller or increase the size of the train_data.')
|
||||
|
||||
train_dataset = train_data.gen_tf_dataset(
|
||||
batch_size=self._hparams.batch_size,
|
||||
is_training=True,
|
||||
shuffle=self._shuffle,
|
||||
preprocess=preprocessor)
|
||||
self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
|
||||
steps_per_epoch=self._hparams.steps_per_epoch,
|
||||
batch_size=self._hparams.batch_size,
|
||||
train_data=train_data)
|
||||
train_dataset = train_dataset.take(count=self._hparams.steps_per_epoch)
|
||||
validation_dataset = validation_data.gen_tf_dataset(
|
||||
batch_size=self._hparams.batch_size,
|
||||
is_training=False,
|
||||
preprocess=preprocessor)
|
||||
self._model.compile(
|
||||
optimizer=self._optimizer,
|
||||
loss=self._loss_function,
|
||||
metrics=[self._metric_function])
|
||||
self._history = self._model.fit(
|
||||
x=train_dataset,
|
||||
epochs=self._hparams.epochs,
|
||||
validation_data=validation_dataset,
|
||||
callbacks=self._callbacks)
|
||||
|
||||
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
||||
"""Evaluates the classifier with the provided evaluation dataset.
|
||||
|
|
|
@ -35,6 +35,7 @@ py_library(
|
|||
name = "model_util",
|
||||
srcs = ["model_util.py"],
|
||||
deps = [
|
||||
":file_util",
|
||||
":quantization",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
],
|
||||
|
@ -50,6 +51,18 @@ py_test(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "file_util",
|
||||
srcs = ["file_util.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "file_util_test",
|
||||
srcs = ["file_util_test.py"],
|
||||
data = ["//mediapipe/model_maker/python/core/utils/testdata"],
|
||||
deps = [":file_util"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "loss_functions",
|
||||
srcs = ["loss_functions.py"],
|
||||
|
|
36
mediapipe/model_maker/python/core/utils/file_util.py
Normal file
36
mediapipe/model_maker/python/core/utils/file_util.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for files."""
|
||||
|
||||
import os
|
||||
|
||||
# resources dependency
|
||||
|
||||
|
||||
def get_absolute_path(file_path: str) -> str:
|
||||
"""Gets the absolute path of a file.
|
||||
|
||||
Args:
|
||||
file_path: The path to a file relative to the `mediapipe` dir
|
||||
|
||||
Returns:
|
||||
The full path of the file
|
||||
"""
|
||||
# Extract the file path before mediapipe/ as the `base_dir`. By joining it
|
||||
# with the `path` which defines the relative path under mediapipe/, it
|
||||
# yields to the absolute path of the model files directory.
|
||||
cwd = os.path.dirname(__file__)
|
||||
base_dir = cwd[:cwd.rfind('mediapipe')]
|
||||
absolute_path = os.path.join(base_dir, file_path)
|
||||
return absolute_path
|
29
mediapipe/model_maker/python/core/utils/file_util_test.py
Normal file
29
mediapipe/model_maker/python/core/utils/file_util_test.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
from mediapipe.model_maker.python.core.utils import file_util
|
||||
|
||||
|
||||
class FileUtilTest(absltest.TestCase):
|
||||
|
||||
def test_get_absolute_path(self):
|
||||
test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt'
|
||||
absolute_path = file_util.get_absolute_path(test_file)
|
||||
self.assertTrue(os.path.exists(absolute_path))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for keras models."""
|
||||
"""Utilities for models."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -19,21 +19,33 @@ from __future__ import print_function
|
|||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
# resources dependency
|
||||
from mediapipe.model_maker.python.core.data import dataset
|
||||
from mediapipe.model_maker.python.core.utils import file_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
|
||||
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
|
||||
ESTIMITED_STEPS_PER_EPOCH = 1000
|
||||
|
||||
|
||||
def get_default_callbacks(
|
||||
export_dir: str) -> Sequence[tf.keras.callbacks.Callback]:
|
||||
"""Gets default callbacks."""
|
||||
summary_dir = os.path.join(export_dir, 'summaries')
|
||||
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
||||
|
||||
checkpoint_path = os.path.join(export_dir, 'checkpoint')
|
||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
checkpoint_path, save_weights_only=True)
|
||||
return [summary_callback, checkpoint_callback]
|
||||
|
||||
|
||||
def load_keras_model(model_path: str,
|
||||
compile_on_load: bool = False) -> tf.keras.Model:
|
||||
"""Loads a tensorflow Keras model from file and returns the Keras model.
|
||||
|
@ -49,16 +61,26 @@ def load_keras_model(model_path: str,
|
|||
Returns:
|
||||
A tensorflow Keras model.
|
||||
"""
|
||||
# Extract the file path before mediapipe/ as the `base_dir`. By joining it
|
||||
# with the `model_path` which defines the relative path under mediapipe/, it
|
||||
# yields to the aboslution path of the model files directory.
|
||||
cwd = os.path.dirname(__file__)
|
||||
base_dir = cwd[:cwd.rfind('mediapipe')]
|
||||
absolute_path = os.path.join(base_dir, model_path)
|
||||
absolute_path = file_util.get_absolute_path(model_path)
|
||||
return tf.keras.models.load_model(
|
||||
absolute_path, custom_objects={'tf': tf}, compile=compile_on_load)
|
||||
|
||||
|
||||
def load_tflite_model_buffer(model_path: str) -> bytearray:
|
||||
"""Loads a TFLite model buffer from file.
|
||||
|
||||
Args:
|
||||
model_path: Relative path to a TFLite file
|
||||
|
||||
Returns:
|
||||
A TFLite model buffer
|
||||
"""
|
||||
absolute_path = file_util.get_absolute_path(model_path)
|
||||
with tf.io.gfile.GFile(absolute_path, 'rb') as f:
|
||||
tflite_model_buffer = f.read()
|
||||
return tflite_model_buffer
|
||||
|
||||
|
||||
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
train_data: Optional[dataset.Dataset] = None) -> int:
|
||||
|
@ -174,7 +196,7 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
|||
lambda: self.decay_schedule_fn(step),
|
||||
name=name)
|
||||
|
||||
def get_config(self) -> Dict[Text, Any]:
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'initial_learning_rate': self.initial_learning_rate,
|
||||
'decay_schedule_fn': self.decay_schedule_fn,
|
||||
|
|
|
@ -24,7 +24,7 @@ from mediapipe.model_maker.python.core.utils import test_util
|
|||
|
||||
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_load_model(self):
|
||||
def test_load_keras_model(self):
|
||||
input_dim = 4
|
||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||
|
@ -36,6 +36,19 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
|||
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
|
||||
self.assertTrue((model_output == loaded_model_output).all())
|
||||
|
||||
def test_load_tflite_model_buffer(self):
|
||||
input_dim = 4
|
||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
tflite_model = model_util.convert_to_tflite(model)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
|
||||
|
||||
tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file)
|
||||
test_util.test_tflite(
|
||||
keras_model=model,
|
||||
tflite_model=tflite_model_buffer,
|
||||
size=[1, input_dim])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='input_only_steps_per_epoch',
|
||||
|
|
23
mediapipe/model_maker/python/core/utils/testdata/BUILD
vendored
Normal file
23
mediapipe/model_maker/python/core/utils/testdata/BUILD
vendored
Normal file
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe/model_maker/python/core/utils:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = ["test.txt"],
|
||||
)
|
0
mediapipe/model_maker/python/core/utils/testdata/test.txt
vendored
Normal file
0
mediapipe/model_maker/python/core/utils/testdata/test.txt
vendored
Normal file
|
@ -28,6 +28,8 @@ py_library(
|
|||
":dataset",
|
||||
":hyperparameters",
|
||||
":image_classifier",
|
||||
":image_classifier_options",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
],
|
||||
)
|
||||
|
@ -58,6 +60,24 @@ py_test(
|
|||
py_library(
|
||||
name = "hyperparameters",
|
||||
srcs = ["hyperparameters.py"],
|
||||
deps = [
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_options",
|
||||
srcs = ["model_options.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_classifier_options",
|
||||
srcs = ["image_classifier_options.py"],
|
||||
deps = [
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -74,6 +94,8 @@ py_library(
|
|||
srcs = ["image_classifier.py"],
|
||||
deps = [
|
||||
":hyperparameters",
|
||||
":image_classifier_options",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
":train_image_classifier_lib",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
|
@ -99,6 +121,7 @@ py_library(
|
|||
|
||||
py_test(
|
||||
name = "image_classifier_test",
|
||||
size = "large",
|
||||
srcs = ["image_classifier_test.py"],
|
||||
shard_count = 2,
|
||||
tags = ["requires-net:external"],
|
||||
|
|
|
@ -16,10 +16,14 @@
|
|||
from mediapipe.model_maker.python.vision.image_classifier import dataset
|
||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.image_classifier import image_classifier
|
||||
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_options
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec
|
||||
|
||||
ImageClassifier = image_classifier.ImageClassifier
|
||||
HParams = hyperparameters.HParams
|
||||
Dataset = dataset.Dataset
|
||||
ModelOptions = model_options.ImageClassifierModelOptions
|
||||
ModelSpec = model_spec.ModelSpec
|
||||
SupportedModels = model_spec.SupportedModels
|
||||
ImageClassifierOptions = image_classifier_options.ImageClassifierOptions
|
||||
|
|
|
@ -14,28 +14,20 @@
|
|||
"""Hyperparameters for training image classification models."""
|
||||
|
||||
import dataclasses
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
|
||||
|
||||
# TODO: Expose other hyperparameters, e.g. data augmentation
|
||||
# hyperparameters if requested.
|
||||
@dataclasses.dataclass
|
||||
class HParams:
|
||||
class HParams(hp.BaseHParams):
|
||||
"""The hyperparameters for training image classifiers.
|
||||
|
||||
The hyperparameters include:
|
||||
# Parameters about training data.
|
||||
Attributes:
|
||||
learning_rate: Learning rate to use for gradient descent training.
|
||||
batch_size: Batch size for training.
|
||||
epochs: Number of training iterations over the dataset.
|
||||
do_fine_tuning: If true, the base module is trained together with the
|
||||
classification layer on top.
|
||||
shuffle: A boolean controlling if shuffle the dataset. Default to false.
|
||||
|
||||
# Parameters about training configuration
|
||||
train_epochs: Training will do this many iterations over the dataset.
|
||||
batch_size: Each training step samples a batch of this many images.
|
||||
learning_rate: The learning rate to use for gradient descent training.
|
||||
dropout_rate: The fraction of the input units to drop, used in dropout
|
||||
layer.
|
||||
l1_regularizer: A regularizer that applies a L1 regularization penalty.
|
||||
l2_regularizer: A regularizer that applies a L2 regularization penalty.
|
||||
label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for
|
||||
|
@ -43,32 +35,21 @@ class HParams:
|
|||
do_data_augmentation: A boolean controlling whether the training dataset is
|
||||
augmented by randomly distorting input images, including random cropping,
|
||||
flipping, etc. See utils.image_preprocessing documentation for details.
|
||||
steps_per_epoch: An optional integer indicate the number of training steps
|
||||
per epoch. If not set, the training pipeline calculates the default steps
|
||||
per epoch as the training dataset size devided by batch size.
|
||||
decay_samples: Number of training samples used to calculate the decay steps
|
||||
and create the training optimizer.
|
||||
warmup_steps: Number of warmup steps for a linear increasing warmup schedule
|
||||
on learning rate. Used to set up warmup schedule by model_util.WarmUp.
|
||||
|
||||
# Parameters about the saved checkpoint
|
||||
model_dir: The location of model checkpoint files and exported model files.
|
||||
on learning rate. Used to set up warmup schedule by model_util.WarmUp.s
|
||||
"""
|
||||
# Parameters about training data
|
||||
do_fine_tuning: bool = False
|
||||
shuffle: bool = False
|
||||
# Parameters from BaseHParams class.
|
||||
learning_rate: float = 0.001
|
||||
batch_size: int = 2
|
||||
epochs: int = 10
|
||||
# Parameters about training configuration
|
||||
train_epochs: int = 5
|
||||
batch_size: int = 32
|
||||
learning_rate: float = 0.005
|
||||
dropout_rate: float = 0.2
|
||||
do_fine_tuning: bool = False
|
||||
l1_regularizer: float = 0.0
|
||||
l2_regularizer: float = 0.0001
|
||||
label_smoothing: float = 0.1
|
||||
do_data_augmentation: bool = True
|
||||
steps_per_epoch: Optional[int] = None
|
||||
# TODO: Use lr_decay in hp.baseHParams to infer decay_samples.
|
||||
decay_samples: int = 10000 * 256
|
||||
warmup_epochs: int = 2
|
||||
|
||||
# Parameters about the saved checkpoint
|
||||
model_dir: str = tempfile.mkdtemp()
|
||||
|
|
|
@ -25,6 +25,8 @@ from mediapipe.model_maker.python.core.utils import model_util
|
|||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.vision.core import image_preprocessing
|
||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
||||
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
|
||||
|
@ -35,17 +37,20 @@ class ImageClassifier(classifier.Classifier):
|
|||
"""ImageClassifier for building image classification model."""
|
||||
|
||||
def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
|
||||
hparams: hp.HParams):
|
||||
hparams: hp.HParams,
|
||||
model_options: model_opt.ImageClassifierModelOptions):
|
||||
"""Initializes ImageClassifier class.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
label_names: A list of label names for the classes.
|
||||
hparams: The hyperparameters for training image classifier.
|
||||
model_options: Model options for creating image classifier.
|
||||
"""
|
||||
super().__init__(
|
||||
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
|
||||
self._hparams = hparams
|
||||
self._model_options = model_options
|
||||
self._preprocess = image_preprocessing.Preprocessor(
|
||||
input_shape=self._model_spec.input_image_shape,
|
||||
num_classes=self._num_classes,
|
||||
|
@ -57,30 +62,37 @@ class ImageClassifier(classifier.Classifier):
|
|||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model_spec: ms.SupportedModels,
|
||||
train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset,
|
||||
hparams: Optional[hp.HParams] = None,
|
||||
options: image_classifier_options.ImageClassifierOptions,
|
||||
) -> 'ImageClassifier':
|
||||
"""Creates and trains an image classifier.
|
||||
|
||||
Loads data and trains the model based on data for image classification.
|
||||
Loads data and trains the model based on data for image classification. If a
|
||||
checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
|
||||
directory, the training process will load the weight from the checkpoint
|
||||
file for continual training.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
hparams: Hyperparameters for training image classifier.
|
||||
options: configuration to create image classifier.
|
||||
|
||||
Returns:
|
||||
An instance based on ImageClassifier.
|
||||
"""
|
||||
if hparams is None:
|
||||
hparams = hp.HParams()
|
||||
if options.hparams is None:
|
||||
options.hparams = hp.HParams()
|
||||
|
||||
spec = ms.SupportedModels.get(model_spec)
|
||||
if options.model_options is None:
|
||||
options.model_options = model_opt.ImageClassifierModelOptions()
|
||||
|
||||
spec = ms.SupportedModels.get(options.supported_model)
|
||||
image_classifier = cls(
|
||||
model_spec=spec, label_names=train_data.label_names, hparams=hparams)
|
||||
model_spec=spec,
|
||||
label_names=train_data.label_names,
|
||||
hparams=options.hparams,
|
||||
model_options=options.model_options)
|
||||
|
||||
image_classifier._create_model()
|
||||
|
||||
|
@ -90,6 +102,7 @@ class ImageClassifier(classifier.Classifier):
|
|||
|
||||
return image_classifier
|
||||
|
||||
# TODO: Migrate to the shared training library of Model Maker.
|
||||
def _train(self, train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset):
|
||||
"""Trains the model with input train_data.
|
||||
|
@ -142,7 +155,7 @@ class ImageClassifier(classifier.Classifier):
|
|||
|
||||
self._model = tf.keras.Sequential([
|
||||
tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer,
|
||||
tf.keras.layers.Dropout(rate=self._hparams.dropout_rate),
|
||||
tf.keras.layers.Dropout(rate=self._model_options.dropout_rate),
|
||||
tf.keras.layers.Dense(
|
||||
units=self._num_classes,
|
||||
activation='softmax',
|
||||
|
@ -167,10 +180,10 @@ class ImageClassifier(classifier.Classifier):
|
|||
path is {self._hparams.model_dir}/{model_name}.
|
||||
quantization_config: The configuration for model quantization.
|
||||
"""
|
||||
if not tf.io.gfile.exists(self._hparams.model_dir):
|
||||
tf.io.gfile.makedirs(self._hparams.model_dir)
|
||||
tflite_file = os.path.join(self._hparams.model_dir, model_name)
|
||||
metadata_file = os.path.join(self._hparams.model_dir, 'metadata.json')
|
||||
if not tf.io.gfile.exists(self._hparams.export_dir):
|
||||
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
|
||||
|
||||
tflite_model = model_util.convert_to_tflite(
|
||||
model=self._model,
|
||||
|
@ -180,7 +193,7 @@ class ImageClassifier(classifier.Classifier):
|
|||
tflite_model,
|
||||
self._model_spec.mean_rgb,
|
||||
self._model_spec.stddev_rgb,
|
||||
labels=metadata_writer.Labels().add(self._label_names))
|
||||
labels=metadata_writer.Labels().add(list(self._label_names)))
|
||||
tflite_model_with_metadata, metadata_json = writer.populate()
|
||||
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
||||
with open(metadata_file, 'w') as f:
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Options for building image classifier."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ImageClassifierOptions:
|
||||
"""Configurable options for building image classifier.
|
||||
|
||||
Attributes:
|
||||
supported_model: A model from the SupportedModels enum.
|
||||
model_options: A set of options for configuring the selected model.
|
||||
hparams: A set of hyperparameters used to train the image classifier.
|
||||
"""
|
||||
supported_model: model_spec.SupportedModels
|
||||
model_options: Optional[model_opt.ImageClassifierModelOptions] = None
|
||||
hparams: Optional[hyperparameters.HParams] = None
|
|
@ -13,9 +13,13 @@
|
|||
# limitations under the License.
|
||||
|
||||
import filecmp
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from unittest import mock as unittest_mock
|
||||
from absl.testing import parameterized
|
||||
import mock
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -54,54 +58,74 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
super(ImageClassifierTest, self).setUp()
|
||||
all_data = self._gen_cmy_data()
|
||||
# Splits data, 90% data for training, 10% for testing
|
||||
self.train_data, self.test_data = all_data.split(0.9)
|
||||
self._train_data, self._test_data = all_data.split(0.9)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='mobilenet_v2',
|
||||
model_spec=image_classifier.SupportedModels.MOBILENET_V2,
|
||||
options=image_classifier.ImageClassifierOptions(
|
||||
supported_model=image_classifier.SupportedModels.MOBILENET_V2,
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
epochs=1,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
export_dir=tempfile.mkdtemp()))),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite0',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
||||
options=image_classifier.ImageClassifierOptions(
|
||||
supported_model=(
|
||||
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
epochs=1,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
export_dir=tempfile.mkdtemp()))),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite0_change_dropout_rate',
|
||||
options=image_classifier.ImageClassifierOptions(
|
||||
supported_model=(
|
||||
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
|
||||
model_options=image_classifier.ModelOptions(dropout_rate=0.1),
|
||||
hparams=image_classifier.HParams(
|
||||
epochs=1,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
export_dir=tempfile.mkdtemp()))),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite2',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2,
|
||||
options=image_classifier.ImageClassifierOptions(
|
||||
supported_model=(
|
||||
image_classifier.SupportedModels.EFFICIENTNET_LITE2),
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
epochs=1,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
export_dir=tempfile.mkdtemp()))),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite4',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4,
|
||||
options=image_classifier.ImageClassifierOptions(
|
||||
supported_model=(
|
||||
image_classifier.SupportedModels.EFFICIENTNET_LITE4),
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
epochs=1,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
export_dir=tempfile.mkdtemp()))),
|
||||
)
|
||||
def test_create_and_train_model(self,
|
||||
model_spec: image_classifier.SupportedModels,
|
||||
hparams: image_classifier.HParams):
|
||||
def test_create_and_train_model(
|
||||
self, options: image_classifier.ImageClassifierOptions):
|
||||
model = image_classifier.ImageClassifier.create(
|
||||
model_spec=model_spec,
|
||||
train_data=self.train_data,
|
||||
hparams=hparams,
|
||||
validation_data=self.test_data)
|
||||
self._test_accuracy(model)
|
||||
|
||||
def test_efficientnetlite0_model_train_and_export(self):
|
||||
hparams = image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)
|
||||
model = image_classifier.ImageClassifier.create(
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
||||
train_data=self.train_data,
|
||||
hparams=hparams,
|
||||
validation_data=self.test_data)
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=options)
|
||||
self._test_accuracy(model)
|
||||
|
||||
# Test export_model
|
||||
model.export_model()
|
||||
output_metadata_file = os.path.join(hparams.model_dir, 'metadata.json')
|
||||
output_tflite_file = os.path.join(hparams.model_dir, 'model.tflite')
|
||||
output_metadata_file = os.path.join(options.hparams.export_dir,
|
||||
'metadata.json')
|
||||
output_tflite_file = os.path.join(options.hparams.export_dir,
|
||||
'model.tflite')
|
||||
expected_metadata_file = test_utils.get_test_data_path('metadata.json')
|
||||
|
||||
self.assertTrue(os.path.exists(output_tflite_file))
|
||||
|
@ -111,10 +135,50 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
|
||||
|
||||
def test_continual_training_by_loading_checkpoint(self):
|
||||
mock_stdout = io.StringIO()
|
||||
with mock.patch('sys.stdout', mock_stdout):
|
||||
options = image_classifier.ImageClassifierOptions(
|
||||
supported_model=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
||||
hparams=image_classifier.HParams(
|
||||
epochs=5, batch_size=1, shuffle=True))
|
||||
model = image_classifier.ImageClassifier.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=options)
|
||||
model = image_classifier.ImageClassifier.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=options)
|
||||
self._test_accuracy(model)
|
||||
|
||||
self.assertRegex(mock_stdout.getvalue(), 'Resuming from')
|
||||
|
||||
def _test_accuracy(self, model, threshold=0.0):
|
||||
_, accuracy = model.evaluate(self.test_data)
|
||||
_, accuracy = model.evaluate(self._test_data)
|
||||
self.assertGreaterEqual(accuracy, threshold)
|
||||
|
||||
@unittest_mock.patch.object(
|
||||
image_classifier.hyperparameters,
|
||||
'HParams',
|
||||
autospec=True,
|
||||
return_value=image_classifier.HParams(epochs=1))
|
||||
@unittest_mock.patch.object(
|
||||
image_classifier.model_options,
|
||||
'ImageClassifierModelOptions',
|
||||
autospec=True,
|
||||
return_value=image_classifier.ModelOptions())
|
||||
def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
|
||||
self, mock_hparams, mock_model_options):
|
||||
options = image_classifier.ImageClassifierOptions(
|
||||
supported_model=(image_classifier.SupportedModels.EFFICIENTNET_LITE0))
|
||||
image_classifier.ImageClassifier.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=options)
|
||||
mock_hparams.assert_called_once()
|
||||
mock_model_options.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load compressed models from tensorflow_hub
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Configurable model options for image classifier models."""
|
||||
|
||||
import dataclasses
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ImageClassifierModelOptions:
|
||||
"""Configurable options for image classifier model.
|
||||
|
||||
Attributes:
|
||||
dropout_rate: The fraction of the input units to drop, used in dropout
|
||||
layer.
|
||||
"""
|
||||
dropout_rate: float = 0.2
|
|
@ -14,8 +14,6 @@
|
|||
"""Library to train model."""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
|
@ -49,18 +47,6 @@ def _create_optimizer(init_lr: float, decay_steps: int,
|
|||
return optimizer
|
||||
|
||||
|
||||
def _get_default_callbacks(model_dir: str) -> List[tf.keras.callbacks.Callback]:
|
||||
"""Gets default callbacks."""
|
||||
summary_dir = os.path.join(model_dir, 'summaries')
|
||||
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
||||
# Save checkpoint every 20 epochs.
|
||||
|
||||
checkpoint_path = os.path.join(model_dir, 'checkpoint')
|
||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
checkpoint_path, save_weights_only=True, period=20)
|
||||
return [summary_callback, checkpoint_callback]
|
||||
|
||||
|
||||
def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
||||
train_ds: tf.data.Dataset,
|
||||
validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History:
|
||||
|
@ -81,7 +67,8 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
|||
learning_rate = hparams.learning_rate * hparams.batch_size / 256
|
||||
|
||||
# Get decay steps.
|
||||
total_training_steps = hparams.steps_per_epoch * hparams.train_epochs
|
||||
# NOMUTANTS--(b/256493858):Plan to test it in the unified training library.
|
||||
total_training_steps = hparams.steps_per_epoch * hparams.epochs
|
||||
default_decay_steps = hparams.decay_samples // hparams.batch_size
|
||||
decay_steps = max(total_training_steps, default_decay_steps)
|
||||
|
||||
|
@ -92,11 +79,24 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
|||
loss = tf.keras.losses.CategoricalCrossentropy(
|
||||
label_smoothing=hparams.label_smoothing)
|
||||
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
||||
callbacks = _get_default_callbacks(hparams.model_dir)
|
||||
|
||||
summary_dir = os.path.join(hparams.export_dir, 'summaries')
|
||||
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
||||
# Save checkpoint every 5 epochs.
|
||||
checkpoint_path = os.path.join(hparams.export_dir, 'checkpoint')
|
||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
|
||||
save_weights_only=True,
|
||||
period=5)
|
||||
|
||||
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
|
||||
if latest_checkpoint:
|
||||
print(f'Resuming from {latest_checkpoint}')
|
||||
model.load_weights(latest_checkpoint)
|
||||
|
||||
# Train the model.
|
||||
return model.fit(
|
||||
x=train_ds,
|
||||
epochs=hparams.train_epochs,
|
||||
epochs=hparams.epochs,
|
||||
validation_data=validation_ds,
|
||||
callbacks=callbacks)
|
||||
callbacks=[summary_callback, checkpoint_callback])
|
||||
|
|
|
@ -87,7 +87,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "builtin_task_graphs",
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
|
@ -95,8 +94,10 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
||||
] + select({
|
||||
# TODO: Build text_classifier_graph on Windows.
|
||||
# TODO: Build audio_classifier_graph on Windows.
|
||||
"//mediapipe:windows": [],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
|
||||
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
|
||||
],
|
||||
}),
|
||||
|
|
|
@ -30,6 +30,15 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hand_landmarks_detection_result",
|
||||
hdrs = ["hand_landmarks_detection_result.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "category",
|
||||
srcs = ["category.cc"],
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace containers {
|
||||
|
||||
// The hand landmarks detection result from HandLandmarker, where each vector
|
||||
// element represents a single hand detected in the image.
|
||||
struct HandLandmarksDetectionResult {
|
||||
// Classification of handedness.
|
||||
std::vector<mediapipe::ClassificationList> handedness;
|
||||
// Detected hand landmarks in normalized image coordinates.
|
||||
std::vector<mediapipe::NormalizedLandmarkList> hand_landmarks;
|
||||
// Detected hand landmarks in world coordinates.
|
||||
std::vector<mediapipe::LandmarkList> hand_world_landmarks;
|
||||
};
|
||||
|
||||
} // namespace containers
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_
|
|
@ -17,6 +17,9 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe.tasks.components.proto;
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.proto";
|
||||
option java_outer_classname = "SegmenterOptionsProto";
|
||||
|
||||
// Shared options used by image segmentation tasks.
|
||||
message SegmenterOptions {
|
||||
// Optional output mask type.
|
||||
|
|
87
mediapipe/tasks/cc/text/text_embedder/BUILD
Normal file
87
mediapipe/tasks/cc/text/text_embedder/BUILD
Normal file
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "text_embedder",
|
||||
srcs = ["text_embedder.cc"],
|
||||
hdrs = ["text_embedder.h"],
|
||||
deps = [
|
||||
":text_embedder_graph",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:embedder_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/utils:cosine_similarity",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:base_task_api",
|
||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_embedder_graph",
|
||||
srcs = ["text_embedder_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/tasks/cc/components:text_preprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "text_embedder_test",
|
||||
srcs = ["text_embedder_test.cc"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
|
||||
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
|
||||
],
|
||||
deps = [
|
||||
":text_embedder",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
],
|
||||
)
|
30
mediapipe/tasks/cc/text/text_embedder/proto/BUILD
Normal file
30
mediapipe/tasks/cc/text/text_embedder/proto/BUILD
Normal file
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "text_embedder_graph_options_proto",
|
||||
srcs = ["text_embedder_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,36 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.text.text_embedder.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message TextEmbedderGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional TextEmbedderGraphOptions ext = 477589892;
|
||||
}
|
||||
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// Options for configuring the embedder behavior, such as normalization or
|
||||
// quantization.
|
||||
optional components.processors.proto.EmbedderOptions embedder_options = 2;
|
||||
}
|
104
mediapipe/tasks/cc/text/text_embedder/text_embedder.cc
Normal file
104
mediapipe/tasks/cc/text/text_embedder/text_embedder.cc
Normal file
|
@ -0,0 +1,104 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::text::text_embedder {
|
||||
namespace {
|
||||
|
||||
constexpr char kTextTag[] = "TEXT";
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
constexpr char kTextInStreamName[] = "text_in";
|
||||
constexpr char kEmbeddingsStreamName[] = "embeddings_out";
|
||||
constexpr char kGraphTypeName[] =
|
||||
"mediapipe.tasks.text.text_embedder.TextEmbedderGraph";
|
||||
|
||||
using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
|
||||
// Creates a MediaPipe graph config that contains a single node of type
|
||||
// "mediapipe.tasks.text.text_embedder.TextEmbedderGraph".
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<proto::TextEmbedderGraphOptions> options_proto) {
|
||||
api2::builder::Graph graph;
|
||||
auto& task_graph = graph.AddNode(kGraphTypeName);
|
||||
task_graph.GetOptions<proto::TextEmbedderGraphOptions>().Swap(
|
||||
options_proto.get());
|
||||
graph.In(kTextTag).SetName(kTextInStreamName) >> task_graph.In(kTextTag);
|
||||
task_graph.Out(kEmbeddingsTag).SetName(kEmbeddingsStreamName) >>
|
||||
graph.Out(kEmbeddingsTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
// Converts the user-facing TextEmbedderOptions struct to the internal
|
||||
// TextEmbedderGraphOptions proto.
|
||||
std::unique_ptr<proto::TextEmbedderGraphOptions>
|
||||
ConvertTextEmbedderOptionsToProto(TextEmbedderOptions* options) {
|
||||
auto options_proto = std::make_unique<proto::TextEmbedderGraphOptions>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
auto embedder_options_proto =
|
||||
std::make_unique<components::processors::proto::EmbedderOptions>(
|
||||
components::processors::ConvertEmbedderOptionsToProto(
|
||||
&(options->embedder_options)));
|
||||
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<std::unique_ptr<TextEmbedder>> TextEmbedder::Create(
|
||||
std::unique_ptr<TextEmbedderOptions> options) {
|
||||
std::unique_ptr<proto::TextEmbedderGraphOptions> options_proto =
|
||||
ConvertTextEmbedderOptionsToProto(options.get());
|
||||
return core::TaskApiFactory::Create<TextEmbedder,
|
||||
proto::TextEmbedderGraphOptions>(
|
||||
CreateGraphConfig(std::move(options_proto)),
|
||||
std::move(options->base_options.op_resolver));
|
||||
}
|
||||
|
||||
absl::StatusOr<TextEmbedderResult> TextEmbedder::Embed(absl::string_view text) {
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
runner_->Process(
|
||||
{{kTextInStreamName, MakePacket<std::string>(std::string(text))}}));
|
||||
return ConvertToEmbeddingResult(
|
||||
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
|
||||
}
|
||||
|
||||
absl::StatusOr<double> TextEmbedder::CosineSimilarity(
|
||||
const components::containers::Embedding& u,
|
||||
const components::containers::Embedding& v) {
|
||||
return components::utils::CosineSimilarity(u, v);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::text::text_embedder
|
96
mediapipe/tasks/cc/text/text_embedder/text_embedder.h
Normal file
96
mediapipe/tasks/cc/text/text_embedder/text_embedder.h
Normal file
|
@ -0,0 +1,96 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||
|
||||
namespace mediapipe::tasks::text::text_embedder {
|
||||
|
||||
// Alias the shared EmbeddingResult struct as result typo.
|
||||
using TextEmbedderResult =
|
||||
::mediapipe::tasks::components::containers::EmbeddingResult;
|
||||
|
||||
// Options for configuring a MediaPipe text embedder task.
|
||||
struct TextEmbedderOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// file with metadata, accelerator options, op resolver, etc.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// Options for configuring the embedder behavior, such as L2-normalization or
|
||||
// scalar-quantization.
|
||||
components::processors::EmbedderOptions embedder_options;
|
||||
};
|
||||
|
||||
// Performs embedding extraction on text.
|
||||
//
|
||||
// This API expects a TFLite model with TFLite Model Metadata that contains the
|
||||
// mandatory (described below) input tensors and output tensors. Metadata should
|
||||
// contain the input process unit for the model's Tokenizer as well as input /
|
||||
// output tensor metadata.
|
||||
//
|
||||
// TODO: Support Universal Sentence Encoder.
|
||||
// Input tensors:
|
||||
// (kTfLiteInt32)
|
||||
// - 3 input tensors of size `[batch_size x bert_max_seq_len]` with names
|
||||
// "ids", "mask", and "segment_ids" representing the input ids, mask ids, and
|
||||
// segment ids respectively
|
||||
// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the
|
||||
// input ids
|
||||
//
|
||||
// At least one output tensor with:
|
||||
// (kTfLiteFloat32)
|
||||
// - `N` components corresponding to the `N` dimensions of the returned
|
||||
// feature vector for this output layer.
|
||||
// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`.
|
||||
class TextEmbedder : core::BaseTaskApi {
|
||||
public:
|
||||
using BaseTaskApi::BaseTaskApi;
|
||||
|
||||
// Creates a TextEmbedder from the provided `options`. A non-default
|
||||
// OpResolver can be specified in the BaseOptions in order to support custom
|
||||
// Ops or specify a subset of built-in Ops.
|
||||
static absl::StatusOr<std::unique_ptr<TextEmbedder>> Create(
|
||||
std::unique_ptr<TextEmbedderOptions> options);
|
||||
|
||||
// Performs embedding extraction on the input `text`.
|
||||
absl::StatusOr<TextEmbedderResult> Embed(absl::string_view text);
|
||||
|
||||
// Shuts down the TextEmbedder when all the work is done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
|
||||
// Utility function to compute cosine similarity [1] between two embeddings.
|
||||
// May return an InvalidArgumentError if e.g. the embeddings are of different
|
||||
// types (quantized vs. float), have different sizes, or have a an L2-norm of
|
||||
// 0.
|
||||
//
|
||||
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
||||
static absl::StatusOr<double> CosineSimilarity(
|
||||
const components::containers::Embedding& u,
|
||||
const components::containers::Embedding& v);
|
||||
};
|
||||
|
||||
} // namespace mediapipe::tasks::text::text_embedder
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
|
145
mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc
Normal file
145
mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc
Normal file
|
@ -0,0 +1,145 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::text::text_embedder {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
constexpr char kTextTag[] = "TEXT";
|
||||
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.text.TextEmbedderGraph" performs text embedding
|
||||
// extraction.
|
||||
// - Accepts input text and outputs embeddings on CPU.
|
||||
//
|
||||
// Inputs:
|
||||
// TEXT - std::string
|
||||
// Input text to perform embedding extraction on.
|
||||
//
|
||||
// Outputs:
|
||||
// EMBEDDINGS - EmbeddingResult
|
||||
// The embedding result.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.text.TextEmbedderGraph"
|
||||
// input_stream: "TEXT:text_in"
|
||||
// output_stream: "EMBEDDINGS:embedding_result_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.text.text_embedder.proto.TextEmbedderGraphOptions.ext] {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "/path/to/model.tflite"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class TextEmbedderGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
CHECK(sc != nullptr);
|
||||
ASSIGN_OR_RETURN(const ModelResources* model_resources,
|
||||
CreateModelResources<proto::TextEmbedderGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
Source<EmbeddingResult> embedding_result_out,
|
||||
BuildTextEmbedderTask(sc->Options<proto::TextEmbedderGraphOptions>(),
|
||||
*model_resources,
|
||||
graph[Input<std::string>(kTextTag)], graph));
|
||||
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
// Adds a mediapipe TextEmbedder task graph into the provided
|
||||
// builder::Graph instance. The TextEmbedder task takes an input
|
||||
// text (std::string) and returns an embedding result.
|
||||
//
|
||||
// task_options: the mediapipe tasks TextEmbedderGraphOptions proto.
|
||||
// model_resources: the ModelResources object initialized from a
|
||||
// TextEmbedder model file with model metadata.
|
||||
// text_in: (std::string) stream to run embedding extraction on.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<Source<EmbeddingResult>> BuildTextEmbedderTask(
|
||||
const proto::TextEmbedderGraphOptions& task_options,
|
||||
const ModelResources& model_resources, Source<std::string> text_in,
|
||||
Graph& graph) {
|
||||
// Adds preprocessing calculators and connects them to the text input
|
||||
// stream.
|
||||
auto& preprocessing =
|
||||
graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph");
|
||||
MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph(
|
||||
model_resources,
|
||||
preprocessing.GetOptions<
|
||||
tasks::components::proto::TextPreprocessingGraphOptions>()));
|
||||
text_in >> preprocessing.In(kTextTag);
|
||||
|
||||
// Adds both InferenceCalculator and ModelResourcesCalculator.
|
||||
auto& inference = AddInference(
|
||||
model_resources, task_options.base_options().acceleration(), graph);
|
||||
// The metadata extractor side-output comes from the
|
||||
// ModelResourcesCalculator.
|
||||
inference.SideOut(kMetadataExtractorTag) >>
|
||||
preprocessing.SideIn(kMetadataExtractorTag);
|
||||
preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag);
|
||||
|
||||
// Adds postprocessing calculators and connects its input stream to the
|
||||
// inference results.
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
|
||||
model_resources, task_options.embedder_options(),
|
||||
&postprocessing.GetOptions<components::processors::proto::
|
||||
EmbeddingPostprocessingGraphOptions>()));
|
||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||
|
||||
// Outputs the embedding result.
|
||||
return postprocessing[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::text::text_embedder::TextEmbedderGraph);
|
||||
|
||||
} // namespace mediapipe::tasks::text::text_embedder
|
143
mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc
Normal file
143
mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc
Normal file
|
@ -0,0 +1,143 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe::tasks::text::text_embedder {
|
||||
namespace {
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||
|
||||
// Note that these models use dynamic-sized tensors.
|
||||
// Embedding model with BERT preprocessing.
|
||||
constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite";
|
||||
// Embedding model with regex preprocessing.
|
||||
constexpr char kRegexOneEmbeddingModel[] =
|
||||
"regex_one_embedding_with_metadata.tflite";
|
||||
|
||||
// Tolerance for embedding vector coordinate values.
|
||||
constexpr float kEpsilon = 1e-4;
|
||||
// Tolerancy for cosine similarity evaluation.
|
||||
constexpr double kSimilarityTolerancy = 1e-6;
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
class EmbedderTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(EmbedderTest, FailsWithMissingModel) {
|
||||
auto text_embedder =
|
||||
TextEmbedder::Create(std::make_unique<TextEmbedderOptions>());
|
||||
ASSERT_EQ(text_embedder.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
ASSERT_THAT(
|
||||
text_embedder.status().message(),
|
||||
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||
ASSERT_THAT(text_embedder.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
}
|
||||
|
||||
TEST_F(EmbedderTest, SucceedsWithMobileBert) {
|
||||
auto options = std::make_unique<TextEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileBert);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||
TextEmbedder::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
TextEmbedderResult result0,
|
||||
text_embedder->Embed("it's a charming and often affecting journey"));
|
||||
ASSERT_EQ(result0.embeddings.size(), 1);
|
||||
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512);
|
||||
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon);
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result1, text_embedder->Embed("what a great and fantastic trip"));
|
||||
ASSERT_EQ(result1.embeddings.size(), 1);
|
||||
ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 512);
|
||||
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 22.626251f, kEpsilon);
|
||||
|
||||
// Check cosine similarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||
result1.embeddings[0]));
|
||||
EXPECT_NEAR(similarity, 0.969514, kSimilarityTolerancy);
|
||||
|
||||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
TEST(EmbedTest, SucceedsWithRegexOneEmbeddingModel) {
|
||||
auto options = std::make_unique<TextEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kRegexOneEmbeddingModel);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||
TextEmbedder::Create(std::move(options)));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result0,
|
||||
text_embedder->Embed("it's a charming and often affecting journey"));
|
||||
EXPECT_EQ(result0.embeddings.size(), 1);
|
||||
EXPECT_EQ(result0.embeddings[0].float_embedding.size(), 16);
|
||||
|
||||
EXPECT_NEAR(result0.embeddings[0].float_embedding[0], 0.0309356f, kEpsilon);
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result1, text_embedder->Embed("what a great and fantastic trip"));
|
||||
EXPECT_EQ(result1.embeddings.size(), 1);
|
||||
EXPECT_EQ(result1.embeddings[0].float_embedding.size(), 16);
|
||||
|
||||
EXPECT_NEAR(result1.embeddings[0].float_embedding[0], 0.0312863f, kEpsilon);
|
||||
|
||||
// Check cosine similarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||
result1.embeddings[0]));
|
||||
EXPECT_NEAR(similarity, 0.999937, kSimilarityTolerancy);
|
||||
|
||||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
TEST_F(EmbedderTest, SucceedsWithQuantization) {
|
||||
auto options = std::make_unique<TextEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileBert);
|
||||
options->embedder_options.quantize = true;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||
TextEmbedder::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
TextEmbedderResult result,
|
||||
text_embedder->Embed("it's a charming and often affecting journey"));
|
||||
ASSERT_EQ(result.embeddings.size(), 1);
|
||||
ASSERT_EQ(result.embeddings[0].quantized_embedding.size(), 512);
|
||||
|
||||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::tasks::text::text_embedder
|
|
@ -110,4 +110,38 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hand_landmarker",
|
||||
srcs = ["hand_landmarker.cc"],
|
||||
hdrs = ["hand_landmarker.h"],
|
||||
deps = [
|
||||
":hand_landmarker_graph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
"//mediapipe/tasks/cc/components/containers:hand_landmarks_detection_result",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:base_task_api",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: Enable this test
|
||||
|
|
269
mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc
Normal file
269
mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc
Normal file
|
@ -0,0 +1,269 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h"
|
||||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace hand_landmarker {
|
||||
|
||||
namespace {
|
||||
|
||||
using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||
hand_landmarker::proto::HandLandmarkerGraphOptions;
|
||||
|
||||
using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult;
|
||||
|
||||
constexpr char kHandLandmarkerGraphTypeName[] =
|
||||
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph";
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageInStreamName[] = "image_in";
|
||||
constexpr char kImageOutStreamName[] = "image_out";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||
constexpr char kHandednessTag[] = "HANDEDNESS";
|
||||
constexpr char kHandednessStreamName[] = "handedness";
|
||||
constexpr char kHandLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kHandLandmarksStreamName[] = "landmarks";
|
||||
constexpr char kHandWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
||||
constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
// Creates a MediaPipe graph config that contains a subgraph node of
|
||||
// "mediapipe.tasks.vision.hand_ladnamrker.HandLandmarkerGraph". If the task is
|
||||
// running in the live stream mode, a "FlowLimiterCalculator" will be added to
|
||||
// limit the number of frames in flight.
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<HandLandmarkerGraphOptionsProto> options,
|
||||
bool enable_flow_limiting) {
|
||||
api2::builder::Graph graph;
|
||||
auto& subgraph = graph.AddNode(kHandLandmarkerGraphTypeName);
|
||||
subgraph.GetOptions<HandLandmarkerGraphOptionsProto>().Swap(options.get());
|
||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||
subgraph.Out(kHandednessTag).SetName(kHandednessStreamName) >>
|
||||
graph.Out(kHandednessTag);
|
||||
subgraph.Out(kHandLandmarksTag).SetName(kHandLandmarksStreamName) >>
|
||||
graph.Out(kHandLandmarksTag);
|
||||
subgraph.Out(kHandWorldLandmarksTag).SetName(kHandWorldLandmarksStreamName) >>
|
||||
graph.Out(kHandWorldLandmarksTag);
|
||||
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
|
||||
if (enable_flow_limiting) {
|
||||
return tasks::core::AddFlowLimiterCalculator(
|
||||
graph, subgraph, {kImageTag, kNormRectTag}, kHandLandmarksTag);
|
||||
}
|
||||
graph.In(kImageTag) >> subgraph.In(kImageTag);
|
||||
graph.In(kNormRectTag) >> subgraph.In(kNormRectTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
// Converts the user-facing HandLandmarkerOptions struct to the internal
|
||||
// HandLandmarkerGraphOptions proto.
|
||||
std::unique_ptr<HandLandmarkerGraphOptionsProto>
|
||||
ConvertHandLandmarkerGraphOptionsProto(HandLandmarkerOptions* options) {
|
||||
auto options_proto = std::make_unique<HandLandmarkerGraphOptionsProto>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||
options->running_mode != core::RunningMode::IMAGE);
|
||||
|
||||
// Configure hand detector options.
|
||||
auto* hand_detector_graph_options =
|
||||
options_proto->mutable_hand_detector_graph_options();
|
||||
hand_detector_graph_options->set_num_hands(options->num_hands);
|
||||
hand_detector_graph_options->set_min_detection_confidence(
|
||||
options->min_hand_detection_confidence);
|
||||
|
||||
// Configure hand landmark detector options.
|
||||
options_proto->set_min_tracking_confidence(options->min_tracking_confidence);
|
||||
auto* hand_landmarks_detector_graph_options =
|
||||
options_proto->mutable_hand_landmarks_detector_graph_options();
|
||||
hand_landmarks_detector_graph_options->set_min_detection_confidence(
|
||||
options->min_hand_presence_confidence);
|
||||
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<std::unique_ptr<HandLandmarker>> HandLandmarker::Create(
|
||||
std::unique_ptr<HandLandmarkerOptions> options) {
|
||||
auto options_proto = ConvertHandLandmarkerGraphOptionsProto(options.get());
|
||||
tasks::core::PacketsCallback packets_callback = nullptr;
|
||||
if (options->result_callback) {
|
||||
auto result_callback = options->result_callback;
|
||||
packets_callback = [=](absl::StatusOr<tasks::core::PacketMap>
|
||||
status_or_packets) {
|
||||
if (!status_or_packets.ok()) {
|
||||
Image image;
|
||||
result_callback(status_or_packets.status(), image,
|
||||
Timestamp::Unset().Value());
|
||||
return;
|
||||
}
|
||||
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
|
||||
return;
|
||||
}
|
||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||
if (status_or_packets.value()[kHandLandmarksStreamName].IsEmpty()) {
|
||||
Packet empty_packet =
|
||||
status_or_packets.value()[kHandLandmarksStreamName];
|
||||
result_callback(
|
||||
{HandLandmarksDetectionResult()}, image_packet.Get<Image>(),
|
||||
empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
|
||||
return;
|
||||
}
|
||||
Packet handedness_packet =
|
||||
status_or_packets.value()[kHandednessStreamName];
|
||||
Packet hand_landmarks_packet =
|
||||
status_or_packets.value()[kHandLandmarksStreamName];
|
||||
Packet hand_world_landmarks_packet =
|
||||
status_or_packets.value()[kHandWorldLandmarksStreamName];
|
||||
result_callback(
|
||||
{{handedness_packet.Get<std::vector<ClassificationList>>(),
|
||||
hand_landmarks_packet.Get<std::vector<NormalizedLandmarkList>>(),
|
||||
hand_world_landmarks_packet.Get<std::vector<LandmarkList>>()}},
|
||||
image_packet.Get<Image>(),
|
||||
hand_landmarks_packet.Timestamp().Value() /
|
||||
kMicroSecondsPerMilliSecond);
|
||||
};
|
||||
}
|
||||
return core::VisionTaskApiFactory::Create<HandLandmarker,
|
||||
HandLandmarkerGraphOptionsProto>(
|
||||
CreateGraphConfig(
|
||||
std::move(options_proto),
|
||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||
std::move(options->base_options.op_resolver), options->running_mode,
|
||||
std::move(packets_callback));
|
||||
}
|
||||
|
||||
absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::Detect(
|
||||
mediapipe::Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessImageData(
|
||||
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||
if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
|
||||
return {HandLandmarksDetectionResult()};
|
||||
}
|
||||
return {{/* handedness= */
|
||||
{output_packets[kHandednessStreamName]
|
||||
.Get<std::vector<mediapipe::ClassificationList>>()},
|
||||
/* hand_landmarks= */
|
||||
{output_packets[kHandLandmarksStreamName]
|
||||
.Get<std::vector<mediapipe::NormalizedLandmarkList>>()},
|
||||
/* hand_world_landmarks */
|
||||
{output_packets[kHandWorldLandmarksStreamName]
|
||||
.Get<std::vector<mediapipe::LandmarkList>>()}}};
|
||||
}
|
||||
|
||||
absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::DetectForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessVideoData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||
if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
|
||||
return {HandLandmarksDetectionResult()};
|
||||
}
|
||||
return {
|
||||
{/* handedness= */
|
||||
{output_packets[kHandednessStreamName]
|
||||
.Get<std::vector<mediapipe::ClassificationList>>()},
|
||||
/* hand_landmarks= */
|
||||
{output_packets[kHandLandmarksStreamName]
|
||||
.Get<std::vector<mediapipe::NormalizedLandmarkList>>()},
|
||||
/* hand_world_landmarks */
|
||||
{output_packets[kHandWorldLandmarksStreamName]
|
||||
.Get<std::vector<mediapipe::LandmarkList>>()}},
|
||||
};
|
||||
}
|
||||
|
||||
absl::Status HandLandmarker::DetectAsync(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
return SendLiveStreamData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||
}
|
||||
|
||||
} // namespace hand_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
192
mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h
Normal file
192
mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h
Normal file
|
@ -0,0 +1,192 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace hand_landmarker {
|
||||
|
||||
struct HandLandmarkerOptions {
|
||||
// Base options for configuring MediaPipe Tasks library, such as specifying
|
||||
// the TfLite model bundle file with metadata, accelerator options, op
|
||||
// resolver, etc.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// The running mode of the task. Default to the image mode.
|
||||
// HandLandmarker has three running modes:
|
||||
// 1) The image mode for detecting hand landmarks on single image inputs.
|
||||
// 2) The video mode for detecting hand landmarks on the decoded frames of a
|
||||
// video.
|
||||
// 3) The live stream mode for detecting hand landmarks on the live stream of
|
||||
// input data, such as from camera. In this mode, the "result_callback"
|
||||
// below must be specified to receive the detection results asynchronously.
|
||||
core::RunningMode running_mode = core::RunningMode::IMAGE;
|
||||
|
||||
// The maximum number of hands can be detected by the HandLandmarker.
|
||||
int num_hands = 1;
|
||||
|
||||
// The minimum confidence score for the hand detection to be considered
|
||||
// successful.
|
||||
float min_hand_detection_confidence = 0.5;
|
||||
|
||||
// The minimum confidence score of hand presence score in the hand landmark
|
||||
// detection.
|
||||
float min_hand_presence_confidence = 0.5;
|
||||
|
||||
// The minimum confidence score for the hand tracking to be considered
|
||||
// successful.
|
||||
float min_tracking_confidence = 0.5;
|
||||
|
||||
// The user-defined result callback for processing live stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::LIVE_STREAM.
|
||||
std::function<void(
|
||||
absl::StatusOr<components::containers::HandLandmarksDetectionResult>,
|
||||
const Image&, int64)>
|
||||
result_callback = nullptr;
|
||||
};
|
||||
|
||||
// Performs hand landmarks detection on the given image.
|
||||
//
|
||||
// TODO add the link to DevSite.
|
||||
// This API expects a pre-trained hand landmarker model asset bundle.
|
||||
//
|
||||
// Inputs:
|
||||
// Image
|
||||
// - The image that hand landmarks detection runs on.
|
||||
// std::optional<NormalizedRect>
|
||||
// - If provided, can be used to specify the rotation to apply to the image
|
||||
// before performing hand landmarks detection, by setting its 'rotation'
|
||||
// field in radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation).
|
||||
// Note that specifying a region-of-interest using the 'x_center',
|
||||
// 'y_center', 'width' and 'height' fields is NOT supported and will
|
||||
// result in an invalid argument error being returned.
|
||||
// Outputs:
|
||||
// HandLandmarksDetectionResult
|
||||
// - The hand landmarks detection results.
|
||||
class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
|
||||
public:
|
||||
using BaseVisionTaskApi::BaseVisionTaskApi;
|
||||
|
||||
// Creates a HandLandmarker from a HandLandmarkerOptions to process image data
|
||||
// or streaming data. Hand landmarker can be created with one of the following
|
||||
// three running modes:
|
||||
// 1) Image mode for detecting hand landmarks on single image inputs. Users
|
||||
// provide mediapipe::Image to the `Detect` method, and will receive the
|
||||
// deteced hand landmarks results as the return value.
|
||||
// 2) Video mode for detecting hand landmarks on the decoded frames of a
|
||||
// video. Users call `DetectForVideo` method, and will receive the detected
|
||||
// hand landmarks results as the return value.
|
||||
// 3) Live stream mode for detecting hand landmarks on the live stream of the
|
||||
// input data, such as from camera. Users call `DetectAsync` to push the
|
||||
// image data into the HandLandmarker, the detected results along with the
|
||||
// input timestamp and the image that hand landmarker runs on will be
|
||||
// available in the result callback when the hand landmarker finishes the
|
||||
// work.
|
||||
static absl::StatusOr<std::unique_ptr<HandLandmarker>> Create(
|
||||
std::unique_ptr<HandLandmarkerOptions> options);
|
||||
|
||||
// Performs hand landmarks detection on the given image.
|
||||
// Only use this method when the HandLandmarker is created with the image
|
||||
// running mode.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing detection, by setting
|
||||
// its 'rotation_degrees' field. Note that specifying a region-of-interest
|
||||
// using the 'region_of_interest' field is NOT supported and will result in an
|
||||
// invalid argument error being returned.
|
||||
//
|
||||
// The image can be of any size with format RGB or RGBA.
|
||||
// TODO: Describes how the input image will be preprocessed
|
||||
// after the yuv support is implemented.
|
||||
absl::StatusOr<components::containers::HandLandmarksDetectionResult> Detect(
|
||||
Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Performs hand landmarks detection on the provided video frame.
|
||||
// Only use this method when the HandLandmarker is created with the video
|
||||
// running mode.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing detection, by setting
|
||||
// its 'rotation_degrees' field. Note that specifying a region-of-interest
|
||||
// using the 'region_of_interest' field is NOT supported and will result in an
|
||||
// invalid argument error being returned.
|
||||
//
|
||||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
// must be monotonically increasing.
|
||||
absl::StatusOr<components::containers::HandLandmarksDetectionResult>
|
||||
DetectForVideo(Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
image_processing_options = std::nullopt);
|
||||
|
||||
// Sends live image data to perform hand landmarks detection, and the results
|
||||
// will be available via the "result_callback" provided in the
|
||||
// HandLandmarkerOptions. Only use this method when the HandLandmarker
|
||||
// is created with the live stream running mode.
|
||||
//
|
||||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
// sent to the hand landmarker. The input timestamps must be monotonically
|
||||
// increasing.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing detection, by setting
|
||||
// its 'rotation_degrees' field. Note that specifying a region-of-interest
|
||||
// using the 'region_of_interest' field is NOT supported and will result in an
|
||||
// invalid argument error being returned.
|
||||
//
|
||||
// The "result_callback" provides
|
||||
// - A vector of HandLandmarksDetectionResult, each is the detected results
|
||||
// for a input frame.
|
||||
// - The const reference to the corresponding input image that the hand
|
||||
// landmarker runs on. Note that the const reference to the image will no
|
||||
// longer be valid when the callback returns. To access the image data
|
||||
// outside of the callback, callers need to make a copy of the image.
|
||||
// - The input timestamp in milliseconds.
|
||||
absl::Status DetectAsync(Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
image_processing_options = std::nullopt);
|
||||
|
||||
// Shuts down the HandLandmarker when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
};
|
||||
|
||||
} // namespace hand_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_H_
|
|
@ -0,0 +1,511 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace hand_landmarker {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::file::Defaults;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::EqualsProto;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
using ::testing::Pointwise;
|
||||
using ::testing::TestParamInfo;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
using ::testing::proto::Approximately;
|
||||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kHandLandmarkerBundleAsset[] = "hand_landmarker.task";
|
||||
constexpr char kThumbUpLandmarksFilename[] = "thumb_up_landmarks.pbtxt";
|
||||
constexpr char kPointingUpLandmarksFilename[] = "pointing_up_landmarks.pbtxt";
|
||||
constexpr char kPointingUpRotatedLandmarksFilename[] =
|
||||
"pointing_up_rotated_landmarks.pbtxt";
|
||||
constexpr char kThumbUpImage[] = "thumb_up.jpg";
|
||||
constexpr char kPointingUpImage[] = "pointing_up.jpg";
|
||||
constexpr char kPointingUpRotatedImage[] = "pointing_up_rotated.jpg";
|
||||
constexpr char kNoHandsImage[] = "cats_and_dogs.jpg";
|
||||
|
||||
constexpr float kLandmarksFractionDiff = 0.03; // percentage
|
||||
constexpr float kLandmarksAbsMargin = 0.03;
|
||||
constexpr float kHandednessMargin = 0.05;
|
||||
|
||||
LandmarksDetectionResult GetLandmarksDetectionResult(
|
||||
absl::string_view landmarks_file_name) {
|
||||
LandmarksDetectionResult result;
|
||||
MP_EXPECT_OK(GetTextProto(
|
||||
file::JoinPath("./", kTestDataDirectory, landmarks_file_name), &result,
|
||||
Defaults()));
|
||||
// Remove z position of landmarks, because they are not used in correctness
|
||||
// testing. For video or live stream mode, the z positions varies a lot during
|
||||
// tracking from frame to frame.
|
||||
for (int i = 0; i < result.landmarks().landmark().size(); i++) {
|
||||
auto& landmark = *result.mutable_landmarks()->mutable_landmark(i);
|
||||
landmark.clear_z();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
HandLandmarksDetectionResult GetExpectedHandLandmarksDetectionResult(
|
||||
const std::vector<absl::string_view>& landmarks_file_names) {
|
||||
HandLandmarksDetectionResult expected_results;
|
||||
for (const auto& file_name : landmarks_file_names) {
|
||||
const auto landmarks_detection_result =
|
||||
GetLandmarksDetectionResult(file_name);
|
||||
expected_results.hand_landmarks.push_back(
|
||||
landmarks_detection_result.landmarks());
|
||||
expected_results.handedness.push_back(
|
||||
landmarks_detection_result.classifications());
|
||||
}
|
||||
return expected_results;
|
||||
}
|
||||
|
||||
void ExpectHandLandmarksDetectionResultsCorrect(
|
||||
const HandLandmarksDetectionResult& actual_results,
|
||||
const HandLandmarksDetectionResult& expected_results) {
|
||||
const auto& actual_landmarks = actual_results.hand_landmarks;
|
||||
const auto& actual_handedness = actual_results.handedness;
|
||||
|
||||
const auto& expected_landmarks = expected_results.hand_landmarks;
|
||||
const auto& expected_handedness = expected_results.handedness;
|
||||
|
||||
ASSERT_EQ(actual_landmarks.size(), expected_landmarks.size());
|
||||
ASSERT_EQ(actual_handedness.size(), expected_handedness.size());
|
||||
|
||||
EXPECT_THAT(
|
||||
actual_handedness,
|
||||
Pointwise(Approximately(Partially(EqualsProto()), kHandednessMargin),
|
||||
expected_handedness));
|
||||
EXPECT_THAT(actual_landmarks,
|
||||
Pointwise(Approximately(Partially(EqualsProto()),
|
||||
/*margin=*/kLandmarksAbsMargin,
|
||||
/*fraction=*/kLandmarksFractionDiff),
|
||||
expected_landmarks));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
struct TestParams {
|
||||
// The name of this test, for convenience when displaying test results.
|
||||
std::string test_name;
|
||||
// The filename of test image.
|
||||
std::string test_image_name;
|
||||
// The filename of test model.
|
||||
std::string test_model_file;
|
||||
// The rotation to apply to the test image before processing, in degrees
|
||||
// clockwise.
|
||||
int rotation;
|
||||
// Expected results from the hand landmarker model output.
|
||||
HandLandmarksDetectionResult expected_results;
|
||||
};
|
||||
|
||||
class ImageModeTest : public testing::TestWithParam<TestParams> {};
|
||||
|
||||
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage)));
|
||||
auto options = std::make_unique<HandLandmarkerOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
|
||||
options->running_mode = core::RunningMode::IMAGE;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
||||
HandLandmarker::Create(std::move(options)));
|
||||
auto results = hand_landmarker->DetectForVideo(image, 0);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("not initialized with the video mode"));
|
||||
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
|
||||
|
||||
results = hand_landmarker->DetectAsync(image, 0);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("not initialized with the live stream mode"));
|
||||
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
|
||||
MP_ASSERT_OK(hand_landmarker->Close());
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage)));
|
||||
auto options = std::make_unique<HandLandmarkerOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
|
||||
options->running_mode = core::RunningMode::IMAGE;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
||||
HandLandmarker::Create(std::move(options)));
|
||||
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
auto results = hand_landmarker->Detect(image, image_processing_options);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("This task doesn't support region-of-interest"));
|
||||
EXPECT_THAT(
|
||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||
}
|
||||
|
||||
TEST_P(ImageModeTest, Succeeds) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
GetParam().test_image_name)));
|
||||
auto options = std::make_unique<HandLandmarkerOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
|
||||
options->running_mode = core::RunningMode::IMAGE;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
||||
HandLandmarker::Create(std::move(options)));
|
||||
HandLandmarksDetectionResult hand_landmarker_results;
|
||||
if (GetParam().rotation != 0) {
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = GetParam().rotation;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
hand_landmarker_results,
|
||||
hand_landmarker->Detect(image, image_processing_options));
|
||||
} else {
|
||||
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
|
||||
hand_landmarker->Detect(image));
|
||||
}
|
||||
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results,
|
||||
GetParam().expected_results);
|
||||
MP_ASSERT_OK(hand_landmarker->Close());
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
HandGestureTest, ImageModeTest,
|
||||
Values(TestParams{
|
||||
/* test_name= */ "LandmarksThumbUp",
|
||||
/* test_image_name= */ kThumbUpImage,
|
||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ 0,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
{kThumbUpLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "LandmarksPointingUp",
|
||||
/* test_image_name= */ kPointingUpImage,
|
||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ 0,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
{kPointingUpLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "LandmarksPointingUpRotated",
|
||||
/* test_image_name= */ kPointingUpRotatedImage,
|
||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ -90,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
{kPointingUpRotatedLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "NoHands",
|
||||
/* test_image_name= */ kNoHandsImage,
|
||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ 0,
|
||||
/* expected_results = */
|
||||
{{}, {}, {}},
|
||||
}),
|
||||
[](const TestParamInfo<ImageModeTest::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
});
|
||||
|
||||
class VideoModeTest : public testing::TestWithParam<TestParams> {};
|
||||
|
||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage)));
|
||||
auto options = std::make_unique<HandLandmarkerOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
|
||||
options->running_mode = core::RunningMode::VIDEO;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
||||
HandLandmarker::Create(std::move(options)));
|
||||
auto results = hand_landmarker->Detect(image);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("not initialized with the image mode"));
|
||||
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
|
||||
|
||||
results = hand_landmarker->DetectAsync(image, 0);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("not initialized with the live stream mode"));
|
||||
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
|
||||
MP_ASSERT_OK(hand_landmarker->Close());
|
||||
}
|
||||
|
||||
TEST_P(VideoModeTest, Succeeds) {
|
||||
const int iterations = 100;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
GetParam().test_image_name)));
|
||||
auto options = std::make_unique<HandLandmarkerOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
|
||||
options->running_mode = core::RunningMode::VIDEO;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
||||
HandLandmarker::Create(std::move(options)));
|
||||
const auto expected_results = GetParam().expected_results;
|
||||
for (int i = 0; i < iterations; ++i) {
|
||||
HandLandmarksDetectionResult hand_landmarker_results;
|
||||
if (GetParam().rotation != 0) {
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = GetParam().rotation;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
hand_landmarker_results,
|
||||
hand_landmarker->DetectForVideo(image, i, image_processing_options));
|
||||
} else {
|
||||
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
|
||||
hand_landmarker->DetectForVideo(image, i));
|
||||
}
|
||||
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results,
|
||||
expected_results);
|
||||
}
|
||||
MP_ASSERT_OK(hand_landmarker->Close());
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
HandGestureTest, VideoModeTest,
|
||||
Values(TestParams{
|
||||
/* test_name= */ "LandmarksThumbUp",
|
||||
/* test_image_name= */ kThumbUpImage,
|
||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ 0,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
{kThumbUpLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "LandmarksPointingUp",
|
||||
/* test_image_name= */ kPointingUpImage,
|
||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ 0,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
{kPointingUpLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "LandmarksPointingUpRotated",
|
||||
/* test_image_name= */ kPointingUpRotatedImage,
|
||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ -90,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
{kPointingUpRotatedLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "NoHands",
|
||||
/* test_image_name= */ kNoHandsImage,
|
||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ 0,
|
||||
/* expected_results = */
|
||||
{{}, {}, {}},
|
||||
}),
|
||||
[](const TestParamInfo<ImageModeTest::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
});
|
||||
|
||||
class LiveStreamModeTest : public testing::TestWithParam<TestParams> {};
|
||||
|
||||
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage)));
|
||||
auto options = std::make_unique<HandLandmarkerOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
|
||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
options->result_callback =
|
||||
[](absl::StatusOr<HandLandmarksDetectionResult> results,
|
||||
const Image& image, int64 timestamp_ms) {};
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
||||
HandLandmarker::Create(std::move(options)));
|
||||
auto results = hand_landmarker->Detect(image);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("not initialized with the image mode"));
|
||||
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
|
||||
|
||||
results = hand_landmarker->DetectForVideo(image, 0);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("not initialized with the video mode"));
|
||||
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
|
||||
MP_ASSERT_OK(hand_landmarker->Close());
|
||||
}
|
||||
|
||||
TEST_P(LiveStreamModeTest, Succeeds) {
|
||||
const int iterations = 100;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
GetParam().test_image_name)));
|
||||
auto options = std::make_unique<HandLandmarkerOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
|
||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
std::vector<HandLandmarksDetectionResult> hand_landmarker_results;
|
||||
std::vector<std::pair<int, int>> image_sizes;
|
||||
std::vector<int64> timestamps;
|
||||
options->result_callback =
|
||||
[&hand_landmarker_results, &image_sizes, ×tamps](
|
||||
absl::StatusOr<HandLandmarksDetectionResult> results,
|
||||
const Image& image, int64 timestamp_ms) {
|
||||
MP_ASSERT_OK(results.status());
|
||||
hand_landmarker_results.push_back(std::move(results.value()));
|
||||
image_sizes.push_back({image.width(), image.height()});
|
||||
timestamps.push_back(timestamp_ms);
|
||||
};
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
||||
HandLandmarker::Create(std::move(options)));
|
||||
for (int i = 0; i < iterations; ++i) {
|
||||
HandLandmarksDetectionResult hand_landmarker_results;
|
||||
if (GetParam().rotation != 0) {
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = GetParam().rotation;
|
||||
MP_ASSERT_OK(
|
||||
hand_landmarker->DetectAsync(image, i, image_processing_options));
|
||||
} else {
|
||||
MP_ASSERT_OK(hand_landmarker->DetectAsync(image, i));
|
||||
}
|
||||
}
|
||||
MP_ASSERT_OK(hand_landmarker->Close());
|
||||
// Due to the flow limiter, the total of outputs will be smaller than the
|
||||
// number of iterations.
|
||||
ASSERT_LE(hand_landmarker_results.size(), iterations);
|
||||
ASSERT_GT(hand_landmarker_results.size(), 0);
|
||||
|
||||
const auto expected_results = GetParam().expected_results;
|
||||
for (int i = 0; i < hand_landmarker_results.size(); ++i) {
|
||||
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results[i],
|
||||
expected_results);
|
||||
}
|
||||
for (const auto& image_size : image_sizes) {
|
||||
EXPECT_EQ(image_size.first, image.width());
|
||||
EXPECT_EQ(image_size.second, image.height());
|
||||
}
|
||||
int64 timestamp_ms = -1;
|
||||
for (const auto& timestamp : timestamps) {
|
||||
EXPECT_GT(timestamp, timestamp_ms);
|
||||
timestamp_ms = timestamp;
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
HandGestureTest, LiveStreamModeTest,
|
||||
Values(TestParams{
|
||||
/* test_name= */ "LandmarksThumbUp",
|
||||
/* test_image_name= */ kThumbUpImage,
|
||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ 0,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
{kThumbUpLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "LandmarksPointingUp",
|
||||
/* test_image_name= */ kPointingUpImage,
|
||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ 0,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
{kPointingUpLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "LandmarksPointingUpRotated",
|
||||
/* test_image_name= */ kPointingUpRotatedImage,
|
||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ -90,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
{kPointingUpRotatedLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "NoHands",
|
||||
/* test_image_name= */ kNoHandsImage,
|
||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ 0,
|
||||
/* expected_results = */
|
||||
{{}, {}, {}},
|
||||
}),
|
||||
[](const TestParamInfo<ImageModeTest::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
});
|
||||
|
||||
} // namespace hand_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -32,7 +32,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
|
@ -63,7 +63,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
"//mediapipe/util:label_map_util",
|
||||
|
|
|
@ -23,10 +23,12 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_segmenter {
|
||||
namespace {
|
||||
|
||||
constexpr char kSegmentationStreamName[] = "segmented_mask_out";
|
||||
|
@ -37,23 +39,24 @@ constexpr char kImageTag[] = "IMAGE";
|
|||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.vision.ImageSegmenterGraph";
|
||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
using ::mediapipe::CalculatorGraphConfig;
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::tasks::components::proto::SegmenterOptions;
|
||||
using ImageSegmenterOptionsProto =
|
||||
image_segmenter::proto::ImageSegmenterOptions;
|
||||
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||
image_segmenter::proto::ImageSegmenterGraphOptions;
|
||||
|
||||
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||
// "mediapipe.tasks.vision.ImageSegmenterGraph".
|
||||
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<ImageSegmenterOptionsProto> options,
|
||||
std::unique_ptr<ImageSegmenterGraphOptionsProto> options,
|
||||
bool enable_flow_limiting) {
|
||||
api2::builder::Graph graph;
|
||||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
task_subgraph.GetOptions<ImageSegmenterOptionsProto>().Swap(options.get());
|
||||
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
|
||||
options.get());
|
||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
|
||||
|
@ -72,9 +75,9 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
|
||||
// Converts the user-facing ImageSegmenterOptions struct to the internal
|
||||
// ImageSegmenterOptions proto.
|
||||
std::unique_ptr<ImageSegmenterOptionsProto> ConvertImageSegmenterOptionsToProto(
|
||||
ImageSegmenterOptions* options) {
|
||||
auto options_proto = std::make_unique<ImageSegmenterOptionsProto>();
|
||||
std::unique_ptr<ImageSegmenterGraphOptionsProto>
|
||||
ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) {
|
||||
auto options_proto = std::make_unique<ImageSegmenterGraphOptionsProto>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
|
@ -137,7 +140,7 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
|||
};
|
||||
}
|
||||
return core::VisionTaskApiFactory::Create<ImageSegmenter,
|
||||
ImageSegmenterOptionsProto>(
|
||||
ImageSegmenterGraphOptionsProto>(
|
||||
CreateGraphConfig(
|
||||
std::move(options_proto),
|
||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||
|
@ -211,6 +214,7 @@ absl::Status ImageSegmenter::SegmentAsync(
|
|||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||
}
|
||||
|
||||
} // namespace image_segmenter
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -26,12 +26,12 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_segmenter {
|
||||
|
||||
// The options for configuring a mediapipe image segmenter task.
|
||||
struct ImageSegmenterOptions {
|
||||
|
@ -191,6 +191,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
|||
absl::Status Close() { return runner_->Close(); }
|
||||
};
|
||||
|
||||
} // namespace image_segmenter
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -35,7 +35,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
#include "mediapipe/util/label_map_util.h"
|
||||
|
@ -44,6 +44,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_segmenter {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -55,7 +56,8 @@ using ::mediapipe::api2::builder::MultiSource;
|
|||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::proto::SegmenterOptions;
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
using ::mediapipe::tasks::vision::image_segmenter::proto::ImageSegmenterOptions;
|
||||
using ::mediapipe::tasks::vision::image_segmenter::proto::
|
||||
ImageSegmenterGraphOptions;
|
||||
using ::tflite::Tensor;
|
||||
using ::tflite::TensorMetadata;
|
||||
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
||||
|
@ -77,7 +79,7 @@ struct ImageSegmenterOutputs {
|
|||
|
||||
} // namespace
|
||||
|
||||
absl::Status SanityCheckOptions(const ImageSegmenterOptions& options) {
|
||||
absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) {
|
||||
if (options.segmenter_options().output_type() ==
|
||||
SegmenterOptions::UNSPECIFIED) {
|
||||
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
||||
|
@ -112,7 +114,7 @@ absl::StatusOr<LabelItems> GetLabelItemsIfAny(
|
|||
}
|
||||
|
||||
absl::Status ConfigureTensorsToSegmentationCalculator(
|
||||
const ImageSegmenterOptions& segmenter_option,
|
||||
const ImageSegmenterGraphOptions& segmenter_option,
|
||||
const core::ModelResources& model_resources,
|
||||
TensorsToSegmentationCalculatorOptions* options) {
|
||||
*options->mutable_segmenter_options() = segmenter_option.segmenter_options();
|
||||
|
@ -181,7 +183,7 @@ absl::StatusOr<const Tensor*> GetOutputTensor(
|
|||
// input_stream: "IMAGE:image"
|
||||
// output_stream: "SEGMENTATION:segmented_masks"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterOptions.ext]
|
||||
// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
|
@ -200,12 +202,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
||||
mediapipe::SubgraphContext* sc) override {
|
||||
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||
CreateModelResources<ImageSegmenterOptions>(sc));
|
||||
CreateModelResources<ImageSegmenterGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_streams,
|
||||
BuildSegmentationTask(
|
||||
sc->Options<ImageSegmenterOptions>(), *model_resources,
|
||||
sc->Options<ImageSegmenterGraphOptions>(), *model_resources,
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
|
||||
|
@ -228,13 +230,13 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
// builder::Graph instance. The segmentation pipeline takes images
|
||||
// (mediapipe::Image) as the input and returns segmented image mask as output.
|
||||
//
|
||||
// task_options: the mediapipe tasks ImageSegmenterOptions proto.
|
||||
// task_options: the mediapipe tasks ImageSegmenterGraphOptions proto.
|
||||
// model_resources: the ModelSources object initialized from a segmentation
|
||||
// model file with model metadata.
|
||||
// image_in: (mediapipe::Image) stream to run segmentation on.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
|
||||
const ImageSegmenterOptions& task_options,
|
||||
const ImageSegmenterGraphOptions& task_options,
|
||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
||||
|
@ -293,8 +295,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::ImageSegmenterGraph);
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::image_segmenter::ImageSegmenterGraph);
|
||||
|
||||
} // namespace image_segmenter
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -33,7 +33,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||
|
@ -42,6 +42,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_segmenter {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::Image;
|
||||
|
@ -547,6 +548,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
|||
// TODO: Add test for hair segmentation model.
|
||||
|
||||
} // namespace
|
||||
} // namespace image_segmenter
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -19,8 +19,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "image_segmenter_options_proto",
|
||||
srcs = ["image_segmenter_options.proto"],
|
||||
name = "image_segmenter_graph_options_proto",
|
||||
srcs = ["image_segmenter_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
|
|
@ -21,9 +21,12 @@ import "mediapipe/framework/calculator.proto";
|
|||
import "mediapipe/tasks/cc/components/proto/segmenter_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message ImageSegmenterOptions {
|
||||
option java_package = "com.google.mediapipe.tasks.vision.imagesegmenter.proto";
|
||||
option java_outer_classname = "ImageSegmenterGraphOptionsProto";
|
||||
|
||||
message ImageSegmenterGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional ImageSegmenterOptions ext = 458105758;
|
||||
optional ImageSegmenterGraphOptions ext = 458105758;
|
||||
}
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
|
@ -20,6 +20,7 @@ android_library(
|
|||
name = "category",
|
||||
srcs = ["Category.java"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
|
@ -36,20 +37,29 @@ android_library(
|
|||
)
|
||||
|
||||
android_library(
|
||||
name = "classification_entry",
|
||||
srcs = ["ClassificationEntry.java"],
|
||||
name = "classifications",
|
||||
srcs = ["Classifications.java"],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
deps = [
|
||||
":category",
|
||||
"//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "classifications",
|
||||
srcs = ["Classifications.java"],
|
||||
name = "classificationresult",
|
||||
srcs = ["ClassificationResult.java"],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
deps = [
|
||||
":classification_entry",
|
||||
":classifications",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
|
@ -38,6 +39,16 @@ public abstract class Category {
|
|||
return new AutoValue_Category(score, index, categoryName, displayName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link Category} object from a {@link ClassificationProto.Classification} protobuf
|
||||
* message.
|
||||
*
|
||||
* @param proto the {@link ClassificationProto.Classification} protobuf message to convert.
|
||||
*/
|
||||
public static Category createFromProto(ClassificationProto.Classification proto) {
|
||||
return create(proto.getScore(), proto.getIndex(), proto.getLabel(), proto.getDisplayName());
|
||||
}
|
||||
|
||||
/** The probability score of this label category. */
|
||||
public abstract float score();
|
||||
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Represents a list of predicted categories with an optional timestamp. Typically used as result
|
||||
* for classification tasks.
|
||||
*/
|
||||
@AutoValue
|
||||
public abstract class ClassificationEntry {
|
||||
/**
|
||||
* Creates a {@link ClassificationEntry} instance from a list of {@link Category} and optional
|
||||
* timestamp.
|
||||
*
|
||||
* @param categories the list of {@link Category} objects that contain category name, display
|
||||
* name, score and label index.
|
||||
* @param timestampMs the {@link long} representing the timestamp for which these categories were
|
||||
* obtained.
|
||||
*/
|
||||
public static ClassificationEntry create(List<Category> categories, long timestampMs) {
|
||||
return new AutoValue_ClassificationEntry(Collections.unmodifiableList(categories), timestampMs);
|
||||
}
|
||||
|
||||
/** The list of predicted {@link Category} objects, sorted by descending score. */
|
||||
public abstract List<Category> categories();
|
||||
|
||||
/**
|
||||
* The timestamp (in milliseconds) associated to the classification entry. This is useful for time
|
||||
* series use cases, e.g. audio classification.
|
||||
*/
|
||||
public abstract long timestampMs();
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Represents the classification results of a model. Typically used as a result for classification
|
||||
* tasks.
|
||||
*/
|
||||
@AutoValue
|
||||
public abstract class ClassificationResult {
|
||||
|
||||
/**
|
||||
* Creates a {@link ClassificationResult} instance.
|
||||
*
|
||||
* @param classifications the list of {@link Classifications} objects containing the predicted
|
||||
* categories for each head of the model.
|
||||
* @param timestampMs the optional timestamp (in milliseconds) of the start of the chunk of data
|
||||
* corresponding to these results.
|
||||
*/
|
||||
public static ClassificationResult create(
|
||||
List<Classifications> classifications, Optional<Long> timestampMs) {
|
||||
return new AutoValue_ClassificationResult(
|
||||
Collections.unmodifiableList(classifications), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link ClassificationResult} object from a {@link
|
||||
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||
*
|
||||
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
|
||||
*/
|
||||
public static ClassificationResult createFromProto(
|
||||
ClassificationsProto.ClassificationResult proto) {
|
||||
List<Classifications> classifications = new ArrayList<>();
|
||||
for (ClassificationsProto.Classifications classificationsProto :
|
||||
proto.getClassificationsList()) {
|
||||
classifications.add(Classifications.createFromProto(classificationsProto));
|
||||
}
|
||||
Optional<Long> timestampMs =
|
||||
proto.hasTimestampMs() ? Optional.of(proto.getTimestampMs()) : Optional.empty();
|
||||
return create(classifications, timestampMs);
|
||||
}
|
||||
|
||||
/** The classification results for each head of the model. */
|
||||
public abstract List<Classifications> classifications();
|
||||
|
||||
/**
|
||||
* The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to
|
||||
* these results.
|
||||
*
|
||||
* <p>This is only used for classification on time series (e.g. audio classification). In these
|
||||
* use cases, the amount of data to process might exceed the maximum size that the model can
|
||||
* process: to solve this, the input data is split into multiple chunks starting at different
|
||||
* timestamps.
|
||||
*/
|
||||
public abstract Optional<Long> timestampMs();
|
||||
}
|
|
@ -15,8 +15,12 @@
|
|||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Represents the list of classification for a given classifier head. Typically used as a result for
|
||||
|
@ -28,25 +32,41 @@ public abstract class Classifications {
|
|||
/**
|
||||
* Creates a {@link Classifications} instance.
|
||||
*
|
||||
* @param entries the list of {@link ClassificationEntry} objects containing the predicted
|
||||
* categories.
|
||||
* @param categories the list of {@link Category} objects containing the predicted categories.
|
||||
* @param headIndex the index of the classifier head.
|
||||
* @param headName the name of the classifier head.
|
||||
* @param headName the optional name of the classifier head.
|
||||
*/
|
||||
public static Classifications create(
|
||||
List<ClassificationEntry> entries, int headIndex, String headName) {
|
||||
List<Category> categories, int headIndex, Optional<String> headName) {
|
||||
return new AutoValue_Classifications(
|
||||
Collections.unmodifiableList(entries), headIndex, headName);
|
||||
Collections.unmodifiableList(categories), headIndex, headName);
|
||||
}
|
||||
|
||||
/** A list of {@link ClassificationEntry} objects. */
|
||||
public abstract List<ClassificationEntry> entries();
|
||||
/**
|
||||
* Creates a {@link Classifications} object from a {@link ClassificationsProto.Classifications}
|
||||
* protobuf message.
|
||||
*
|
||||
* @param proto the {@link ClassificationsProto.Classifications} protobuf message to convert.
|
||||
*/
|
||||
public static Classifications createFromProto(ClassificationsProto.Classifications proto) {
|
||||
List<Category> categories = new ArrayList<>();
|
||||
for (ClassificationProto.Classification classificationProto :
|
||||
proto.getClassificationList().getClassificationList()) {
|
||||
categories.add(Category.createFromProto(classificationProto));
|
||||
}
|
||||
Optional<String> headName =
|
||||
proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty();
|
||||
return create(categories, proto.getHeadIndex(), headName);
|
||||
}
|
||||
|
||||
/** A list of {@link Category} objects. */
|
||||
public abstract List<Category> categories();
|
||||
|
||||
/**
|
||||
* The index of the classifier head these entries refer to. This is useful for multi-head models.
|
||||
*/
|
||||
public abstract int headIndex();
|
||||
|
||||
/** The name of the classifier head, which is the corresponding tensor metadata name. */
|
||||
public abstract String headName();
|
||||
/** The optional name of the classifier head, which is the corresponding tensor metadata name. */
|
||||
public abstract Optional<String> headName();
|
||||
}
|
||||
|
|
|
@ -26,22 +26,23 @@ public abstract class BaseOptions {
|
|||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/**
|
||||
* Sets the model path to a tflite model with metadata in the assets.
|
||||
* Sets the model path to a model asset file (a tflite model or a model asset bundle file) in
|
||||
* the Android app assets folder.
|
||||
*
|
||||
* <p>Note: when model path is set, both model file descriptor and model buffer should be empty.
|
||||
*/
|
||||
public abstract Builder setModelAssetPath(String value);
|
||||
|
||||
/**
|
||||
* Sets the native fd int of a tflite model with metadata.
|
||||
* Sets the native fd int of a model asset file (a tflite model or a model asset bundle file).
|
||||
*
|
||||
* <p>Note: when model file descriptor is set, both model path and model buffer should be empty.
|
||||
*/
|
||||
public abstract Builder setModelAssetFileDescriptor(Integer value);
|
||||
|
||||
/**
|
||||
* Sets either the direct {@link ByteBuffer} or the {@link MappedByteBuffer} of a tflite model
|
||||
* with metadata.
|
||||
* Sets either the direct {@link ByteBuffer} or the {@link MappedByteBuffer} of a model asset
|
||||
* file (a tflite model or a model asset bundle file).
|
||||
*
|
||||
* <p>Note: when model buffer is set, both model file and model file descriptor should be empty.
|
||||
*/
|
||||
|
|
|
@ -22,6 +22,7 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
|||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/proto:segmenter_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
|
||||
|
@ -36,6 +37,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
|||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
|
||||
|
@ -232,7 +234,7 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l
|
|||
"//mediapipe/framework/formats:rect_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||
|
|
|
@ -37,8 +37,8 @@ cc_library(
|
|||
android_library(
|
||||
name = "textclassifier",
|
||||
srcs = [
|
||||
"textclassifier/TextClassificationResult.java",
|
||||
"textclassifier/TextClassifier.java",
|
||||
"textclassifier/TextClassifierResult.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
|
@ -51,9 +51,7 @@ android_library(
|
|||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
|
||||
|
|
|
@ -1,103 +0,0 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.text.textclassifier;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
|
||||
import com.google.mediapipe.tasks.components.containers.Classifications;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/** Represents the classification results generated by {@link TextClassifier}. */
|
||||
@AutoValue
|
||||
public abstract class TextClassificationResult implements TaskResult {
|
||||
|
||||
/**
|
||||
* Creates an {@link TextClassificationResult} instance from a {@link
|
||||
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||
*
|
||||
* @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf
|
||||
* message.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
// TODO: consolidate output formats across platforms.
|
||||
static TextClassificationResult create(
|
||||
ClassificationsProto.ClassificationResult classificationResult, long timestampMs) {
|
||||
List<Classifications> classifications = new ArrayList<>();
|
||||
for (ClassificationsProto.Classifications classificationsProto :
|
||||
classificationResult.getClassificationsList()) {
|
||||
classifications.add(classificationsFromProto(classificationsProto));
|
||||
}
|
||||
return new AutoValue_TextClassificationResult(
|
||||
timestampMs, Collections.unmodifiableList(classifications));
|
||||
}
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
|
||||
/** Contains one set of results per classifier head. */
|
||||
@SuppressWarnings("AutoValueImmutableFields")
|
||||
public abstract List<Classifications> classifications();
|
||||
|
||||
/**
|
||||
* Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object.
|
||||
*
|
||||
* @param category the {@link CategoryProto.Category} protobuf message to convert.
|
||||
*/
|
||||
static Category categoryFromProto(CategoryProto.Category category) {
|
||||
return Category.create(
|
||||
category.getScore(),
|
||||
category.getIndex(),
|
||||
category.getCategoryName(),
|
||||
category.getDisplayName());
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link
|
||||
* ClassificationEntry} object.
|
||||
*
|
||||
* @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert.
|
||||
*/
|
||||
static ClassificationEntry classificationEntryFromProto(
|
||||
ClassificationsProto.ClassificationEntry entry) {
|
||||
List<Category> categories = new ArrayList<>();
|
||||
for (CategoryProto.Category category : entry.getCategoriesList()) {
|
||||
categories.add(categoryFromProto(category));
|
||||
}
|
||||
return ClassificationEntry.create(categories, entry.getTimestampMs());
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link
|
||||
* Classifications} object.
|
||||
*
|
||||
* @param classifications the {@link ClassificationsProto.Classifications} protobuf message to
|
||||
* convert.
|
||||
*/
|
||||
static Classifications classificationsFromProto(
|
||||
ClassificationsProto.Classifications classifications) {
|
||||
List<ClassificationEntry> entries = new ArrayList<>();
|
||||
for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
|
||||
entries.add(classificationEntryFromProto(entry));
|
||||
}
|
||||
return Classifications.create(
|
||||
entries, classifications.getHeadIndex(), classifications.getHeadName());
|
||||
}
|
||||
}
|
|
@ -22,6 +22,7 @@ import com.google.mediapipe.framework.MediaPipeException;
|
|||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.ProtoUtil;
|
||||
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
|
@ -86,10 +87,9 @@ public final class TextClassifier implements AutoCloseable {
|
|||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList("CLASSIFICATION_RESULT:classification_result_out"));
|
||||
Collections.unmodifiableList(Arrays.asList("CLASSIFICATIONS:classifications_out"));
|
||||
|
||||
private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0;
|
||||
private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||
private final TaskRunner runner;
|
||||
|
@ -142,17 +142,18 @@ public final class TextClassifier implements AutoCloseable {
|
|||
* @throws MediaPipeException if there is an error during {@link TextClassifier} creation.
|
||||
*/
|
||||
public static TextClassifier createFromOptions(Context context, TextClassifierOptions options) {
|
||||
OutputHandler<TextClassificationResult, Void> handler = new OutputHandler<>();
|
||||
OutputHandler<TextClassifierResult, Void> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<TextClassificationResult, Void>() {
|
||||
new OutputHandler.OutputPacketConverter<TextClassifierResult, Void>() {
|
||||
@Override
|
||||
public TextClassificationResult convertToTaskResult(List<Packet> packets) {
|
||||
public TextClassifierResult convertToTaskResult(List<Packet> packets) {
|
||||
try {
|
||||
return TextClassificationResult.create(
|
||||
return TextClassifierResult.create(
|
||||
ClassificationResult.createFromProto(
|
||||
PacketGetter.getProto(
|
||||
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX),
|
||||
ClassificationsProto.ClassificationResult.getDefaultInstance()),
|
||||
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp());
|
||||
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
|
||||
ClassificationsProto.ClassificationResult.getDefaultInstance())),
|
||||
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp());
|
||||
} catch (IOException e) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
||||
|
@ -192,10 +193,10 @@ public final class TextClassifier implements AutoCloseable {
|
|||
*
|
||||
* @param inputText a {@link String} for processing.
|
||||
*/
|
||||
public TextClassificationResult classify(String inputText) {
|
||||
public TextClassifierResult classify(String inputText) {
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText));
|
||||
return (TextClassificationResult) runner.process(inputPackets);
|
||||
return (TextClassifierResult) runner.process(inputPackets);
|
||||
}
|
||||
|
||||
/** Closes and cleans up the {@link TextClassifier}. */
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.text.textclassifier;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
|
||||
/** Represents the classification results generated by {@link TextClassifier}. */
|
||||
@AutoValue
|
||||
public abstract class TextClassifierResult implements TaskResult {
|
||||
|
||||
/**
|
||||
* Creates an {@link TextClassifierResult} instance.
|
||||
*
|
||||
* @param classificationResult the {@link ClassificationResult} object containing one set of
|
||||
* results per classifier head.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
static TextClassifierResult create(ClassificationResult classificationResult, long timestampMs) {
|
||||
return new AutoValue_TextClassifierResult(classificationResult, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link TextClassifierResult} instance from a {@link
|
||||
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||
*
|
||||
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
// TODO: consolidate output formats across platforms.
|
||||
static TextClassifierResult createFromProto(
|
||||
ClassificationsProto.ClassificationResult proto, long timestampMs) {
|
||||
return create(ClassificationResult.createFromProto(proto), timestampMs);
|
||||
}
|
||||
|
||||
/** Contains one set of results per classifier head. */
|
||||
public abstract ClassificationResult classificationResult();
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
}
|
|
@ -84,8 +84,8 @@ android_library(
|
|||
android_library(
|
||||
name = "imageclassifier",
|
||||
srcs = [
|
||||
"imageclassifier/ImageClassificationResult.java",
|
||||
"imageclassifier/ImageClassifier.java",
|
||||
"imageclassifier/ImageClassifierResult.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
|
@ -100,9 +100,7 @@ android_library(
|
|||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
|
|
|
@ -1,102 +0,0 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.imageclassifier;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
|
||||
import com.google.mediapipe.tasks.components.containers.Classifications;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/** Represents the classification results generated by {@link ImageClassifier}. */
|
||||
@AutoValue
|
||||
public abstract class ImageClassificationResult implements TaskResult {
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageClassificationResult} instance from a {@link
|
||||
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||
*
|
||||
* @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf
|
||||
* message.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
// TODO: consolidate output formats across platforms.
|
||||
static ImageClassificationResult create(
|
||||
ClassificationsProto.ClassificationResult classificationResult, long timestampMs) {
|
||||
List<Classifications> classifications = new ArrayList<>();
|
||||
for (ClassificationsProto.Classifications classificationsProto :
|
||||
classificationResult.getClassificationsList()) {
|
||||
classifications.add(classificationsFromProto(classificationsProto));
|
||||
}
|
||||
return new AutoValue_ImageClassificationResult(
|
||||
timestampMs, Collections.unmodifiableList(classifications));
|
||||
}
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
|
||||
/** Contains one set of results per classifier head. */
|
||||
public abstract List<Classifications> classifications();
|
||||
|
||||
/**
|
||||
* Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object.
|
||||
*
|
||||
* @param category the {@link CategoryProto.Category} protobuf message to convert.
|
||||
*/
|
||||
static Category categoryFromProto(CategoryProto.Category category) {
|
||||
return Category.create(
|
||||
category.getScore(),
|
||||
category.getIndex(),
|
||||
category.getCategoryName(),
|
||||
category.getDisplayName());
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link
|
||||
* ClassificationEntry} object.
|
||||
*
|
||||
* @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert.
|
||||
*/
|
||||
static ClassificationEntry classificationEntryFromProto(
|
||||
ClassificationsProto.ClassificationEntry entry) {
|
||||
List<Category> categories = new ArrayList<>();
|
||||
for (CategoryProto.Category category : entry.getCategoriesList()) {
|
||||
categories.add(categoryFromProto(category));
|
||||
}
|
||||
return ClassificationEntry.create(categories, entry.getTimestampMs());
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link
|
||||
* Classifications} object.
|
||||
*
|
||||
* @param classifications the {@link ClassificationsProto.Classifications} protobuf message to
|
||||
* convert.
|
||||
*/
|
||||
static Classifications classificationsFromProto(
|
||||
ClassificationsProto.Classifications classifications) {
|
||||
List<ClassificationEntry> entries = new ArrayList<>();
|
||||
for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
|
||||
entries.add(classificationEntryFromProto(entry));
|
||||
}
|
||||
return Classifications.create(
|
||||
entries, classifications.getHeadIndex(), classifications.getHeadName());
|
||||
}
|
||||
}
|
|
@ -25,6 +25,7 @@ import com.google.mediapipe.framework.PacketGetter;
|
|||
import com.google.mediapipe.framework.ProtoUtil;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
|
@ -96,8 +97,8 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList("CLASSIFICATION_RESULT:classification_result_out", "IMAGE:image_out"));
|
||||
private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0;
|
||||
Arrays.asList("CLASSIFICATIONS:classifications_out", "IMAGE:image_out"));
|
||||
private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
|
||||
|
@ -164,17 +165,18 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
* @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
|
||||
*/
|
||||
public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) {
|
||||
OutputHandler<ImageClassificationResult, MPImage> handler = new OutputHandler<>();
|
||||
OutputHandler<ImageClassifierResult, MPImage> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<ImageClassificationResult, MPImage>() {
|
||||
new OutputHandler.OutputPacketConverter<ImageClassifierResult, MPImage>() {
|
||||
@Override
|
||||
public ImageClassificationResult convertToTaskResult(List<Packet> packets) {
|
||||
public ImageClassifierResult convertToTaskResult(List<Packet> packets) {
|
||||
try {
|
||||
return ImageClassificationResult.create(
|
||||
return ImageClassifierResult.create(
|
||||
ClassificationResult.createFromProto(
|
||||
PacketGetter.getProto(
|
||||
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX),
|
||||
ClassificationsProto.ClassificationResult.getDefaultInstance()),
|
||||
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp());
|
||||
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
|
||||
ClassificationsProto.ClassificationResult.getDefaultInstance())),
|
||||
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp());
|
||||
} catch (IOException e) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
||||
|
@ -229,7 +231,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ImageClassificationResult classify(MPImage image) {
|
||||
public ImageClassifierResult classify(MPImage image) {
|
||||
return classify(image, ImageProcessingOptions.builder().build());
|
||||
}
|
||||
|
||||
|
@ -248,9 +250,9 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
* input image before running inference.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ImageClassificationResult classify(
|
||||
public ImageClassifierResult classify(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
return (ImageClassificationResult) processImageData(image, imageProcessingOptions);
|
||||
return (ImageClassifierResult) processImageData(image, imageProcessingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -271,7 +273,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ImageClassificationResult classifyForVideo(MPImage image, long timestampMs) {
|
||||
public ImageClassifierResult classifyForVideo(MPImage image, long timestampMs) {
|
||||
return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
|
@ -294,9 +296,9 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ImageClassificationResult classifyForVideo(
|
||||
public ImageClassifierResult classifyForVideo(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
return (ImageClassificationResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
return (ImageClassifierResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -383,7 +385,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
* the image classifier is in the live stream mode.
|
||||
*/
|
||||
public abstract Builder setResultListener(
|
||||
ResultListener<ImageClassificationResult, MPImage> resultListener);
|
||||
ResultListener<ImageClassifierResult, MPImage> resultListener);
|
||||
|
||||
/** Sets an optional {@link ErrorListener}. */
|
||||
public abstract Builder setErrorListener(ErrorListener errorListener);
|
||||
|
@ -420,7 +422,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
|
||||
abstract Optional<ClassifierOptions> classifierOptions();
|
||||
|
||||
abstract Optional<ResultListener<ImageClassificationResult, MPImage>> resultListener();
|
||||
abstract Optional<ResultListener<ImageClassifierResult, MPImage>> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.imageclassifier;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
|
||||
/** Represents the classification results generated by {@link ImageClassifier}. */
|
||||
@AutoValue
|
||||
public abstract class ImageClassifierResult implements TaskResult {
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageClassifierResult} instance.
|
||||
*
|
||||
* @param classificationResult the {@link ClassificationResult} object containing one set of
|
||||
* results per classifier head.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
static ImageClassifierResult create(ClassificationResult classificationResult, long timestampMs) {
|
||||
return new AutoValue_ImageClassifierResult(classificationResult, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageClassifierResult} instance from a {@link
|
||||
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||
*
|
||||
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
// TODO: consolidate output formats across platforms.
|
||||
static ImageClassifierResult createFromProto(
|
||||
ClassificationsProto.ClassificationResult proto, long timestampMs) {
|
||||
return create(ClassificationResult.createFromProto(proto), timestampMs);
|
||||
}
|
||||
|
||||
/** Contains one set of results per classifier head. */
|
||||
public abstract ClassificationResult classificationResult();
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
}
|
|
@ -76,7 +76,7 @@ public class TextClassifierTest {
|
|||
public void classify_succeedsWithBert() throws Exception {
|
||||
TextClassifier textClassifier =
|
||||
TextClassifier.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
|
||||
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
||||
TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
||||
assertHasOneHead(negativeResults);
|
||||
assertCategoriesAre(
|
||||
negativeResults,
|
||||
|
@ -84,7 +84,7 @@ public class TextClassifierTest {
|
|||
Category.create(0.95630914f, 0, "negative", ""),
|
||||
Category.create(0.04369091f, 1, "positive", "")));
|
||||
|
||||
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
||||
TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
||||
assertHasOneHead(positiveResults);
|
||||
assertCategoriesAre(
|
||||
positiveResults,
|
||||
|
@ -99,7 +99,7 @@ public class TextClassifierTest {
|
|||
TextClassifier.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(),
|
||||
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE));
|
||||
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
||||
TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
||||
assertHasOneHead(negativeResults);
|
||||
assertCategoriesAre(
|
||||
negativeResults,
|
||||
|
@ -107,7 +107,7 @@ public class TextClassifierTest {
|
|||
Category.create(0.95630914f, 0, "negative", ""),
|
||||
Category.create(0.04369091f, 1, "positive", "")));
|
||||
|
||||
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
||||
TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
||||
assertHasOneHead(positiveResults);
|
||||
assertHasOneHead(positiveResults);
|
||||
assertCategoriesAre(
|
||||
|
@ -122,7 +122,7 @@ public class TextClassifierTest {
|
|||
TextClassifier textClassifier =
|
||||
TextClassifier.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE);
|
||||
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
||||
TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
||||
assertHasOneHead(negativeResults);
|
||||
assertCategoriesAre(
|
||||
negativeResults,
|
||||
|
@ -130,7 +130,7 @@ public class TextClassifierTest {
|
|||
Category.create(0.6647746f, 0, "Negative", ""),
|
||||
Category.create(0.33522537f, 1, "Positive", "")));
|
||||
|
||||
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
||||
TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
||||
assertHasOneHead(positiveResults);
|
||||
assertCategoriesAre(
|
||||
positiveResults,
|
||||
|
@ -139,16 +139,15 @@ public class TextClassifierTest {
|
|||
Category.create(0.48799595f, 1, "Positive", "")));
|
||||
}
|
||||
|
||||
private static void assertHasOneHead(TextClassificationResult results) {
|
||||
assertThat(results.classifications()).hasSize(1);
|
||||
assertThat(results.classifications().get(0).headIndex()).isEqualTo(0);
|
||||
assertThat(results.classifications().get(0).headName()).isEqualTo("probability");
|
||||
assertThat(results.classifications().get(0).entries()).hasSize(1);
|
||||
private static void assertHasOneHead(TextClassifierResult results) {
|
||||
assertThat(results.classificationResult().classifications()).hasSize(1);
|
||||
assertThat(results.classificationResult().classifications().get(0).headIndex()).isEqualTo(0);
|
||||
assertThat(results.classificationResult().classifications().get(0).headName().get())
|
||||
.isEqualTo("probability");
|
||||
}
|
||||
|
||||
private static void assertCategoriesAre(
|
||||
TextClassificationResult results, List<Category> categories) {
|
||||
assertThat(results.classifications().get(0).entries().get(0).categories())
|
||||
private static void assertCategoriesAre(TextClassifierResult results, List<Category> categories) {
|
||||
assertThat(results.classificationResult().classifications().get(0).categories())
|
||||
.isEqualTo(categories);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -91,11 +91,12 @@ public class ImageClassifierTest {
|
|||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE);
|
||||
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||
ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||
|
||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||
assertThat(results.classifications().get(0).entries().get(0).categories()).hasSize(1001);
|
||||
assertThat(results.classifications().get(0).entries().get(0).categories().get(0))
|
||||
assertHasOneHead(results);
|
||||
assertThat(results.classificationResult().classifications().get(0).categories())
|
||||
.hasSize(1001);
|
||||
assertThat(results.classificationResult().classifications().get(0).categories().get(0))
|
||||
.isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", ""));
|
||||
}
|
||||
|
||||
|
@ -108,9 +109,9 @@ public class ImageClassifierTest {
|
|||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||
ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||
|
||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||
assertHasOneHead(results);
|
||||
assertCategoriesAre(
|
||||
results,
|
||||
Arrays.asList(
|
||||
|
@ -128,9 +129,9 @@ public class ImageClassifierTest {
|
|||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||
ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||
|
||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||
assertHasOneHead(results);
|
||||
assertCategoriesAre(
|
||||
results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", "")));
|
||||
}
|
||||
|
@ -144,9 +145,9 @@ public class ImageClassifierTest {
|
|||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||
ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||
|
||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||
assertHasOneHead(results);
|
||||
assertCategoriesAre(
|
||||
results,
|
||||
Arrays.asList(
|
||||
|
@ -166,9 +167,9 @@ public class ImageClassifierTest {
|
|||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||
ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||
|
||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||
assertHasOneHead(results);
|
||||
assertCategoriesAre(
|
||||
results,
|
||||
Arrays.asList(
|
||||
|
@ -190,9 +191,9 @@ public class ImageClassifierTest {
|
|||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||
ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||
|
||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||
assertHasOneHead(results);
|
||||
assertCategoriesAre(
|
||||
results,
|
||||
Arrays.asList(
|
||||
|
@ -214,10 +215,10 @@ public class ImageClassifierTest {
|
|||
RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRegionOfInterest(roi).build();
|
||||
ImageClassificationResult results =
|
||||
ImageClassifierResult results =
|
||||
imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions);
|
||||
|
||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||
assertHasOneHead(results);
|
||||
assertCategoriesAre(
|
||||
results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", "")));
|
||||
}
|
||||
|
@ -233,10 +234,10 @@ public class ImageClassifierTest {
|
|||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
|
||||
ImageClassificationResult results =
|
||||
ImageClassifierResult results =
|
||||
imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
|
||||
|
||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||
assertHasOneHead(results);
|
||||
assertCategoriesAre(
|
||||
results,
|
||||
Arrays.asList(
|
||||
|
@ -258,11 +259,11 @@ public class ImageClassifierTest {
|
|||
RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build();
|
||||
ImageClassificationResult results =
|
||||
ImageClassifierResult results =
|
||||
imageClassifier.classify(
|
||||
getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions);
|
||||
|
||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||
assertHasOneHead(results);
|
||||
assertCategoriesAre(
|
||||
results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", "")));
|
||||
}
|
||||
|
@ -391,9 +392,9 @@ public class ImageClassifierTest {
|
|||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||
ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||
|
||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||
assertHasOneHead(results);
|
||||
assertCategoriesAre(
|
||||
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
||||
}
|
||||
|
@ -410,9 +411,8 @@ public class ImageClassifierTest {
|
|||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
ImageClassificationResult results =
|
||||
imageClassifier.classifyForVideo(image, /*timestampMs=*/ i);
|
||||
assertHasOneHeadAndOneTimestamp(results, i);
|
||||
ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i);
|
||||
assertHasOneHead(results);
|
||||
assertCategoriesAre(
|
||||
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
||||
}
|
||||
|
@ -478,24 +478,17 @@ public class ImageClassifierTest {
|
|||
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
|
||||
}
|
||||
|
||||
private static void assertHasOneHeadAndOneTimestamp(
|
||||
ImageClassificationResult results, long timestampMs) {
|
||||
assertThat(results.classifications()).hasSize(1);
|
||||
assertThat(results.classifications().get(0).headIndex()).isEqualTo(0);
|
||||
assertThat(results.classifications().get(0).headName()).isEqualTo("probability");
|
||||
assertThat(results.classifications().get(0).entries()).hasSize(1);
|
||||
assertThat(results.classifications().get(0).entries().get(0).timestampMs())
|
||||
.isEqualTo(timestampMs);
|
||||
private static void assertHasOneHead(ImageClassifierResult results) {
|
||||
assertThat(results.classificationResult().classifications()).hasSize(1);
|
||||
assertThat(results.classificationResult().classifications().get(0).headIndex()).isEqualTo(0);
|
||||
assertThat(results.classificationResult().classifications().get(0).headName().get())
|
||||
.isEqualTo("probability");
|
||||
}
|
||||
|
||||
private static void assertCategoriesAre(
|
||||
ImageClassificationResult results, List<Category> categories) {
|
||||
assertThat(results.classifications().get(0).entries().get(0).categories())
|
||||
.hasSize(categories.size());
|
||||
for (int i = 0; i < categories.size(); i++) {
|
||||
assertThat(results.classifications().get(0).entries().get(0).categories().get(i))
|
||||
.isEqualTo(categories.get(i));
|
||||
}
|
||||
ImageClassifierResult results, List<Category> categories) {
|
||||
assertThat(results.classificationResult().classifications().get(0).categories())
|
||||
.isEqualTo(categories);
|
||||
}
|
||||
|
||||
private static void assertImageSizeIsExpected(MPImage inputImage) {
|
||||
|
|
|
@ -43,3 +43,9 @@ py_library(
|
|||
srcs = ["image_classifier.py"],
|
||||
deps = [":metadata_writer"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "text_classifier",
|
||||
srcs = ["text_classifier.py"],
|
||||
deps = [":metadata_writer"],
|
||||
)
|
||||
|
|
|
@ -62,10 +62,10 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
|
|||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
||||
|
||||
Returns:
|
||||
An MetadataWrite object.
|
||||
A MetadataWriter object.
|
||||
"""
|
||||
writer = metadata_writer.MetadataWriter(model_buffer)
|
||||
writer.add_genernal_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||
writer.add_image_input(input_norm_mean, input_norm_std)
|
||||
writer.add_classification_output(labels, score_calibration)
|
||||
return cls(writer)
|
||||
|
|
|
@ -228,6 +228,45 @@ class ScoreThresholdingMd:
|
|||
return score_thresholding
|
||||
|
||||
|
||||
class RegexTokenizerMd:
|
||||
"""A container for the Regex tokenizer [1] metadata information.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||
"""
|
||||
|
||||
def __init__(self, delim_regex_pattern: str, vocab_file_path: str):
|
||||
"""Initializes a RegexTokenizerMd object.
|
||||
|
||||
Args:
|
||||
delim_regex_pattern: the regular expression to segment strings and create
|
||||
tokens.
|
||||
vocab_file_path: path to the vocabulary file.
|
||||
"""
|
||||
self._delim_regex_pattern = delim_regex_pattern
|
||||
self._vocab_file_path = vocab_file_path
|
||||
|
||||
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
|
||||
"""Creates the Regex tokenizer metadata based on the information.
|
||||
|
||||
Returns:
|
||||
A Flatbuffers Python object of the Regex tokenizer metadata.
|
||||
"""
|
||||
vocab = _metadata_fb.AssociatedFileT()
|
||||
vocab.name = self._vocab_file_path
|
||||
vocab.description = _VOCAB_FILE_DESCRIPTION
|
||||
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
|
||||
|
||||
# Create the RegexTokenizer.
|
||||
tokenizer = _metadata_fb.ProcessUnitT()
|
||||
tokenizer.optionsType = (
|
||||
_metadata_fb.ProcessUnitOptions.RegexTokenizerOptions)
|
||||
tokenizer.options = _metadata_fb.RegexTokenizerOptionsT()
|
||||
tokenizer.options.delimRegexPattern = self._delim_regex_pattern
|
||||
tokenizer.options.vocabFile = [vocab]
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TensorMd:
|
||||
"""A container for common tensor metadata information.
|
||||
|
||||
|
@ -397,6 +436,56 @@ class InputImageTensorMd(TensorMd):
|
|||
return tensor_metadata
|
||||
|
||||
|
||||
class InputTextTensorMd(TensorMd):
|
||||
"""A container for the input text tensor metadata information.
|
||||
|
||||
Attributes:
|
||||
tokenizer_md: information of the tokenizer in the input text tensor, if any.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
tokenizer_md: Optional[RegexTokenizerMd] = None):
|
||||
"""Initializes the instance of InputTextTensorMd.
|
||||
|
||||
Args:
|
||||
name: name of the tensor.
|
||||
description: description of what the tensor is.
|
||||
tokenizer_md: information of the tokenizer in the input text tensor, if
|
||||
any. Only `RegexTokenizer` [1] is currenly supported. If the tokenizer
|
||||
is `BertTokenizer` [2] or `SentencePieceTokenizer` [3], refer to
|
||||
`BertInputTensorsMd` class.
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||
[3]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||
"""
|
||||
super().__init__(name, description)
|
||||
self.tokenizer_md = tokenizer_md
|
||||
|
||||
def create_metadata(self) -> _metadata_fb.TensorMetadataT:
|
||||
"""Creates the input text metadata based on the information.
|
||||
|
||||
Returns:
|
||||
A Flatbuffers Python object of the input text metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: if the type of tokenizer_md is unsupported.
|
||||
"""
|
||||
if not isinstance(self.tokenizer_md, (type(None), RegexTokenizerMd)):
|
||||
raise ValueError(
|
||||
f"The type of tokenizer_options, {type(self.tokenizer_md)}, is "
|
||||
f"unsupported")
|
||||
|
||||
tensor_metadata = super().create_metadata()
|
||||
if self.tokenizer_md:
|
||||
tensor_metadata.processUnits = [self.tokenizer_md.create_metadata()]
|
||||
return tensor_metadata
|
||||
|
||||
|
||||
class ClassificationTensorMd(TensorMd):
|
||||
"""A container for the classification tensor metadata information.
|
||||
|
||||
|
|
|
@ -29,6 +29,9 @@ from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
|
|||
|
||||
_INPUT_IMAGE_NAME = 'image'
|
||||
_INPUT_IMAGE_DESCRIPTION = 'Input image to be processed.'
|
||||
_INPUT_REGEX_TEXT_NAME = 'input_text'
|
||||
_INPUT_REGEX_TEXT_DESCRIPTION = ('Embedding vectors representing the input '
|
||||
'text to be processed.')
|
||||
_OUTPUT_CLASSIFICATION_NAME = 'score'
|
||||
_OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively.'
|
||||
|
||||
|
@ -82,6 +85,22 @@ class ScoreThresholding:
|
|||
global_score_threshold: float
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RegexTokenizer:
|
||||
"""Parameters of the Regex tokenizer [1] metadata information.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||
|
||||
Attributes:
|
||||
delim_regex_pattern: the regular expression to segment strings and create
|
||||
tokens.
|
||||
vocab_file_path: path to the vocabulary file.
|
||||
"""
|
||||
delim_regex_pattern: str
|
||||
vocab_file_path: str
|
||||
|
||||
|
||||
class Labels(object):
|
||||
"""Simple container holding classification labels of a particular tensor.
|
||||
|
||||
|
@ -355,11 +374,11 @@ class MetadataWriter(object):
|
|||
if os.path.exists(self._temp_folder.name):
|
||||
self._temp_folder.cleanup()
|
||||
|
||||
def add_genernal_info(
|
||||
def add_general_info(
|
||||
self,
|
||||
model_name: str,
|
||||
model_description: Optional[str] = None) -> 'MetadataWriter':
|
||||
"""Adds a genernal info metadata for the general metadata informantion."""
|
||||
"""Adds a general info metadata for the general metadata informantion."""
|
||||
# Will overwrite the previous `self._general_md` if exists.
|
||||
self._general_md = metadata_info.GeneralMd(
|
||||
name=model_name, description=model_description)
|
||||
|
@ -415,6 +434,34 @@ class MetadataWriter(object):
|
|||
self._input_mds.append(input_md)
|
||||
return self
|
||||
|
||||
def add_regex_text_input(
|
||||
self,
|
||||
regex_tokenizer: RegexTokenizer,
|
||||
name: str = _INPUT_REGEX_TEXT_NAME,
|
||||
description: str = _INPUT_REGEX_TEXT_DESCRIPTION) -> 'MetadataWriter':
|
||||
"""Adds an input text metadata for the text input with regex tokenizer.
|
||||
|
||||
Args:
|
||||
regex_tokenizer: information of the regex tokenizer [1] used to process
|
||||
the input string.
|
||||
name: Name of the input tensor.
|
||||
description: Description of the input tensor.
|
||||
|
||||
Returns:
|
||||
The MetaWriter instance, can be used for chained operation.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||
"""
|
||||
tokenizer_md = metadata_info.RegexTokenizerMd(
|
||||
delim_regex_pattern=regex_tokenizer.delim_regex_pattern,
|
||||
vocab_file_path=regex_tokenizer.vocab_file_path)
|
||||
input_md = metadata_info.InputTextTensorMd(
|
||||
name=name, description=description, tokenizer_md=tokenizer_md)
|
||||
self._input_mds.append(input_md)
|
||||
self._associated_files.append(regex_tokenizer.vocab_file_path)
|
||||
return self
|
||||
|
||||
def add_classification_output(
|
||||
self,
|
||||
labels: Optional[Labels] = None,
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Writes metadata and label file to the Text classifier models."""
|
||||
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
|
||||
_MODEL_NAME = "TextClassifier"
|
||||
_MODEL_DESCRIPTION = ("Classify the input text into a set of known categories.")
|
||||
|
||||
|
||||
class MetadataWriter(metadata_writer.MetadataWriterBase):
|
||||
"""MetadataWriter to write the metadata into the text classifier."""
|
||||
|
||||
@classmethod
|
||||
def create_for_regex_model(
|
||||
cls, model_buffer: bytearray,
|
||||
regex_tokenizer: metadata_writer.RegexTokenizer,
|
||||
labels: metadata_writer.Labels) -> "MetadataWriter":
|
||||
"""Creates MetadataWriter for TFLite model with regex tokentizer.
|
||||
|
||||
The parameters required in this method are mandatory when using MediaPipe
|
||||
Tasks.
|
||||
|
||||
Note that only the output TFLite is used for deployment. The output JSON
|
||||
content is used to interpret the metadata content.
|
||||
|
||||
Args:
|
||||
model_buffer: A valid flatbuffer loaded from the TFLite model file.
|
||||
regex_tokenizer: information of the regex tokenizer [1] used to process
|
||||
the input string. If the tokenizer is `BertTokenizer` [2] or
|
||||
`SentencePieceTokenizer` [3], please refer to
|
||||
`create_for_bert_model`.
|
||||
labels: an instance of Labels helper class used in the output
|
||||
classification tensor [4].
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||
[3]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||
[4]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||
|
||||
Returns:
|
||||
A MetadataWriter object.
|
||||
"""
|
||||
writer = metadata_writer.MetadataWriter(model_buffer)
|
||||
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||
writer.add_regex_text_input(regex_tokenizer)
|
||||
writer.add_classification_output(labels)
|
||||
return cls(writer)
|
|
@ -174,12 +174,10 @@ class AudioClassifierTest(parameterized.TestCase):
|
|||
self.assertIsInstance(classifier, _AudioClassifier)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
# Invalid empty model path.
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"ExternalFile must specify at least one of 'file_content', "
|
||||
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
|
||||
base_options = _BaseOptions(model_asset_path='')
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite')
|
||||
options = _AudioClassifierOptions(base_options=base_options)
|
||||
_AudioClassifier.create_from_options(options)
|
||||
|
||||
|
|
|
@ -53,3 +53,17 @@ py_test(
|
|||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "text_classifier_test",
|
||||
srcs = ["text_classifier_test.py"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/metadata:data_files",
|
||||
"//mediapipe/tasks/testdata/metadata:model_files",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:text_classifier",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -191,6 +191,43 @@ class InputImageTensorMdTest(parameterized.TestCase):
|
|||
f"{len(norm_mean)} and {len(norm_std)}", str(error.exception))
|
||||
|
||||
|
||||
class InputTextTensorMdTest(absltest.TestCase):
|
||||
|
||||
_NAME = "input text"
|
||||
_DESCRIPTION = "The input string."
|
||||
_VOCAB_FILE = "vocab.txt"
|
||||
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "input_text_tensor_meta.json"))
|
||||
_EXPECTED_TENSOR_DEFAULT_JSON = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "input_text_tensor_default_meta.json"))
|
||||
|
||||
def test_create_metadata_should_succeed(self):
|
||||
regex_tokenizer_md = metadata_info.RegexTokenizerMd(
|
||||
self._DELIM_REGEX_PATTERN, self._VOCAB_FILE)
|
||||
|
||||
text_tensor_md = metadata_info.InputTextTensorMd(self._NAME,
|
||||
self._DESCRIPTION,
|
||||
regex_tokenizer_md)
|
||||
|
||||
metadata_json = _metadata.convert_to_json(
|
||||
_create_dummy_model_metadata_with_tensor(
|
||||
text_tensor_md.create_metadata()))
|
||||
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
def test_create_metadata_by_default_should_succeed(self):
|
||||
text_tensor_md = metadata_info.InputTextTensorMd()
|
||||
|
||||
metadata_json = _metadata.convert_to_json(
|
||||
_create_dummy_model_metadata_with_tensor(
|
||||
text_tensor_md.create_metadata()))
|
||||
with open(self._EXPECTED_TENSOR_DEFAULT_JSON, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
class ClassificationTensorMdTest(parameterized.TestCase):
|
||||
|
||||
_NAME = "probability"
|
||||
|
|
|
@ -113,7 +113,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
|||
def test_initialize_and_populate(self):
|
||||
writer = metadata_writer.MetadataWriter.create(
|
||||
self.image_classifier_model_buffer)
|
||||
writer.add_genernal_info(
|
||||
writer.add_general_info(
|
||||
model_name='my_image_model', model_description='my_description')
|
||||
tflite_model, metadata_json = writer.populate()
|
||||
self.assertLen(tflite_model, 1882986)
|
||||
|
@ -142,7 +142,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
|||
def test_add_feature_input_output(self):
|
||||
writer = metadata_writer.MetadataWriter.create(
|
||||
self.image_classifier_model_buffer)
|
||||
writer.add_genernal_info(
|
||||
writer.add_general_info(
|
||||
model_name='my_model', model_description='my_description')
|
||||
writer.add_feature_input(
|
||||
name='input_tesnor', description='a feature input tensor')
|
||||
|
@ -191,7 +191,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
|||
def test_image_classifier(self):
|
||||
writer = metadata_writer.MetadataWriter.create(
|
||||
self.image_classifier_model_buffer)
|
||||
writer.add_genernal_info(
|
||||
writer.add_general_info(
|
||||
model_name='image_classifier',
|
||||
model_description='Imagenet classification model')
|
||||
writer.add_image_input(
|
||||
|
@ -282,7 +282,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
|||
|
||||
def test_image_classifier_with_locale_and_score_calibration(self):
|
||||
writer = metadata_writer.MetadataWriter(self.image_classifier_model_buffer)
|
||||
writer.add_genernal_info(
|
||||
writer.add_general_info(
|
||||
model_name='image_classifier',
|
||||
model_description='Classify the input image.')
|
||||
writer.add_image_input(
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for metadata_writer.text_classifier."""
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import text_classifier
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DIR = "mediapipe/tasks/testdata/metadata/"
|
||||
_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite")
|
||||
_LABEL_FILE = test_utils.get_test_data_path(_TEST_DIR +
|
||||
"movie_review_labels.txt")
|
||||
_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt")
|
||||
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||
_JSON_FILE = test_utils.get_test_data_path("movie_review.json")
|
||||
|
||||
|
||||
class TextClassifierTest(absltest.TestCase):
|
||||
|
||||
def test_write_metadata(self,):
|
||||
with open(_MODEL, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = text_classifier.MetadataWriter.create_for_regex_model(
|
||||
model_buffer,
|
||||
regex_tokenizer=metadata_writer.RegexTokenizer(
|
||||
delim_regex_pattern=_DELIM_REGEX_PATTERN,
|
||||
vocab_file_path=_VOCAB_FILE),
|
||||
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
|
||||
_, metadata_json = writer.populate()
|
||||
|
||||
with open(_JSON_FILE, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
|
@ -154,12 +154,10 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
self.assertIsInstance(classifier, _TextClassifier)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
# Invalid empty model path.
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"ExternalFile must specify at least one of 'file_content', "
|
||||
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
|
||||
base_options = _BaseOptions(model_asset_path='')
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite')
|
||||
options = _TextClassifierOptions(base_options=base_options)
|
||||
_TextClassifier.create_from_options(options)
|
||||
|
||||
|
|
|
@ -147,12 +147,10 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
self.assertIsInstance(classifier, _ImageClassifier)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
# Invalid empty model path.
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"ExternalFile must specify at least one of 'file_content', "
|
||||
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
|
||||
base_options = _BaseOptions(model_asset_path='')
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite')
|
||||
options = _ImageClassifierOptions(base_options=base_options)
|
||||
_ImageClassifier.create_from_options(options)
|
||||
|
||||
|
|
|
@ -97,12 +97,10 @@ class ImageSegmenterTest(parameterized.TestCase):
|
|||
self.assertIsInstance(segmenter, _ImageSegmenter)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
# Invalid empty model path.
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"ExternalFile must specify at least one of 'file_content', "
|
||||
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
|
||||
base_options = _BaseOptions(model_asset_path='')
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite')
|
||||
options = _ImageSegmenterOptions(base_options=base_options)
|
||||
_ImageSegmenter.create_from_options(options)
|
||||
|
||||
|
|
|
@ -119,12 +119,10 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
self.assertIsInstance(detector, _ObjectDetector)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
# Invalid empty model path.
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"ExternalFile must specify at least one of 'file_content', "
|
||||
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
|
||||
base_options = _BaseOptions(model_asset_path='')
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite')
|
||||
options = _ObjectDetectorOptions(base_options=base_options)
|
||||
_ObjectDetector.create_from_options(options)
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ py_library(
|
|||
"//mediapipe/python:packet_creator",
|
||||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
|
|
|
@ -22,7 +22,7 @@ from mediapipe.python import packet_getter
|
|||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import packet
|
||||
from mediapipe.tasks.cc.components.proto import segmenter_options_pb2
|
||||
from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_options_pb2
|
||||
from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
@ -31,7 +31,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
|||
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
|
||||
_ImageSegmenterOptionsProto = image_segmenter_options_pb2.ImageSegmenterOptions
|
||||
_ImageSegmenterGraphOptionsProto = image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions
|
||||
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
|
@ -40,7 +40,7 @@ _SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
|
|||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||
_IMAGE_TAG = 'IMAGE'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageSegmenterGraph'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'
|
||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||
|
||||
|
||||
|
@ -81,13 +81,13 @@ class ImageSegmenterOptions:
|
|||
[List[image_module.Image], image_module.Image, int], None]] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ImageSegmenterOptionsProto:
|
||||
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
|
||||
"""Generates an ImageSegmenterOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||
segmenter_options_proto = _SegmenterOptionsProto(
|
||||
output_type=self.output_type.value, activation=self.activation.value)
|
||||
return _ImageSegmenterOptionsProto(
|
||||
return _ImageSegmenterGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
segmenter_options=segmenter_options_proto)
|
||||
|
||||
|
|
12
mediapipe/tasks/testdata/metadata/BUILD
vendored
12
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -31,6 +31,7 @@ mediapipe_files(srcs = [
|
|||
"mobilenet_v2_1.0_224_quant.tflite",
|
||||
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
|
||||
"mobilenet_v2_1.0_224_without_metadata.tflite",
|
||||
"movie_review.tflite",
|
||||
])
|
||||
|
||||
exports_files([
|
||||
|
@ -54,6 +55,11 @@ exports_files([
|
|||
"labels.txt",
|
||||
"mobilenet_v2_1.0_224.json",
|
||||
"mobilenet_v2_1.0_224_quant.json",
|
||||
"input_text_tensor_meta.json",
|
||||
"input_text_tensor_default_meta.json",
|
||||
"movie_review_labels.txt",
|
||||
"regex_vocab.txt",
|
||||
"movie_review.json",
|
||||
])
|
||||
|
||||
filegroup(
|
||||
|
@ -67,6 +73,7 @@ filegroup(
|
|||
"mobilenet_v2_1.0_224_quant.tflite",
|
||||
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
|
||||
"mobilenet_v2_1.0_224_without_metadata.tflite",
|
||||
"movie_review.tflite",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -86,9 +93,14 @@ filegroup(
|
|||
"input_image_tensor_float_meta.json",
|
||||
"input_image_tensor_uint8_meta.json",
|
||||
"input_image_tensor_unsupported_meta.json",
|
||||
"input_text_tensor_default_meta.json",
|
||||
"input_text_tensor_meta.json",
|
||||
"labels.txt",
|
||||
"mobilenet_v2_1.0_224.json",
|
||||
"mobilenet_v2_1.0_224_quant.json",
|
||||
"movie_review.json",
|
||||
"movie_review_labels.txt",
|
||||
"regex_vocab.txt",
|
||||
"score_calibration.txt",
|
||||
"score_calibration_file_meta.json",
|
||||
"score_calibration_tensor_meta.json",
|
||||
|
|
17
mediapipe/tasks/testdata/metadata/input_text_tensor_default_meta.json
vendored
Normal file
17
mediapipe/tasks/testdata/metadata/input_text_tensor_default_meta.json
vendored
Normal file
|
@ -0,0 +1,17 @@
|
|||
{
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
34
mediapipe/tasks/testdata/metadata/input_text_tensor_meta.json
vendored
Normal file
34
mediapipe/tasks/testdata/metadata/input_text_tensor_meta.json
vendored
Normal file
|
@ -0,0 +1,34 @@
|
|||
{
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "input text",
|
||||
"description": "The input string.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "RegexTokenizerOptions",
|
||||
"options": {
|
||||
"delim_regex_pattern": "[^\\w\\']+",
|
||||
"vocab_file": [
|
||||
{
|
||||
"name": "vocab.txt",
|
||||
"description": "Vocabulary file to convert natural language words to embedding vectors.",
|
||||
"type": "VOCABULARY"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
63
mediapipe/tasks/testdata/metadata/movie_review.json
vendored
Normal file
63
mediapipe/tasks/testdata/metadata/movie_review.json
vendored
Normal file
|
@ -0,0 +1,63 @@
|
|||
{
|
||||
"name": "TextClassifier",
|
||||
"description": "Classify the input text into a set of known categories.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "input_text",
|
||||
"description": "Embedding vectors representing the input text to be processed.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "RegexTokenizerOptions",
|
||||
"options": {
|
||||
"delim_regex_pattern": "[^\\w\\']+",
|
||||
"vocab_file": [
|
||||
{
|
||||
"name": "regex_vocab.txt",
|
||||
"description": "Vocabulary file to convert natural language words to embedding vectors.",
|
||||
"type": "VOCABULARY"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "score",
|
||||
"description": "Score of the labels respectively.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.2.1"
|
||||
}
|
2
mediapipe/tasks/testdata/metadata/movie_review_labels.txt
vendored
Normal file
2
mediapipe/tasks/testdata/metadata/movie_review_labels.txt
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
Negative
|
||||
Positive
|
10000
mediapipe/tasks/testdata/metadata/regex_vocab.txt
vendored
Normal file
10000
mediapipe/tasks/testdata/metadata/regex_vocab.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
6
mediapipe/tasks/testdata/text/BUILD
vendored
6
mediapipe/tasks/testdata/text/BUILD
vendored
|
@ -28,6 +28,7 @@ mediapipe_files(srcs = [
|
|||
"bert_text_classifier.tflite",
|
||||
"mobilebert_embedding_with_metadata.tflite",
|
||||
"mobilebert_with_metadata.tflite",
|
||||
"regex_one_embedding_with_metadata.tflite",
|
||||
"test_model_text_classifier_bool_output.tflite",
|
||||
"test_model_text_classifier_with_regex_tokenizer.tflite",
|
||||
"universal_sentence_encoder_qa_with_metadata.tflite",
|
||||
|
@ -92,6 +93,11 @@ filegroup(
|
|||
srcs = ["mobilebert_embedding_with_metadata.tflite"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "regex_embedding_with_metadata",
|
||||
srcs = ["regex_one_embedding_with_metadata.tflite"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "universal_sentence_encoder_qa",
|
||||
data = ["universal_sentence_encoder_qa_with_metadata.tflite"],
|
||||
|
|
5
mediapipe/tasks/testdata/vision/BUILD
vendored
5
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -144,8 +144,13 @@ filegroup(
|
|||
)
|
||||
|
||||
# Gestures related models. Visible to model_maker.
|
||||
# TODO: Upload canned gesture model and gesture embedding model to GCS after Model Card approval
|
||||
filegroup(
|
||||
name = "test_gesture_models",
|
||||
srcs = [
|
||||
"hand_landmark_full.tflite",
|
||||
"palm_detection_full.tflite",
|
||||
],
|
||||
visibility = [
|
||||
"//mediapipe/model_maker:__subpackages__",
|
||||
"//mediapipe/tasks:internal",
|
||||
|
|
106
mediapipe/tasks/web/BUILD
Normal file
106
mediapipe/tasks/web/BUILD
Normal file
|
@ -0,0 +1,106 @@
|
|||
# This contains the MediaPipe Tasks NPM package definitions.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||
load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm")
|
||||
load("@npm//@bazel/rollup:index.bzl", "rollup_bundle")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
# Audio
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "audio_lib",
|
||||
srcs = ["audio.ts"],
|
||||
deps = ["//mediapipe/tasks/web/audio:audio_lib"],
|
||||
)
|
||||
|
||||
rollup_bundle(
|
||||
name = "audio_bundle",
|
||||
config_file = "rollup.config.mjs",
|
||||
entry_point = "audio.ts",
|
||||
output_dir = False,
|
||||
deps = [
|
||||
":audio_lib",
|
||||
"@npm//@rollup/plugin-commonjs",
|
||||
"@npm//@rollup/plugin-node-resolve",
|
||||
],
|
||||
)
|
||||
|
||||
pkg_npm(
|
||||
name = "audio_pkg",
|
||||
package_name = "__PACKAGE_NAME__",
|
||||
srcs = ["package.json"],
|
||||
substitutions = {
|
||||
"__PACKAGE_NAME__": "@mediapipe/tasks-audio",
|
||||
"__DESCRIPTION__": "MediaPipe Audio Tasks",
|
||||
"__BUNDLE__": "audio_bundle.js",
|
||||
},
|
||||
tgz = "audio.tgz",
|
||||
deps = [":audio_bundle"],
|
||||
)
|
||||
|
||||
# Text
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "text_lib",
|
||||
srcs = ["text.ts"],
|
||||
deps = ["//mediapipe/tasks/web/text:text_lib"],
|
||||
)
|
||||
|
||||
rollup_bundle(
|
||||
name = "text_bundle",
|
||||
config_file = "rollup.config.mjs",
|
||||
entry_point = "text.ts",
|
||||
output_dir = False,
|
||||
deps = [
|
||||
":text_lib",
|
||||
"@npm//@rollup/plugin-commonjs",
|
||||
"@npm//@rollup/plugin-node-resolve",
|
||||
],
|
||||
)
|
||||
|
||||
pkg_npm(
|
||||
name = "text_pkg",
|
||||
package_name = "__PACKAGE_NAME__",
|
||||
srcs = ["package.json"],
|
||||
substitutions = {
|
||||
"__PACKAGE_NAME__": "@mediapipe/tasks-text",
|
||||
"__DESCRIPTION__": "MediaPipe Text Tasks",
|
||||
"__BUNDLE__": "text_bundle.js",
|
||||
},
|
||||
tgz = "text.tgz",
|
||||
deps = [":text_bundle"],
|
||||
)
|
||||
|
||||
# Vision
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "vision_lib",
|
||||
srcs = ["vision.ts"],
|
||||
deps = ["//mediapipe/tasks/web/vision:vision_lib"],
|
||||
)
|
||||
|
||||
rollup_bundle(
|
||||
name = "vision_bundle",
|
||||
config_file = "rollup.config.mjs",
|
||||
entry_point = "vision.ts",
|
||||
output_dir = False,
|
||||
deps = [
|
||||
":vision_lib",
|
||||
"@npm//@rollup/plugin-commonjs",
|
||||
"@npm//@rollup/plugin-node-resolve",
|
||||
],
|
||||
)
|
||||
|
||||
pkg_npm(
|
||||
name = "vision_pkg",
|
||||
package_name = "__PACKAGE_NAME__",
|
||||
srcs = ["package.json"],
|
||||
substitutions = {
|
||||
"__PACKAGE_NAME__": "@mediapipe/tasks-vision",
|
||||
"__DESCRIPTION__": "MediaPipe Vision Tasks",
|
||||
"__BUNDLE__": "vision_bundle.js",
|
||||
},
|
||||
tgz = "vision.tgz",
|
||||
deps = [":vision_bundle"],
|
||||
)
|
17
mediapipe/tasks/web/audio.ts
Normal file
17
mediapipe/tasks/web/audio.ts
Normal file
|
@ -0,0 +1,17 @@
|
|||
/**
|
||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
export * from '../../tasks/web/audio/index';
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "audio_lib",
|
||||
srcs = ["index.ts"],
|
||||
|
|
|
@ -119,8 +119,8 @@ export class AudioClassifier extends TaskRunner {
|
|||
*/
|
||||
async setOptions(options: AudioClassifierOptions): Promise<void> {
|
||||
if (options.baseOptions) {
|
||||
const baseOptionsProto =
|
||||
await convertBaseOptionsToProto(options.baseOptions);
|
||||
const baseOptionsProto = await convertBaseOptionsToProto(
|
||||
options.baseOptions, this.options.getBaseOptions());
|
||||
this.options.setBaseOptions(baseOptionsProto);
|
||||
}
|
||||
|
||||
|
@ -198,7 +198,7 @@ export class AudioClassifier extends TaskRunner {
|
|||
classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM);
|
||||
classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM);
|
||||
classifierNode.addOutputStream(
|
||||
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
|
||||
'CLASSIFICATIONS:' + CLASSIFICATION_RESULT_STREAM);
|
||||
classifierNode.setOptions(calculatorOptions);
|
||||
|
||||
graphConfig.addNode(classifierNode);
|
||||
|
|
|
@ -26,6 +26,8 @@ mediapipe_ts_library(
|
|||
name = "base_options",
|
||||
srcs = ["base_options.ts"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
|
||||
"//mediapipe/tasks/web/core",
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb';
|
||||
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
|
||||
import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
||||
|
@ -25,26 +27,60 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
|||
* Converts a BaseOptions API object to its Protobuf representation.
|
||||
* @throws If neither a model assset path or buffer is provided
|
||||
*/
|
||||
export async function convertBaseOptionsToProto(baseOptions: BaseOptions):
|
||||
Promise<BaseOptionsProto> {
|
||||
if (baseOptions.modelAssetPath && baseOptions.modelAssetBuffer) {
|
||||
export async function convertBaseOptionsToProto(
|
||||
updatedOptions: BaseOptions,
|
||||
currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> {
|
||||
const result =
|
||||
currentOptions ? currentOptions.clone() : new BaseOptionsProto();
|
||||
|
||||
await configureExternalFile(updatedOptions, result);
|
||||
configureAcceleration(updatedOptions, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configues the `externalFile` option and validates that a single model is
|
||||
* provided.
|
||||
*/
|
||||
async function configureExternalFile(
|
||||
options: BaseOptions, proto: BaseOptionsProto) {
|
||||
const externalFile = proto.getModelAsset() || new ExternalFile();
|
||||
proto.setModelAsset(externalFile);
|
||||
|
||||
if (options.modelAssetPath || options.modelAssetBuffer) {
|
||||
if (options.modelAssetPath && options.modelAssetBuffer) {
|
||||
throw new Error(
|
||||
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
||||
}
|
||||
if (!baseOptions.modelAssetPath && !baseOptions.modelAssetBuffer) {
|
||||
|
||||
let modelAssetBuffer = options.modelAssetBuffer;
|
||||
if (!modelAssetBuffer) {
|
||||
const response = await fetch(options.modelAssetPath!.toString());
|
||||
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
|
||||
}
|
||||
externalFile.setFileContent(modelAssetBuffer);
|
||||
}
|
||||
|
||||
if (!externalFile.hasFileContent()) {
|
||||
throw new Error(
|
||||
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
||||
}
|
||||
|
||||
let modelAssetBuffer = baseOptions.modelAssetBuffer;
|
||||
if (!modelAssetBuffer) {
|
||||
const response = await fetch(baseOptions.modelAssetPath!.toString());
|
||||
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
|
||||
}
|
||||
|
||||
const proto = new BaseOptionsProto();
|
||||
const externalFile = new ExternalFile();
|
||||
externalFile.setFileContent(modelAssetBuffer);
|
||||
proto.setModelAsset(externalFile);
|
||||
return proto;
|
||||
}
|
||||
|
||||
/** Configues the `acceleration` option. */
|
||||
function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
|
||||
if ('delegate' in options) {
|
||||
const acceleration = new Acceleration();
|
||||
if (options.delegate === 'cpu') {
|
||||
acceleration.setXnnpack(
|
||||
new InferenceCalculatorOptions.Delegate.Xnnpack());
|
||||
proto.setAcceleration(acceleration);
|
||||
} else if (options.delegate === 'gpu') {
|
||||
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
||||
proto.setAcceleration(acceleration);
|
||||
} else {
|
||||
proto.clearAcceleration();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
8
mediapipe/tasks/web/core/base_options.d.ts
vendored
8
mediapipe/tasks/web/core/base_options.d.ts
vendored
|
@ -22,10 +22,14 @@ export interface BaseOptions {
|
|||
* The model path to the model asset file. Only one of `modelAssetPath` or
|
||||
* `modelAssetBuffer` can be set.
|
||||
*/
|
||||
modelAssetPath?: string;
|
||||
modelAssetPath?: string|undefined;
|
||||
|
||||
/**
|
||||
* A buffer containing the model aaset. Only one of `modelAssetPath` or
|
||||
* `modelAssetBuffer` can be set.
|
||||
*/
|
||||
modelAssetBuffer?: Uint8Array;
|
||||
modelAssetBuffer?: Uint8Array|undefined;
|
||||
|
||||
/** Overrides the default backend to use for the provided model. */
|
||||
delegate?: 'cpu'|'gpu'|undefined;
|
||||
}
|
||||
|
|
15
mediapipe/tasks/web/package.json
Normal file
15
mediapipe/tasks/web/package.json
Normal file
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"name": "__PACKAGE_NAME__",
|
||||
"version": "__VERSION__",
|
||||
"description": "__DESCRIPTION__",
|
||||
"main": "__BUNDLE__",
|
||||
"module": "__BUNDLE__",
|
||||
"author": "mediapipe@google.com",
|
||||
"license": "Apache-2.0",
|
||||
"type": "module",
|
||||
"dependencies": {
|
||||
"google-protobuf": "^3.21.2"
|
||||
},
|
||||
"homepage": "http://mediapipe.dev",
|
||||
"keywords": [ "AR", "ML", "Augmented", "MediaPipe", "MediaPipe Tasks" ]
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user