From cbc7eb661b4a3c321eb9a949c69b6c9cd52498f0 Mon Sep 17 00:00:00 2001 From: Yuqi Li Date: Mon, 10 Oct 2022 12:12:13 -0700 Subject: [PATCH 01/18] Add metadata_info.py for metadata writer. PiperOrigin-RevId: 480146881 --- .../python/metadata/metadata_writers/BUILD | 21 + .../metadata/metadata_writers/__init__.py | 13 + .../metadata_writers/metadata_info.py | 446 +++++++++++++++ .../test/metadata/metadata_writers/BUILD | 26 + .../metadata_writers/metadata_info_test.py | 343 ++++++++++++ mediapipe/tasks/python/test/test_utils.py | 10 + mediapipe/tasks/testdata/metadata/BUILD | 28 + .../metadata/associated_file_meta.json | 10 + .../metadata/bounding_box_tensor_meta.json | 35 ++ .../classification_tensor_float_meta.json | 52 ++ .../classification_tensor_uint8_meta.json | 52 ++ ...lassification_tensor_unsupported_meta.json | 46 ++ .../metadata/feature_tensor_meta.json | 35 ++ .../tasks/testdata/metadata/general_meta.json | 7 + .../testdata/metadata/image_tensor_meta.json | 35 ++ .../input_image_tensor_float_meta.json | 47 ++ .../input_image_tensor_uint8_meta.json | 43 ++ .../input_image_tensor_unsupported_meta.json | 37 ++ .../testdata/metadata/score_calibration.txt | 511 ++++++++++++++++++ .../metadata/score_calibration_file_meta.json | 9 + .../score_calibration_tensor_meta.json | 15 + third_party/external_files.bzl | 84 +++ 22 files changed, 1905 insertions(+) create mode 100644 mediapipe/tasks/python/metadata/metadata_writers/BUILD create mode 100644 mediapipe/tasks/python/metadata/metadata_writers/__init__.py create mode 100644 mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py create mode 100644 mediapipe/tasks/python/test/metadata/metadata_writers/BUILD create mode 100644 mediapipe/tasks/python/test/metadata/metadata_writers/metadata_info_test.py create mode 100644 mediapipe/tasks/testdata/metadata/associated_file_meta.json create mode 100644 mediapipe/tasks/testdata/metadata/bounding_box_tensor_meta.json create mode 100644 mediapipe/tasks/testdata/metadata/classification_tensor_float_meta.json create mode 100644 mediapipe/tasks/testdata/metadata/classification_tensor_uint8_meta.json create mode 100644 mediapipe/tasks/testdata/metadata/classification_tensor_unsupported_meta.json create mode 100644 mediapipe/tasks/testdata/metadata/feature_tensor_meta.json create mode 100644 mediapipe/tasks/testdata/metadata/general_meta.json create mode 100644 mediapipe/tasks/testdata/metadata/image_tensor_meta.json create mode 100644 mediapipe/tasks/testdata/metadata/input_image_tensor_float_meta.json create mode 100644 mediapipe/tasks/testdata/metadata/input_image_tensor_uint8_meta.json create mode 100644 mediapipe/tasks/testdata/metadata/input_image_tensor_unsupported_meta.json create mode 100644 mediapipe/tasks/testdata/metadata/score_calibration.txt create mode 100644 mediapipe/tasks/testdata/metadata/score_calibration_file_meta.json create mode 100644 mediapipe/tasks/testdata/metadata/score_calibration_tensor_meta.json diff --git a/mediapipe/tasks/python/metadata/metadata_writers/BUILD b/mediapipe/tasks/python/metadata/metadata_writers/BUILD new file mode 100644 index 000000000..cc8bd45db --- /dev/null +++ b/mediapipe/tasks/python/metadata/metadata_writers/BUILD @@ -0,0 +1,21 @@ +# Placeholder for internal Python strict library compatibility macro. + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +py_library( + name = "metadata_info", + srcs = [ + "metadata_info.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/metadata:metadata_schema_py", + "//mediapipe/tasks/metadata:schema_py", + ], +) diff --git a/mediapipe/tasks/python/metadata/metadata_writers/__init__.py b/mediapipe/tasks/python/metadata/metadata_writers/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/tasks/python/metadata/metadata_writers/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py b/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py new file mode 100644 index 000000000..07938d863 --- /dev/null +++ b/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py @@ -0,0 +1,446 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper classes for common model metadata information.""" + +import csv +import os +from typing import List, Optional, Type + +from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb +from mediapipe.tasks.metadata import schema_py_generated as _schema_fb + +# Min and max values for UINT8 tensors. +_MIN_UINT8 = 0 +_MAX_UINT8 = 255 + +# Default description for vocabulary files. +_VOCAB_FILE_DESCRIPTION = ("Vocabulary file to convert natural language " + "words to embedding vectors.") + + +class GeneralMd: + """A container for common metadata information of a model. + + Attributes: + name: name of the model. + version: version of the model. + description: description of what the model does. + author: author of the model. + licenses: licenses of the model. + """ + + def __init__(self, + name: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, + author: Optional[str] = None, + licenses: Optional[str] = None) -> None: + self.name = name + self.version = version + self.description = description + self.author = author + self.licenses = licenses + + def create_metadata(self) -> _metadata_fb.ModelMetadataT: + """Creates the model metadata based on the general model information. + + Returns: + A Flatbuffers Python object of the model metadata. + """ + model_metadata = _metadata_fb.ModelMetadataT() + model_metadata.name = self.name + model_metadata.version = self.version + model_metadata.description = self.description + model_metadata.author = self.author + model_metadata.license = self.licenses + return model_metadata + + +class AssociatedFileMd: + """A container for common associated file metadata information. + + Attributes: + file_path: path to the associated file. + description: description of the associated file. + file_type: file type of the associated file [1]. + locale: locale of the associated file [2]. + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L77 + [2]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L176 + """ + + def __init__( + self, + file_path: str, + description: Optional[str] = None, + file_type: Optional[int] = _metadata_fb.AssociatedFileType.UNKNOWN, + locale: Optional[str] = None) -> None: + self.file_path = file_path + self.description = description + self.file_type = file_type + self.locale = locale + + def create_metadata(self) -> _metadata_fb.AssociatedFileT: + """Creates the associated file metadata. + + Returns: + A Flatbuffers Python object of the associated file metadata. + """ + file_metadata = _metadata_fb.AssociatedFileT() + file_metadata.name = os.path.basename(self.file_path) + file_metadata.description = self.description + file_metadata.type = self.file_type + file_metadata.locale = self.locale + return file_metadata + + +class LabelFileMd(AssociatedFileMd): + """A container for label file metadata information.""" + + _LABEL_FILE_DESCRIPTION = ("Labels for categories that the model can " + "recognize.") + _FILE_TYPE = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS + + def __init__(self, file_path: str, locale: Optional[str] = None) -> None: + """Creates a LabelFileMd object. + + Args: + file_path: file_path of the label file. + locale: locale of the label file [1]. + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L176 + """ + super().__init__(file_path, self._LABEL_FILE_DESCRIPTION, self._FILE_TYPE, + locale) + + +class ScoreCalibrationMd: + """A container for score calibration [1] metadata information. + + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456 + """ + + _SCORE_CALIBRATION_FILE_DESCRIPTION = ( + "Contains sigmoid-based score calibration parameters. The main purposes " + "of score calibration is to make scores across classes comparable, so " + "that a common threshold can be used for all output classes.") + _FILE_TYPE = _metadata_fb.AssociatedFileType.TENSOR_AXIS_SCORE_CALIBRATION + + def __init__(self, + score_transformation_type: _metadata_fb.ScoreTransformationType, + default_score: float, file_path: str) -> None: + """Creates a ScoreCalibrationMd object. + + Args: + score_transformation_type: type of the function used for transforming the + uncalibrated score before applying score calibration. + default_score: the default calibrated score to apply if the uncalibrated + score is below min_score or if no parameters were specified for a given + index. + file_path: file_path of the score calibration file [1]. + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L133 + + Raises: + ValueError: if the score_calibration file is malformed. + """ + self._score_transformation_type = score_transformation_type + self._default_score = default_score + self._file_path = file_path + + # Sanity check the score calibration file. + with open(self._file_path) as calibration_file: + csv_reader = csv.reader(calibration_file, delimiter=",") + for row in csv_reader: + if row and len(row) != 3 and len(row) != 4: + raise ValueError( + f"Expected empty lines or 3 or 4 parameters per line in score" + f" calibration file, but got {len(row)}.") + + if row and float(row[0]) < 0: + raise ValueError( + f"Expected scale to be a non-negative value, but got " + f"{float(row[0])}.") + + def create_metadata(self) -> _metadata_fb.ProcessUnitT: + """Creates the score calibration metadata based on the information. + + Returns: + A Flatbuffers Python object of the score calibration metadata. + """ + score_calibration = _metadata_fb.ProcessUnitT() + score_calibration.optionsType = ( + _metadata_fb.ProcessUnitOptions.ScoreCalibrationOptions) + options = _metadata_fb.ScoreCalibrationOptionsT() + options.scoreTransformation = self._score_transformation_type + options.defaultScore = self._default_score + score_calibration.options = options + return score_calibration + + def create_score_calibration_file_md(self) -> AssociatedFileMd: + return AssociatedFileMd(self._file_path, + self._SCORE_CALIBRATION_FILE_DESCRIPTION, + self._FILE_TYPE) + + +class TensorMd: + """A container for common tensor metadata information. + + Attributes: + name: name of the tensor. + description: description of what the tensor is. + min_values: per-channel minimum value of the tensor. + max_values: per-channel maximum value of the tensor. + content_type: content_type of the tensor. + associated_files: information of the associated files in the tensor. + tensor_name: name of the corresponding tensor [1] in the TFLite model. It is + used to locate the corresponding tensor and decide the order of the tensor + metadata [2] when populating model metadata. + [1]: + https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136 + [2]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640 + """ + + def __init__( + self, + name: Optional[str] = None, + description: Optional[str] = None, + min_values: Optional[List[float]] = None, + max_values: Optional[List[float]] = None, + content_type: int = _metadata_fb.ContentProperties.FeatureProperties, + associated_files: Optional[List[Type[AssociatedFileMd]]] = None, + tensor_name: Optional[str] = None) -> None: + self.name = name + self.description = description + self.min_values = min_values + self.max_values = max_values + self.content_type = content_type + self.associated_files = associated_files + self.tensor_name = tensor_name + + def create_metadata(self) -> _metadata_fb.TensorMetadataT: + """Creates the input tensor metadata based on the information. + + Returns: + A Flatbuffers Python object of the input metadata. + """ + tensor_metadata = _metadata_fb.TensorMetadataT() + tensor_metadata.name = self.name + tensor_metadata.description = self.description + + # Create min and max values + stats = _metadata_fb.StatsT() + stats.max = self.max_values + stats.min = self.min_values + tensor_metadata.stats = stats + + # Create content properties + content = _metadata_fb.ContentT() + if self.content_type is _metadata_fb.ContentProperties.FeatureProperties: + content.contentProperties = _metadata_fb.FeaturePropertiesT() + elif self.content_type is _metadata_fb.ContentProperties.ImageProperties: + content.contentProperties = _metadata_fb.ImagePropertiesT() + elif self.content_type is ( + _metadata_fb.ContentProperties.BoundingBoxProperties): + content.contentProperties = _metadata_fb.BoundingBoxPropertiesT() + elif self.content_type is _metadata_fb.ContentProperties.AudioProperties: + content.contentProperties = _metadata_fb.AudioPropertiesT() + + content.contentPropertiesType = self.content_type + tensor_metadata.content = content + + # TODO: check if multiple label files have populated locale. + # Create associated files + if self.associated_files: + tensor_metadata.associatedFiles = [ + file.create_metadata() for file in self.associated_files + ] + return tensor_metadata + + +class InputImageTensorMd(TensorMd): + """A container for input image tensor metadata information. + + Attributes: + norm_mean: the mean value used in tensor normalization [1]. + norm_std: the std value used in the tensor normalization [1]. norm_mean and + norm_std must have the same dimension. + color_space_type: the color space type of the input image [2]. + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389 + [2]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L198 + """ + + # Min and max float values for image pixels. + _MIN_PIXEL = 0.0 + _MAX_PIXEL = 255.0 + + def __init__( + self, + name: Optional[str] = None, + description: Optional[str] = None, + norm_mean: Optional[List[float]] = None, + norm_std: Optional[List[float]] = None, + color_space_type: Optional[int] = _metadata_fb.ColorSpaceType.UNKNOWN, + tensor_type: Optional["_schema_fb.TensorType"] = None) -> None: + """Initializes the instance of InputImageTensorMd. + + Args: + name: name of the tensor. + description: description of what the tensor is. + norm_mean: the mean value used in tensor normalization [1]. + norm_std: the std value used in the tensor normalization [1]. norm_mean + and norm_std must have the same dimension. + color_space_type: the color space type of the input image [2]. + tensor_type: data type of the tensor. + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389 + [2]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L198 + + Raises: + ValueError: if norm_mean and norm_std have different dimensions. + """ + if norm_std and norm_mean and len(norm_std) != len(norm_mean): + raise ValueError( + f"norm_mean and norm_std are expected to be the same dim. But got " + f"{len(norm_mean)} and {len(norm_std)}") + + if tensor_type is _schema_fb.TensorType.UINT8: + min_values = [_MIN_UINT8] + max_values = [_MAX_UINT8] + elif tensor_type is _schema_fb.TensorType.FLOAT32 and norm_std and norm_mean: + min_values = [ + float(self._MIN_PIXEL - mean) / std + for mean, std in zip(norm_mean, norm_std) + ] + max_values = [ + float(self._MAX_PIXEL - mean) / std + for mean, std in zip(norm_mean, norm_std) + ] + else: + # Uint8 and Float32 are the two major types currently. And Task library + # doesn't support other types so far. + min_values = None + max_values = None + + super().__init__(name, description, min_values, max_values, + _metadata_fb.ContentProperties.ImageProperties) + self.norm_mean = norm_mean + self.norm_std = norm_std + self.color_space_type = color_space_type + + def create_metadata(self) -> _metadata_fb.TensorMetadataT: + """Creates the input image metadata based on the information. + + Returns: + A Flatbuffers Python object of the input image metadata. + """ + tensor_metadata = super().create_metadata() + tensor_metadata.content.contentProperties.colorSpace = self.color_space_type + # Create normalization parameters + if self.norm_mean and self.norm_std: + normalization = _metadata_fb.ProcessUnitT() + normalization.optionsType = ( + _metadata_fb.ProcessUnitOptions.NormalizationOptions) + normalization.options = _metadata_fb.NormalizationOptionsT() + normalization.options.mean = self.norm_mean + normalization.options.std = self.norm_std + tensor_metadata.processUnits = [normalization] + return tensor_metadata + + +class ClassificationTensorMd(TensorMd): + """A container for the classification tensor metadata information. + + Attributes: + label_files: information of the label files [1] in the classification + tensor. + score_calibration_md: information of the score calibration operation [2] in + the classification tensor. + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99 + [2]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456 + """ + + # Min and max float values for classification results. + _MIN_FLOAT = 0.0 + _MAX_FLOAT = 1.0 + + def __init__(self, + name: Optional[str] = None, + description: Optional[str] = None, + label_files: Optional[List[LabelFileMd]] = None, + tensor_type: Optional[int] = None, + score_calibration_md: Optional[ScoreCalibrationMd] = None, + tensor_name: Optional[str] = None) -> None: + """Initializes the instance of ClassificationTensorMd. + + Args: + name: name of the tensor. + description: description of what the tensor is. + label_files: information of the label files [1] in the classification + tensor. + tensor_type: data type of the tensor. + score_calibration_md: information of the score calibration files operation + [2] in the classification tensor. + tensor_name: name of the corresponding tensor [3] in the TFLite model. It + is used to locate the corresponding classification tensor and decide the + order of the tensor metadata [4] when populating model metadata. + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99 + [2]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456 + [3]: + https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136 + [4]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640 + """ + self.score_calibration_md = score_calibration_md + + if tensor_type is _schema_fb.TensorType.UINT8: + min_values = [_MIN_UINT8] + max_values = [_MAX_UINT8] + elif tensor_type is _schema_fb.TensorType.FLOAT32: + min_values = [self._MIN_FLOAT] + max_values = [self._MAX_FLOAT] + else: + # Uint8 and Float32 are the two major types currently. And Task library + # doesn't support other types so far. + min_values = None + max_values = None + + associated_files = label_files or [] + if self.score_calibration_md: + associated_files.append( + score_calibration_md.create_score_calibration_file_md()) + + super().__init__(name, description, min_values, max_values, + _metadata_fb.ContentProperties.FeatureProperties, + associated_files, tensor_name) + + def create_metadata(self) -> _metadata_fb.TensorMetadataT: + """Creates the classification tensor metadata based on the information.""" + tensor_metadata = super().create_metadata() + if self.score_calibration_md: + tensor_metadata.processUnits = [ + self.score_calibration_md.create_metadata() + ] + return tensor_metadata diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD b/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD new file mode 100644 index 000000000..debb25787 --- /dev/null +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD @@ -0,0 +1,26 @@ +# Placeholder for internal Python strict test compatibility macro. + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +py_test( + name = "metadata_info_test", + srcs = ["metadata_info_test.py"], + data = [ + "//mediapipe/tasks/testdata/metadata:data_files", + ], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + "//mediapipe/tasks/metadata:metadata_schema_py", + "//mediapipe/tasks/metadata:schema_py", + "//mediapipe/tasks/python/metadata", + "//mediapipe/tasks/python/metadata/metadata_writers:metadata_info", + "//mediapipe/tasks/python/test:test_utils", + "@flatbuffers//:runtime_py", + ], +) diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_info_test.py b/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_info_test.py new file mode 100644 index 000000000..75602c83c --- /dev/null +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_info_test.py @@ -0,0 +1,343 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for metadata info classes.""" + +import tempfile + +from absl.testing import absltest +from absl.testing import parameterized + +import flatbuffers +from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb +from mediapipe.tasks.metadata import schema_py_generated as _schema_fb +from mediapipe.tasks.python.metadata import metadata as _metadata +from mediapipe.tasks.python.metadata.metadata_writers import metadata_info +from mediapipe.tasks.python.test import test_utils + +_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path("score_calibration.txt") + + +class GeneralMdTest(absltest.TestCase): + + _EXPECTED_GENERAL_META_JSON = test_utils.get_test_data_path( + "general_meta.json") + + def test_create_metadata_should_succeed(self): + general_md = metadata_info.GeneralMd( + name="model", + version="v1", + description="A ML model.", + author="MediaPipe", + licenses="Apache") + general_metadata = general_md.create_metadata() + + # Create the Flatbuffers object and convert it to the json format. + builder = flatbuffers.Builder(0) + builder.Finish( + general_metadata.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_json = _metadata.convert_to_json(bytes(builder.Output())) + + with open(self._EXPECTED_GENERAL_META_JSON, "r") as f: + expected_json = f.read() + + self.assertEqual(metadata_json, expected_json) + + +class AssociatedFileMdTest(absltest.TestCase): + + _EXPECTED_META_JSON = test_utils.get_test_data_path( + "associated_file_meta.json") + + def test_create_metadata_should_succeed(self): + file_md = metadata_info.AssociatedFileMd( + file_path="label.txt", + description="The label file.", + file_type=_metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS, + locale="en") + file_metadata = file_md.create_metadata() + + # Create the Flatbuffers object and convert it to the json format. + model_metadata = _metadata_fb.ModelMetadataT() + model_metadata.associatedFiles = [file_metadata] + builder = flatbuffers.Builder(0) + builder.Finish( + model_metadata.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_json = _metadata.convert_to_json(bytes(builder.Output())) + + with open(self._EXPECTED_META_JSON, "r") as f: + expected_json = f.read() + self.assertEqual(metadata_json, expected_json) + + +class TensorMdTest(parameterized.TestCase): + + _TENSOR_NAME = "input" + _TENSOR_DESCRIPTION = "The input tensor." + _TENSOR_MIN = 0 + _TENSOR_MAX = 1 + _LABEL_FILE_EN = "labels.txt" + _LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese. + _EXPECTED_FEATURE_TENSOR_JSON = test_utils.get_test_data_path( + "feature_tensor_meta.json") + _EXPECTED_IMAGE_TENSOR_JSON = test_utils.get_test_data_path( + "image_tensor_meta.json") + _EXPECTED_BOUNDING_BOX_TENSOR_JSON = test_utils.get_test_data_path( + "bounding_box_tensor_meta.json") + + @parameterized.named_parameters( + { + "testcase_name": "feature_tensor", + "content_type": _metadata_fb.ContentProperties.FeatureProperties, + "golden_json": _EXPECTED_FEATURE_TENSOR_JSON + }, { + "testcase_name": "image_tensor", + "content_type": _metadata_fb.ContentProperties.ImageProperties, + "golden_json": _EXPECTED_IMAGE_TENSOR_JSON + }, { + "testcase_name": "bounding_box_tensor", + "content_type": _metadata_fb.ContentProperties.BoundingBoxProperties, + "golden_json": _EXPECTED_BOUNDING_BOX_TENSOR_JSON + }) + def test_create_metadata_should_succeed(self, content_type, golden_json): + associated_file1 = metadata_info.AssociatedFileMd( + file_path=self._LABEL_FILE_EN, locale="en") + associated_file2 = metadata_info.AssociatedFileMd( + file_path=self._LABEL_FILE_CN, locale="cn") + + tensor_md = metadata_info.TensorMd( + name=self._TENSOR_NAME, + description=self._TENSOR_DESCRIPTION, + min_values=[self._TENSOR_MIN], + max_values=[self._TENSOR_MAX], + content_type=content_type, + associated_files=[associated_file1, associated_file2]) + tensor_metadata = tensor_md.create_metadata() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_tensor(tensor_metadata)) + with open(golden_json, "r") as f: + expected_json = f.read() + self.assertEqual(metadata_json, expected_json) + + +class InputImageTensorMdTest(parameterized.TestCase): + + _NAME = "image" + _DESCRIPTION = "The input image." + _NORM_MEAN = (0, 127.5, 255) + _NORM_STD = (127.5, 127.5, 127.5) + _COLOR_SPACE_TYPE = _metadata_fb.ColorSpaceType.RGB + _EXPECTED_FLOAT_TENSOR_JSON = test_utils.get_test_data_path( + "input_image_tensor_float_meta.json") + _EXPECTED_UINT8_TENSOR_JSON = test_utils.get_test_data_path( + "input_image_tensor_uint8_meta.json") + _EXPECTED_UNSUPPORTED_TENSOR_JSON = test_utils.get_test_data_path( + "input_image_tensor_unsupported_meta.json") + + @parameterized.named_parameters( + { + "testcase_name": "float", + "tensor_type": _schema_fb.TensorType.FLOAT32, + "golden_json": _EXPECTED_FLOAT_TENSOR_JSON + }, { + "testcase_name": "uint8", + "tensor_type": _schema_fb.TensorType.UINT8, + "golden_json": _EXPECTED_UINT8_TENSOR_JSON + }, { + "testcase_name": "unsupported_tensor_type", + "tensor_type": _schema_fb.TensorType.INT16, + "golden_json": _EXPECTED_UNSUPPORTED_TENSOR_JSON + }) + def test_create_metadata_should_succeed(self, tensor_type, golden_json): + tesnor_md = metadata_info.InputImageTensorMd( + name=self._NAME, + description=self._DESCRIPTION, + norm_mean=list(self._NORM_MEAN), + norm_std=list(self._NORM_STD), + color_space_type=self._COLOR_SPACE_TYPE, + tensor_type=tensor_type) + tensor_metadata = tesnor_md.create_metadata() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_tensor(tensor_metadata)) + with open(golden_json, "r") as f: + expected_json = f.read() + self.assertEqual(metadata_json, expected_json) + + def test_init_should_throw_exception_with_incompatible_mean_and_std(self): + norm_mean = [0] + norm_std = [1, 2] + with self.assertRaises(ValueError) as error: + metadata_info.InputImageTensorMd(norm_mean=norm_mean, norm_std=norm_std) + self.assertEqual( + f"norm_mean and norm_std are expected to be the same dim. But got " + f"{len(norm_mean)} and {len(norm_std)}", str(error.exception)) + + +class ClassificationTensorMdTest(parameterized.TestCase): + + _NAME = "probability" + _DESCRIPTION = "The classification result tensor." + _LABEL_FILE_EN = "labels.txt" + _LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese. + _CALIBRATION_DEFAULT_SCORE = 0.2 + _EXPECTED_FLOAT_TENSOR_JSON = test_utils.get_test_data_path( + "classification_tensor_float_meta.json") + _EXPECTED_UINT8_TENSOR_JSON = test_utils.get_test_data_path( + "classification_tensor_uint8_meta.json") + _EXPECTED_UNSUPPORTED_TENSOR_JSON = test_utils.get_test_data_path( + "classification_tensor_unsupported_meta.json") + + @parameterized.named_parameters( + { + "testcase_name": "float", + "tensor_type": _schema_fb.TensorType.FLOAT32, + "golden_json": _EXPECTED_FLOAT_TENSOR_JSON + }, { + "testcase_name": "uint8", + "tensor_type": _schema_fb.TensorType.UINT8, + "golden_json": _EXPECTED_UINT8_TENSOR_JSON + }, { + "testcase_name": "unsupported_tensor_type", + "tensor_type": _schema_fb.TensorType.INT16, + "golden_json": _EXPECTED_UNSUPPORTED_TENSOR_JSON + }) + def test_create_metadata_should_succeed(self, tensor_type, golden_json): + label_file_en = metadata_info.LabelFileMd( + file_path=self._LABEL_FILE_EN, locale="en") + label_file_cn = metadata_info.LabelFileMd( + file_path=self._LABEL_FILE_CN, locale="cn") + score_calibration_md = metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.IDENTITY, + self._CALIBRATION_DEFAULT_SCORE, _SCORE_CALIBRATION_FILE) + + tesnor_md = metadata_info.ClassificationTensorMd( + name=self._NAME, + description=self._DESCRIPTION, + label_files=[label_file_en, label_file_cn], + tensor_type=tensor_type, + score_calibration_md=score_calibration_md) + tensor_metadata = tesnor_md.create_metadata() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_tensor(tensor_metadata)) + with open(golden_json, "r") as f: + expected_json = f.read() + self.assertEqual(metadata_json, expected_json) + + +class ScoreCalibrationMdTest(absltest.TestCase): + _DEFAULT_VALUE = 0.2 + _EXPECTED_TENSOR_JSON = test_utils.get_test_data_path( + "score_calibration_tensor_meta.json") + _EXPECTED_MODEL_META_JSON = test_utils.get_test_data_path( + "score_calibration_file_meta.json") + + def test_create_metadata_should_succeed(self): + score_calibration_md = metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.LOG, self._DEFAULT_VALUE, + _SCORE_CALIBRATION_FILE) + score_calibration_metadata = score_calibration_md.create_metadata() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_process_uint( + score_calibration_metadata)) + with open(self._EXPECTED_TENSOR_JSON, "r") as f: + expected_json = f.read() + self.assertEqual(metadata_json, expected_json) + + def test_create_score_calibration_file_md_should_succeed(self): + score_calibration_md = metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.LOG, self._DEFAULT_VALUE, + _SCORE_CALIBRATION_FILE) + score_calibration_file_md = ( + score_calibration_md.create_score_calibration_file_md()) + file_metadata = score_calibration_file_md.create_metadata() + + # Create the Flatbuffers object and convert it to the json format. + model_metadata = _metadata_fb.ModelMetadataT() + model_metadata.associatedFiles = [file_metadata] + builder = flatbuffers.Builder(0) + builder.Finish( + model_metadata.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_json = _metadata.convert_to_json(bytes(builder.Output())) + + with open(self._EXPECTED_MODEL_META_JSON, "r") as f: + expected_json = f.read() + self.assertEqual(metadata_json, expected_json) + + def test_create_score_calibration_file_fails_with_less_colunms(self): + with tempfile.TemporaryDirectory() as temp_dir: + malformed_calibration_file = test_utils.create_calibration_file( + temp_dir, content="1.0,0.2") + + with self.assertRaisesRegex( + ValueError, + "Expected empty lines or 3 or 4 parameters per line in score" + + " calibration file, but got 2."): + metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.LOG, self._DEFAULT_VALUE, + malformed_calibration_file) + + def test_create_score_calibration_file_fails_with_negative_scale(self): + with tempfile.TemporaryDirectory() as temp_dir: + malformed_calibration_file = test_utils.create_calibration_file( + temp_dir, content="-1.0,0.2,0.1") + + with self.assertRaisesRegex( + ValueError, + "Expected scale to be a non-negative value, but got -1.0."): + metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.LOG, self._DEFAULT_VALUE, + malformed_calibration_file) + + +def _create_dummy_model_metadata_with_tensor( + tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes: + # Create a dummy model using the tensor metadata. + subgraph_metadata = _metadata_fb.SubGraphMetadataT() + subgraph_metadata.inputTensorMetadata = [tensor_metadata] + model_metadata = _metadata_fb.ModelMetadataT() + model_metadata.subgraphMetadata = [subgraph_metadata] + + # Create the Flatbuffers object and convert it to the json format. + builder = flatbuffers.Builder(0) + builder.Finish( + model_metadata.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + return bytes(builder.Output()) + + +def _create_dummy_model_metadata_with_process_uint( + process_unit_metadata: _metadata_fb.ProcessUnitT) -> bytes: + # Create a dummy model using the tensor metadata. + subgraph_metadata = _metadata_fb.SubGraphMetadataT() + subgraph_metadata.inputProcessUnits = [process_unit_metadata] + model_metadata = _metadata_fb.ModelMetadataT() + model_metadata.subgraphMetadata = [subgraph_metadata] + + # Create the Flatbuffers object and convert it to the json format. + builder = flatbuffers.Builder(0) + builder.Finish( + model_metadata.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + return bytes(builder.Output()) + + +if __name__ == "__main__": + absltest.main() diff --git a/mediapipe/tasks/python/test/test_utils.py b/mediapipe/tasks/python/test/test_utils.py index 531a18f7a..b428f8302 100644 --- a/mediapipe/tasks/python/test/test_utils.py +++ b/mediapipe/tasks/python/test/test_utils.py @@ -43,3 +43,13 @@ def get_test_data_path(file_or_dirname: str) -> str: if f.endswith(file_or_dirname): return os.path.join(directory, f) raise ValueError("No %s in test directory" % file_or_dirname) + + +def create_calibration_file(file_dir: str, + file_name: str = "score_calibration.txt", + content: str = "1.0,2.0,3.0,4.0") -> str: + """Creates the calibration file.""" + calibration_file = os.path.join(file_dir, file_name) + with open(calibration_file, mode="w") as file: + file.write(content) + return calibration_file diff --git a/mediapipe/tasks/testdata/metadata/BUILD b/mediapipe/tasks/testdata/metadata/BUILD index 8bda87ae2..9f50368b8 100644 --- a/mediapipe/tasks/testdata/metadata/BUILD +++ b/mediapipe/tasks/testdata/metadata/BUILD @@ -33,7 +33,21 @@ mediapipe_files(srcs = [ exports_files([ "external_file", + "general_meta.json", "golden_json.json", + "associated_file_meta.json", + "bounding_box_tensor_meta.json", + "classification_tensor_float_meta.json", + "classification_tensor_uint8_meta.json", + "classification_tensor_unsupported_meta.json", + "feature_tensor_meta.json", + "image_tensor_meta.json", + "input_image_tensor_float_meta.json", + "input_image_tensor_uint8_meta.json", + "input_image_tensor_unsupported_meta.json", + "score_calibration.txt", + "score_calibration_file_meta.json", + "score_calibration_tensor_meta.json", ]) filegroup( @@ -51,7 +65,21 @@ filegroup( filegroup( name = "data_files", srcs = [ + "associated_file_meta.json", + "bounding_box_tensor_meta.json", + "classification_tensor_float_meta.json", + "classification_tensor_uint8_meta.json", + "classification_tensor_unsupported_meta.json", "external_file", + "feature_tensor_meta.json", + "general_meta.json", "golden_json.json", + "image_tensor_meta.json", + "input_image_tensor_float_meta.json", + "input_image_tensor_uint8_meta.json", + "input_image_tensor_unsupported_meta.json", + "score_calibration.txt", + "score_calibration_file_meta.json", + "score_calibration_tensor_meta.json", ], ) diff --git a/mediapipe/tasks/testdata/metadata/associated_file_meta.json b/mediapipe/tasks/testdata/metadata/associated_file_meta.json new file mode 100644 index 000000000..2a3d47b96 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/associated_file_meta.json @@ -0,0 +1,10 @@ +{ + "associated_files": [ + { + "name": "label.txt", + "description": "The label file.", + "type": "TENSOR_AXIS_LABELS", + "locale": "en" + } + ] +} diff --git a/mediapipe/tasks/testdata/metadata/bounding_box_tensor_meta.json b/mediapipe/tasks/testdata/metadata/bounding_box_tensor_meta.json new file mode 100644 index 000000000..55b0624c6 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/bounding_box_tensor_meta.json @@ -0,0 +1,35 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input", + "description": "The input tensor.", + "content": { + "content_properties_type": "BoundingBoxProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "locale": "en" + }, + { + "name": "labels_cn.txt", + "locale": "cn" + } + ] + } + ] + } + ] +} diff --git a/mediapipe/tasks/testdata/metadata/classification_tensor_float_meta.json b/mediapipe/tasks/testdata/metadata/classification_tensor_float_meta.json new file mode 100644 index 000000000..1b146d5ea --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/classification_tensor_float_meta.json @@ -0,0 +1,52 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "probability", + "description": "The classification result tensor.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "default_score": 0.2 + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "en" + }, + { + "name": "labels_cn.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "cn" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ] +} diff --git a/mediapipe/tasks/testdata/metadata/classification_tensor_uint8_meta.json b/mediapipe/tasks/testdata/metadata/classification_tensor_uint8_meta.json new file mode 100644 index 000000000..f544afdd6 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/classification_tensor_uint8_meta.json @@ -0,0 +1,52 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "probability", + "description": "The classification result tensor.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "default_score": 0.2 + } + } + ], + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "en" + }, + { + "name": "labels_cn.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "cn" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ] +} diff --git a/mediapipe/tasks/testdata/metadata/classification_tensor_unsupported_meta.json b/mediapipe/tasks/testdata/metadata/classification_tensor_unsupported_meta.json new file mode 100644 index 000000000..98cf17884 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/classification_tensor_unsupported_meta.json @@ -0,0 +1,46 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "probability", + "description": "The classification result tensor.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "default_score": 0.2 + } + } + ], + "stats": { + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "en" + }, + { + "name": "labels_cn.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "cn" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ] +} diff --git a/mediapipe/tasks/testdata/metadata/feature_tensor_meta.json b/mediapipe/tasks/testdata/metadata/feature_tensor_meta.json new file mode 100644 index 000000000..4502d24c2 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/feature_tensor_meta.json @@ -0,0 +1,35 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input", + "description": "The input tensor.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "locale": "en" + }, + { + "name": "labels_cn.txt", + "locale": "cn" + } + ] + } + ] + } + ] +} diff --git a/mediapipe/tasks/testdata/metadata/general_meta.json b/mediapipe/tasks/testdata/metadata/general_meta.json new file mode 100644 index 000000000..70991ee06 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/general_meta.json @@ -0,0 +1,7 @@ +{ + "name": "model", + "description": "A ML model.", + "version": "v1", + "author": "MediaPipe", + "license": "Apache" +} diff --git a/mediapipe/tasks/testdata/metadata/image_tensor_meta.json b/mediapipe/tasks/testdata/metadata/image_tensor_meta.json new file mode 100644 index 000000000..834848cf8 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/image_tensor_meta.json @@ -0,0 +1,35 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input", + "description": "The input tensor.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "locale": "en" + }, + { + "name": "labels_cn.txt", + "locale": "cn" + } + ] + } + ] + } + ] +} diff --git a/mediapipe/tasks/testdata/metadata/input_image_tensor_float_meta.json b/mediapipe/tasks/testdata/metadata/input_image_tensor_float_meta.json new file mode 100644 index 000000000..2f9b288c3 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/input_image_tensor_float_meta.json @@ -0,0 +1,47 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "The input image.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 0.0, + 127.5, + 255.0 + ], + "std": [ + 127.5, + 127.5, + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 2.0, + 1.0, + 0.0 + ], + "min": [ + 0.0, + -1.0, + -2.0 + ] + } + } + ] + } + ] +} diff --git a/mediapipe/tasks/testdata/metadata/input_image_tensor_uint8_meta.json b/mediapipe/tasks/testdata/metadata/input_image_tensor_uint8_meta.json new file mode 100644 index 000000000..fc1d84023 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/input_image_tensor_uint8_meta.json @@ -0,0 +1,43 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "The input image.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 0.0, + 127.5, + 255.0 + ], + "std": [ + 127.5, + 127.5, + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + } + } + ] + } + ] +} diff --git a/mediapipe/tasks/testdata/metadata/input_image_tensor_unsupported_meta.json b/mediapipe/tasks/testdata/metadata/input_image_tensor_unsupported_meta.json new file mode 100644 index 000000000..09a05aabd --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/input_image_tensor_unsupported_meta.json @@ -0,0 +1,37 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "The input image.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 0.0, + 127.5, + 255.0 + ], + "std": [ + 127.5, + 127.5, + 127.5 + ] + } + } + ], + "stats": { + } + } + ] + } + ] +} diff --git a/mediapipe/tasks/testdata/metadata/score_calibration.txt b/mediapipe/tasks/testdata/metadata/score_calibration.txt new file mode 100644 index 000000000..6b3e1dc7f --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/score_calibration.txt @@ -0,0 +1,511 @@ + +0.9876328110694885,0.36622241139411926,0.5352765321731567,0.71484375 +0.9584911465644836,1.0602262020111084,0.2777034342288971,0.019999999552965164 +0.9698624014854431,0.8795201778411865,0.539591908454895,0.00390625 +0.7486230731010437,1.1876736879348755,2.552982807159424,0.019999999552965164 +0.9745277166366577,0.3739396333694458,0.4621727764606476,0.19921875 +0.9683839678764343,0.6996201276779175,0.7690851092338562,0.019999999552965164 +0.6875,0.31044548749923706,1.0056899785995483,0.019999999552965164 +0.9849396347999573,0.8532888889312744,-0.2361421436071396,0.03125 +0.9878578186035156,1.0118975639343262,0.13313621282577515,0.359375 +0.9915205836296082,0.4434199929237366,1.0268371105194092,0.05078125 +0.9370332360267639,0.4586562216281891,-0.08101099729537964,0.019999999552965164 +0.9905818104743958,0.8670706152915955,0.012704282067716122,0.019999999552965164 +0.9080020189285278,0.8507471680641174,0.5081117749214172,0.019999999552965164 +0.985953152179718,0.9933826923370361,-0.8114940524101257,0.109375 +0.9819648861885071,1.12098228931427,-0.6330763697624207,0.01171875 +0.9025918245315552,0.7803755402565002,0.03275677561759949,0.08984375 +0.9863958954811096,0.11243592947721481,0.935604453086853,0.61328125 +0.9905291795730591,0.3710605800151825,0.708966851234436,0.359375 +0.9917052984237671,0.9596433043479919,0.19800108671188354,0.09765625 +0.8762937188148499,0.3449830114841461,0.5352474451065063,0.0078125 +0.9902125000953674,0.8918796181678772,-0.1306992471218109,0.26171875 + +0.9902340173721313,0.9177873134613037,-0.4322589933872223,0.019999999552965164 +0.9707600474357605,0.7028177976608276,0.9813734889030457,0.019999999552965164 +0.9823090434074402,1.0499590635299683,0.12045472860336304,0.0078125 +0.990516185760498,0.9449402093887329,1.3773189783096313,0.019999999552965164 +0.9875434041023254,0.577914297580719,1.282518982887268,0.0390625 +0.9821421504020691,0.0967339277267456,0.8279788494110107,0.47265625 +0.9875047206878662,0.9038218259811401,2.1208062171936035,0.38671875 +0.9857864379882812,0.8627446889877319,0.18189261853694916,0.019999999552965164 +0.9647751450538635,1.0752476453781128,-0.018294010311365128,0.0234375 +0.9830358624458313,0.5638481378555298,0.8346489667892456,0.019999999552965164 +0.9904966354370117,1.0160938501358032,-0.0573287308216095,0.00390625 +0.8458405137062073,0.4868394434452057,0.6617084741592407,0.019999999552965164 +0.9847381711006165,0.5939620137214661,0.008616370148956776,0.00390625 +0.9375938773155212,0.723095178604126,0.6635608077049255,0.019999999552965164 +0.9334303140640259,0.5689108967781067,0.37019580602645874,0.019999999552965164 +0.9716793894767761,1.0037211179733276,0.5898993611335754,0.02734375 +0.9197732210159302,0.46794334053993225,0.7365336418151855,0.640625 +0.9857497811317444,0.7299028635025024,0.9195274114608765,0.0390625 +0.8758038282394409,1.200216293334961,0.02580185979604721,0.019999999552965164 +0.9841026067733765,0.8050475716590881,0.9698556661605835,0.0078125 +0.9908539652824402,0.7911490201950073,0.19351358711719513,0.12109375 +0.9179956316947937,0.023991893976926804,0.35193610191345215,0.04296875 +0.9903728365898132,0.7744967341423035,0.2686336636543274,0.359375 +0.906022846698761,0.5766159892082214,1.0600007772445679,0.04296875 +0.9885554909706116,0.99117511510849,0.5611960291862488,0.4140625 +0.9906331896781921,1.1376535892486572,1.45369291305542,0.019999999552965164 +0.9640991687774658,0.5387894511222839,1.1824018955230713,0.019999999552965164 +0.9932155609130859,0.4347895085811615,1.3938102722167969,0.0078125 +0.9884702563285828,0.885567843914032,0.1556047648191452,0.1484375 +0.9891508221626282,0.04143073782324791,0.6111864447593689,0.0078125 +0.8935436010360718,0.2937895655632019,0.3215920031070709,0.00390625 +0.8327123522758484,0.8381986021995544,-0.026293788105249405,0.019999999552965164 +0.9839455485343933,0.9581400156021118,1.495324969291687,0.640625 +0.9904995560646057,0.9168422818183899,0.33293962478637695,0.015625 +0.9856975674629211,1.0433714389801025,0.5954801440238953,0.019999999552965164 +0.9942344427108765,0.7206616997718811,1.666426181793213,0.9609375 +0.8182767033576965,0.9546273946762085,0.5500107407569885,0.019999999552965164 +0.9631295800209045,0.6277880668640137,0.05952891707420349,0.05859375 +0.9819005727767944,1.0826934576034546,0.7444049715995789,0.30859375 +0.9884315133094788,1.0500890016555786,1.1161768436431885,0.019999999552965164 +0.9175815582275391,0.09232989698648453,1.596696138381958,0.47265625 +0.9868760108947754,0.903079628944397,-0.15774966776371002,0.8515625 +0.9866015911102295,0.7533788084983826,0.7489103078842163,0.03125 +0.8074312806129456,0.8615151643753052,0.40621864795684814,0.00390625 +0.9829285144805908,0.8954831957817078,0.4462486207485199,0.02734375 +0.9681841135025024,0.6257772445678711,0.43809664249420166,0.38671875 +0.9872947931289673,0.9947993159294128,0.9271130561828613,0.26171875 +0.7997345328330994,0.3995186686515808,-0.3755347430706024,0.019999999552965164 +0.9922754168510437,1.1357101202011108,-0.10267537832260132,0.5 +0.9861471652984619,0.8725204467773438,1.1657888889312744,0.019999999552965164 +0.9888646006584167,1.2098380327224731,-0.27832522988319397,0.05078125 +0.5641342997550964,1.0501892566680908,1.9519661664962769,0.019999999552965164 +0.9548168778419495,0.8971696496009827,1.378737449645996,0.00390625 +0.9875019788742065,0.8718118071556091,0.5476236939430237,0.0078125 +0.9725168347358704,0.6989551782608032,-1.3157455921173096,0.61328125 +0.9864014983177185,0.7576251029968262,-0.41650667786598206,0.00390625 +0.960071861743927,0.13068856298923492,0.4819187819957733,0.019999999552965164 +0.9849705100059509,0.7724528908729553,0.3877875804901123,0.03125 +0.9703006744384766,0.8848260641098022,-1.1767181158065796,0.80078125 +0.9837008714675903,0.7015050053596497,0.18209102749824524,0.00390625 +0.9579976797103882,0.053806986659765244,2.7309608459472656,0.4000000059604645 +0.9896979928016663,0.41135814785957336,0.5738034844398499,0.019999999552965164 +0.9853873252868652,0.5438565611839294,0.20562179386615753,0.02734375 +0.9784129858016968,0.6330984830856323,-0.1789831817150116,0.015625 +0.9375,0.855596125125885,-0.1933964192867279,0.019999999552965164 +0.9524176716804504,0.08709807693958282,0.6299692988395691,0.33203125 +0.9808038473129272,1.2909820079803467,0.3397117257118225,0.00390625 +0.8008236885070801,0.7974631786346436,1.0567312240600586,0.019999999552965164 +0.9421642422676086,0.6754576563835144,0.32419073581695557,0.23828125 +0.9072281718254089,1.1716840267181396,-0.10382208228111267,0.00390625 +0.9497162103652954,1.1582106351852417,-0.11845408380031586,0.00390625 +0.9773319959640503,0.5042116641998291,1.2815768718719482,0.23828125 +0.9743752479553223,1.1731196641921997,0.48585158586502075,0.1640625 +0.9601503610610962,1.0114264488220215,-0.9113408327102661,0.38671875 +0.97279292345047,0.32572469115257263,0.548393964767456,0.01171875 +0.9845231175422668,0.9852075576782227,1.0973742008209229,0.69140625 +0.9764596223831177,0.2248251885175705,0.8963963985443115,0.33203125 +0.8746626377105713,0.016590777784585953,1.4492003917694092,0.359375 +0.9726155996322632,0.8712832927703857,-0.6451321840286255,0.52734375 +0.980800211429596,0.8469374775886536,0.0718703418970108,0.04296875 +0.7734344005584717,0.8508065342903137,0.4233662784099579,0.019999999552965164 +0.969182550907135,0.8082079887390137,-0.4314402937889099,0.0234375 +0.9037994742393494,0.1387290209531784,1.8660004138946533,0.5 +0.9869260191917419,0.6927974820137024,0.4927133619785309,0.019999999552965164 +0.8794143795967102,0.8060213327407837,-0.6247795820236206,0.09765625 +0.9895913600921631,0.8851431012153625,0.9641156196594238,0.28515625 +0.9833245873451233,0.9379183053970337,1.5143399238586426,0.0078125 +0.26580730080604553,1.488408088684082,2.5120370388031006,0.019999999552965164 +0.9859549403190613,1.5805137157440186,0.7283271551132202,0.01171875 +0.9376091361045837,0.6854841709136963,0.20175717771053314,0.00390625 +0.965065598487854,0.7363166213035583,-0.3636060357093811,0.1484375 +0.9904685020446777,0.9182849526405334,0.30159056186676025,0.05859375 +0.5014551877975464,0.7409977912902832,0.2045259326696396,0.019999999552965164 +0.9434370398521423,0.3679845631122589,0.6447131633758545,0.38671875 +0.9806621670722961,0.9568924307823181,1.2417932748794556,0.019999999552965164 +0.9825865626335144,1.2273900508880615,-0.0674915760755539,0.0390625 +0.9859767556190491,0.7635276317596436,-0.8502742648124695,0.109375 +0.9701240658760071,0.46266916394233704,0.38697123527526855,0.0703125 +0.9651575088500977,0.5057743191719055,0.6578569412231445,0.0078125 + +0.9685596227645874,0.6961715817451477,0.20829983055591583,0.015625 +0.9772806167602539,0.8312440514564514,-0.09966880083084106,0.019999999552965164 +0.9718109369277954,0.8248763680458069,1.2387524843215942,0.08984375 +0.9890084266662598,2.0058324337005615,1.7648913860321045,0.019999999552965164 +0.9813475608825684,1.02803373336792,1.4689184427261353,0.019999999552965164 +0.9925220608711243,0.8020634055137634,0.7509317994117737,0.015625 +0.9754987955093384,0.5145153999328613,0.4638928472995758,0.00390625 +0.9735408425331116,0.7434492111206055,0.06251777708530426,0.01171875 +0.8753963112831116,1.6830265522003174,4.509310722351074,0.019999999552965164 +0.9385876655578613,0.46194836497306824,0.13496099412441254,0.13671875 +0.9676342010498047,0.5462782979011536,0.9306238889694214,0.1796875 +0.9829097986221313,0.8054409623146057,0.11194216459989548,0.08984375 +0.9503080248832703,0.44028621912002563,0.4689175486564636,0.00390625 +0.9808863997459412,0.8023126721382141,-0.022534284740686417,0.015625 +0.9079821109771729,0.33415740728378296,0.544142484664917,0.019999999552965164 +0.9839802980422974,0.9184480905532837,0.2658761739730835,0.1484375 +0.75,0.8216301798820496,0.3300539255142212,0.019999999552965164 +0.9590148329734802,0.722118616104126,0.255025178194046,0.015625 +0.9616804122924805,0.8398274779319763,0.33006206154823303,0.019999999552965164 +0.7859238386154175,0.5596626400947571,0.5452361702919006,0.019999999552965164 +0.9842674732208252,0.07029404491186142,1.189304232597351,0.30859375 +0.7237641215324402,0.2756437361240387,-0.10612351447343826,0.019999999552965164 +0.9793540239334106,0.5117573738098145,0.8033715486526489,0.01953125 +0.9825188517570496,0.3965616822242737,0.17742407321929932,0.019999999552965164 +0.9859991073608398,1.32109534740448,0.5763598084449768,0.019999999552965164 +0.9551243782043457,0.3639756143093109,0.19449777901172638,0.00390625 +0.9606218338012695,0.8222983479499817,0.43461644649505615,0.00390625 +0.9785885810852051,0.9104304909706116,0.2279568761587143,0.01171875 +0.9705367684364319,0.0769517719745636,0.7330215573310852,0.04296875 +0.9736841320991516,0.9110560417175293,0.10864781588315964,0.05859375 +0.9880238771438599,1.1702078580856323,0.05487633869051933,0.00390625 +0.9913991093635559,0.7445327043533325,1.2198610305786133,0.01171875 +0.8302573561668396,0.33997753262519836,1.0731935501098633,0.019999999552965164 +0.9880614280700684,0.9227356910705566,2.1198885440826416,0.61328125 +0.9173498153686523,0.2221490740776062,0.11565151065587997,0.0078125 +0.962620735168457,1.011454701423645,-1.5519139766693115,0.8203125 +0.9828791618347168,0.7543124556541443,0.29118794202804565,0.00390625 +0.9908701181411743,0.8183356523513794,0.48734790086746216,0.019999999552965164 +0.5002585649490356,0.12179236859083176,0.20199841260910034,0.019999999552965164 +0.9631574153900146,0.41631683707237244,1.1000276803970337,0.44140625 +0.9875426888465881,0.8117235898971558,0.8689690232276917,0.08203125 +0.9410585761070251,0.3703889548778534,0.7951740026473999,0.0078125 +0.9877454042434692,0.2155231237411499,1.635109305381775,0.94921875 +0.9860436320304871,1.0054532289505005,-0.9608616232872009,0.03125 +0.9721421003341675,0.5174740552902222,0.43327680230140686,0.0078125 + +0.9908374547958374,0.8122930526733398,0.21533408761024475,0.0078125 +0.9896888136863708,0.7030488848686218,-0.062063876539468765,0.01953125 +0.9861313700675964,0.49431633949279785,0.981758177280426,0.12109375 +0.9792494177818298,1.0670701265335083,0.7028639316558838,0.019999999552965164 +0.9871346950531006,1.3606067895889282,-3.00394868850708,0.61328125 +0.9583333134651184,0.9180184602737427,-0.05760742351412773,0.019999999552965164 +0.9764145612716675,0.5258041024208069,1.1425464153289795,0.019999999552965164 +0.9076833128929138,1.081973910331726,0.6340405344963074,0.019999999552965164 +0.9895729422569275,0.27958083152770996,1.2441545724868774,0.08203125 +0.916824221611023,0.7878308892250061,-1.3060243129730225,0.359375 +0.9883677363395691,0.6098470687866211,0.7665972709655762,0.52734375 +0.949999988079071,0.818132758140564,1.5476282835006714,0.019999999552965164 + +0.9666821360588074,0.707548201084137,0.7326748967170715,0.00390625 +0.9861665368080139,0.7194502353668213,2.1585183143615723,0.38671875 +0.9811879992485046,0.32190269231796265,0.31508582830429077,0.05078125 +0.9625869989395142,0.11173010617494583,0.9030138850212097,0.019999999552965164 +0.9675677418708801,0.49738144874572754,0.5481624007225037,0.019999999552965164 +0.9764066934585571,1.0306450128555298,0.2257029116153717,0.00390625 +0.9857029318809509,0.8312124013900757,-0.12777498364448547,0.00390625 +0.9781621098518372,0.621485710144043,0.3126043975353241,0.21875 +0.9705549478530884,0.15182119607925415,1.7296228408813477,0.13671875 +0.9801698923110962,0.8953424692153931,0.6697174310684204,0.019999999552965164 +0.9842199087142944,0.7984838485717773,0.7436375617980957,0.0078125 +0.9159231185913086,0.05519663542509079,0.011483916081488132,0.47265625 +0.9742691516876221,0.9268448352813721,1.1530364751815796,0.019999999552965164 +0.9579406380653381,0.7879363894462585,1.1582229137420654,0.00390625 +0.8999202251434326,0.8120636343955994,0.37021151185035706,0.019999999552965164 +0.9870507121086121,1.1666820049285889,1.387096881866455,0.019999999552965164 +0.9769532680511475,0.6519474983215332,0.3170791268348694,0.109375 +0.9546447396278381,0.7559569478034973,0.9533731937408447,0.0078125 +0.9773718118667603,1.3183629512786865,1.0090563297271729,0.019999999552965164 +0.9049819707870483,1.0706751346588135,1.7704588174819946,0.019999999552965164 +0.9003662467002869,0.7251236438751221,-1.4905513525009155,0.4140625 +0.9834321141242981,0.5246152877807617,1.2191725969314575,0.47265625 +0.9748008847236633,0.8448761105537415,-0.01744924671947956,0.00390625 +0.9904628396034241,0.8762193918228149,0.22459718585014343,0.01171875 +0.6833457946777344,0.8996955752372742,1.2423095703125,0.019999999552965164 +0.9909645318984985,0.8978683948516846,0.7022045254707336,0.019999999552965164 +0.9843918681144714,0.12815311551094055,1.5720607042312622,0.78125 +0.9382115602493286,0.4989806115627289,1.1206520795822144,0.03515625 +0.9832627177238464,0.6727185845375061,0.2797912657260895,0.08984375 +0.8830162286758423,1.1294968128204346,1.1474463939666748,0.019999999552965164 +0.9554208517074585,0.9476046562194824,0.8490120768547058,0.019999999552965164 +0.98823082447052,0.7835749983787537,0.5608289837837219,0.03515625 +0.9790570139884949,0.9982950091362,0.3763321042060852,0.00390625 +0.5039305686950684,0.9079190492630005,1.265581488609314,0.019999999552965164 +0.9871423840522766,0.6633929014205933,0.09028752893209457,0.019999999552965164 +0.8614975214004517,0.9595098495483398,-0.5349600315093994,0.00390625 +0.9873358011245728,0.698331892490387,0.7571848630905151,0.1484375 +0.7227392196655273,1.1300171613693237,1.1754553318023682,0.019999999552965164 +0.9814568758010864,0.46864795684814453,0.6286783218383789,0.19921875 +0.9876973032951355,0.29863566160202026,0.7726709842681885,0.61328125 +0.9887779951095581,1.1818888187408447,-1.0321481227874756,0.38671875 +0.9684743285179138,0.7226923108100891,0.0908145159482956,0.0390625 +0.9854185581207275,1.0576037168502808,0.35190048813819885,0.0078125 +0.9463624954223633,0.781932532787323,0.7598024606704712,0.01171875 +0.9837555885314941,0.8735848665237427,0.5948384404182434,0.019999999552965164 +0.9700835347175598,0.45710718631744385,2.141801357269287,0.8359375 +0.9896127581596375,1.018708348274231,0.23626597225666046,0.01953125 +0.7728451490402222,8.084141001063472e-08,0.7415778636932373,0.4000000059604645 +0.9838477969169617,0.8994008302688599,0.15494465827941895,0.00390625 +0.9421281218528748,0.4648025333881378,0.12706322968006134,0.00390625 +0.9843724370002747,1.0055731534957886,-0.911835253238678,0.23828125 +0.958256185054779,1.1208757162094116,-0.31016042828559875,0.0078125 +0.9832971692085266,0.056124646216630936,1.7148709297180176,0.23828125 +0.9804430603981018,0.4016909897327423,0.6085042357444763,0.0703125 +0.9825966358184814,0.9228396415710449,0.912163257598877,0.019999999552965164 +0.9441317915916443,0.048142336308956146,0.6141980290412903,0.109375 +0.9856440424919128,0.8616625666618347,0.28943121433258057,0.015625 +0.9913654923439026,1.0482347011566162,0.6889304518699646,0.015625 +0.97914719581604,0.8870795369148254,-0.700239360332489,0.015625 +0.9836585521697998,0.5450212955474854,0.009687358513474464,0.01953125 +0.990472137928009,0.8221097588539124,2.5926225185394287,0.97265625 +0.6274135708808899,0.6787079572677612,0.12988793849945068,0.015625 +0.982601523399353,0.7495649456977844,1.2217103242874146,0.019999999552965164 +0.9841020703315735,0.9071263670921326,1.3682825565338135,0.09765625 +0.9872562885284424,0.818276584148407,-0.14663955569267273,0.05859375 +0.5041943192481995,0.35444244742393494,0.46112486720085144,0.00390625 +0.7517910599708557,0.91172856092453,1.3611085414886475,0.019999999552965164 +0.9861181378364563,1.0613479614257812,-0.46272075176239014,0.015625 +0.9914185404777527,0.9464229941368103,1.2103853225708008,0.0234375 +0.984909176826477,0.5985794067382812,0.7704220414161682,0.08203125 +0.9575125575065613,0.7695640325546265,0.6132461428642273,0.00390625 +0.9845197200775146,0.7421835064888,1.332088589668274,0.019999999552965164 +0.9470700621604919,0.357934832572937,1.0986406803131104,0.359375 +0.9287161231040955,0.6833012104034424,0.373298704624176,0.00390625 +0.9531774520874023,0.3247152864933014,0.6011538505554199,0.66796875 +0.9779354929924011,0.828241229057312,0.3349589705467224,0.03125 +0.9863978028297424,0.932086169719696,0.04865559563040733,0.02734375 +0.9826814532279968,0.06353739649057388,1.879408359527588,0.61328125 +0.974474310874939,0.8063777685165405,0.8257133364677429,0.019999999552965164 +0.9670184254646301,0.09195757657289505,1.7024414539337158,0.5 +0.9885809421539307,0.7981435656547546,-0.11792337149381638,0.0703125 +0.9829109907150269,0.9578585028648376,-1.9371291399002075,0.13671875 +0.9754639863967896,1.137816071510315,0.5887423157691956,0.00390625 +0.9755549430847168,0.677255392074585,0.20494212210178375,0.00390625 +0.9903355836868286,1.0475162267684937,2.1768462657928467,0.52734375 +0.9855127930641174,0.9580414891242981,0.35021960735321045,0.76171875 +0.9450457692146301,0.4737727642059326,-0.3041325807571411,0.01171875 +0.9360163807868958,0.9219141006469727,1.2481396198272705,0.019999999552965164 +0.9696909189224243,0.06589268147945404,1.456658124923706,0.30000001192092896 +0.6495901942253113,0.8538134098052979,0.3043774366378784,0.019999999552965164 +0.9901140928268433,0.8112474679946899,0.7102972269058228,0.019999999552965164 +0.9925929307937622,0.49307680130004883,0.6297348737716675,0.019999999552965164 +0.9840761423110962,0.5691578388214111,0.9437046647071838,0.00390625 +0.9625457525253296,0.9322702288627625,1.3358750343322754,0.0234375 +0.9820173978805542,0.6805416345596313,1.0065922737121582,0.05859375 +0.9883391261100769,0.742003321647644,0.6168643236160278,0.0078125 +0.9119130969047546,0.8404607176780701,0.8882355690002441,0.01171875 +0.9854885935783386,1.295777440071106,0.5272557735443115,0.00390625 +0.9911734461784363,1.152715802192688,-0.05230601131916046,0.019999999552965164 +0.8071879744529724,0.4576769471168518,1.391660451889038,0.00390625 +0.9919166564941406,1.1775370836257935,0.5039792060852051,0.019999999552965164 +0.9831258654594421,0.9164834022521973,0.3790256977081299,0.01171875 +0.990642249584198,0.9242916107177734,1.477474570274353,0.38671875 +0.7415178418159485,0.2909083068370819,0.19971248507499695,0.019999999552965164 +0.9146556854248047,0.06850286573171616,1.3211928606033325,0.61328125 +0.976986825466156,0.6469135284423828,-0.7279839515686035,0.02734375 +0.968462347984314,0.4640704393386841,1.4650955200195312,0.1484375 +0.937825083732605,0.9767780303955078,-0.7378027439117432,0.0390625 +0.9878604412078857,1.1423084735870361,1.7311146259307861,0.1484375 +0.9904257655143738,0.9551829099655151,1.564165472984314,0.00390625 +0.9830996990203857,0.92529296875,-0.1086890697479248,0.02734375 + +0.9820512533187866,0.7556048631668091,0.6512532830238342,0.109375 +0.9740781188011169,0.8380919098854065,0.19731587171554565,0.019999999552965164 +0.9830799698829651,1.183397650718689,-0.801214873790741,0.019999999552965164 +0.9898439049720764,1.168870210647583,1.2985308170318604,0.00390625 +0.97286057472229,0.8012385964393616,-1.657444953918457,0.09765625 +0.9182834625244141,0.5254654884338379,-0.027080848813056946,0.04296875 +0.9729798436164856,0.4111078381538391,1.077646255493164,0.019999999552965164 +0.6875,1.756393551826477,0.34522199630737305,0.019999999552965164 +0.9920725226402283,1.0676580667495728,1.1592471599578857,0.019999999552965164 +0.37564563751220703,0.07466565072536469,0.3562135696411133,0.019999999552965164 +0.9894161224365234,0.8109862804412842,1.3056280612945557,0.0390625 +0.9386259317398071,0.5322021842002869,-0.03461914509534836,0.08984375 +0.9866133332252502,0.8940346240997314,1.0361984968185425,0.00390625 +0.9822850823402405,0.6215930581092834,-0.6859042048454285,0.00390625 +0.9752063155174255,1.0129338502883911,0.3866007626056671,0.019999999552965164 +0.9825329184532166,0.567034125328064,0.5370683670043945,0.5 +0.9422088861465454,0.9411858320236206,0.5332568883895874,0.38671875 +0.9506444931030273,0.7494101524353027,0.9869776368141174,0.00390625 +0.9923189282417297,1.1255286931991577,0.8734608292579651,0.019999999552965164 +0.9807777404785156,0.9558923244476318,1.5415621995925903,0.09765625 +0.961335301399231,0.7840818762779236,0.06915930658578873,0.00390625 +0.9867202639579773,1.0596263408660889,0.21268242597579956,0.0078125 +0.9926426410675049,0.8886650204658508,0.6200761198997498,0.019999999552965164 +0.9791930913925171,0.4474319517612457,0.5827012062072754,0.019999999552965164 +0.986801028251648,1.1846712827682495,1.4253416061401367,0.00390625 +0.9549052119255066,0.6142332553863525,0.4867286682128906,0.00390625 +0.983259916305542,0.42561075091362,0.9666317105293274,0.08203125 +0.98175048828125,0.7744573354721069,0.4953071177005768,0.019999999552965164 +0.987273097038269,0.8209654092788696,0.5267868041992188,0.019999999552965164 +0.9916341304779053,0.6881924271583557,0.9522916078567505,0.019999999552965164 +0.9819192886352539,0.8128346800804138,0.6556753516197205,0.05859375 +0.9854727387428284,0.6597779393196106,0.9645410180091858,0.8359375 +0.9891805648803711,0.7752296924591064,1.34084153175354,0.52734375 +0.9489904046058655,0.6988677978515625,0.5052891969680786,0.019999999552965164 +0.9741962552070618,0.43797168135643005,0.7825477123260498,0.01171875 +0.9907783269882202,0.8732656240463257,1.1458243131637573,0.19921875 +0.9760454297065735,0.7810378670692444,-0.29553040862083435,0.015625 +0.9885720014572144,0.8427382707595825,0.2628841996192932,0.019999999552965164 +0.8171960115432739,0.3271152079105377,1.30915105342865,0.26171875 +0.9881270527839661,0.13021250069141388,1.6307408809661865,0.55859375 +0.9751906991004944,0.8255484104156494,0.21788427233695984,0.019999999552965164 +0.9630831480026245,2.1396600701476974e-15,2.883542776107788,0.5 +0.8849332332611084,0.888649582862854,1.0651483535766602,0.01171875 +0.9897550344467163,0.08640030771493912,2.661073923110962,0.69140625 +0.9030827879905701,0.7017505168914795,0.07822071760892868,0.00390625 +0.9650112986564636,0.36098214983940125,0.7112777829170227,0.0078125 +0.9872719049453735,0.7115703821182251,0.6924230456352234,0.019999999552965164 +0.5884749889373779,0.0942283645272255,0.24825790524482727,0.019999999552965164 +0.9642857313156128,0.5304845571517944,0.6281308531761169,0.019999999552965164 +0.9651434421539307,0.07168509066104889,1.4704163074493408,0.61328125 +0.9779187440872192,1.0171563625335693,-2.8089962005615234,0.1484375 +0.9375227689743042,0.9291267991065979,0.6853470802307129,0.019999999552965164 +0.9820515513420105,0.7226945757865906,-0.19336646795272827,0.61328125 +0.984882652759552,0.8176864385604858,1.161419153213501,0.0078125 +0.9573767185211182,0.9027169346809387,0.15423306822776794,0.26171875 +0.9059234261512756,0.872424840927124,0.7419941425323486,0.019999999552965164 +0.9914654493331909,1.0662620067596436,2.7141172885894775,0.55859375 +0.9839044809341431,0.9037585854530334,0.7042809724807739,0.01953125 +0.986689567565918,0.6848335266113281,0.9014078974723816,0.00390625 +0.9837497472763062,0.7507086396217346,0.7179840207099915,0.0078125 +0.9895229339599609,1.1564929485321045,0.5822750926017761,0.019999999552965164 +0.9845471978187561,0.8716567158699036,0.19987598061561584,0.01953125 +0.971385657787323,0.49073365330696106,1.2333439588546753,0.73828125 +0.9841684699058533,0.6468350887298584,1.0000839233398438,0.0703125 +0.9882851839065552,0.26080548763275146,0.8985073566436768,0.01171875 +0.9851044416427612,0.8687262535095215,0.07842865586280823,0.1796875 +0.9799972772598267,0.25032666325569153,1.2494641542434692,0.10000000149011612 +0.9896620512008667,0.7762697339057922,0.20227234065532684,0.019999999552965164 +0.990495502948761,0.15801414847373962,1.006077766418457,0.01171875 +0.9806667566299438,0.7082678079605103,0.35462483763694763,0.02734375 +0.9715457558631897,0.0615643672645092,0.9478678703308105,0.4000000059604645 +0.9168440103530884,0.5679594874382019,-0.6143214106559753,0.1484375 +0.9824567437171936,0.45072048902511597,1.0683321952819824,0.1484375 +0.9840478301048279,0.08733312040567398,1.3535010814666748,0.47265625 +0.9896746873855591,1.1761761903762817,0.7102295756340027,0.94140625 +0.9827673435211182,0.8215981125831604,0.6729252338409424,0.019999999552965164 +0.9906817674636841,0.16318124532699585,1.133107304573059,0.30000001192092896 +0.9701097011566162,1.0519390106201172,-0.16105352342128754,0.00390625 +0.9417809844017029,0.7868722081184387,1.1539735794067383,0.019999999552965164 +0.9615354537963867,0.8469739556312561,0.6801642179489136,0.0390625 +0.988472580909729,0.81600022315979,0.6296193599700928,0.019999999552965164 +0.9841001629829407,0.8400164246559143,-0.06806250661611557,0.00390625 +0.9276565313339233,0.32582467794418335,-0.14148345589637756,0.019999999552965164 +0.7008209228515625,0.545078694820404,1.1250351667404175,0.019999999552965164 +0.9907881021499634,0.9919379353523254,-0.12143492698669434,0.019999999552965164 +0.9702130556106567,0.7762024402618408,0.24524429440498352,0.0078125 +0.9876235723495483,0.7181832790374756,0.41931474208831787,0.019999999552965164 +0.9841905236244202,0.8836563229560852,0.28947240114212036,0.00390625 +0.990247905254364,0.9825950860977173,0.6003378033638,0.00390625 +0.9635987281799316,0.3707619905471802,-0.03457726538181305,0.0390625 +0.9924789071083069,1.485293984413147,0.5796234607696533,0.00390625 +0.9839015603065491,0.06343062222003937,1.9442640542984009,0.5 +0.9927193522453308,0.7006005048751831,0.3714500069618225,0.019999999552965164 +0.9870567321777344,0.869498610496521,1.5008329153060913,0.00390625 +0.9002388119697571,0.4945279657840729,-0.27996397018432617,0.0078125 +0.98891282081604,0.8541091680526733,0.5112633109092712,0.66796875 +0.9001862406730652,0.43330734968185425,0.3592444360256195,0.00390625 +0.958705723285675,0.7425220012664795,0.15833647549152374,0.00390625 +0.9910086989402771,0.9245886206626892,0.8454338908195496,0.01953125 +0.9912900328636169,1.3806378841400146,1.0953043699264526,0.99609375 +0.9887956976890564,1.0331758260726929,0.6490115523338318,0.640625 +0.8638584017753601,0.902369499206543,-0.2767508327960968,0.0078125 +0.7059138417243958,1.0,1.032223091723683e-11,0.019999999552965164 +0.9889519810676575,0.8361310362815857,0.811896800994873,0.03515625 +0.970467209815979,0.07315781712532043,0.20799599587917328,0.00390625 +0.9828550219535828,0.8393198251724243,0.6089786291122437,0.28515625 +0.9553551077842712,0.7775288820266724,-0.4464336037635803,0.046875 +0.9782186150550842,0.4313304126262665,0.4458310604095459,0.019999999552965164 +0.9371097087860107,0.9338632225990295,1.3358187675476074,0.019999999552965164 +0.9861361384391785,0.24091234803199768,1.4301774501800537,0.80078125 +0.9890525341033936,1.1365840435028076,0.3055979013442993,0.00390625 +0.957517683506012,0.058012738823890686,0.15909947454929352,0.046875 +0.9762251377105713,0.72292160987854,0.49151331186294556,0.019999999552965164 +0.9875496625900269,0.9114606976509094,-0.5052767992019653,0.05859375 +0.9715835452079773,0.8113637566566467,-2.0302956104278564,0.019999999552965164 +0.9846333265304565,0.49688151478767395,0.7285738587379456,0.019999999552965164 +0.98553466796875,0.1484774351119995,1.3616747856140137,0.5859375 +0.9866309762001038,1.0217945575714111,-0.8717418313026428,0.02734375 +0.9891880750656128,0.42588523030281067,0.7833192944526672,0.109375 +0.9870361685752869,0.8525673151016235,1.2773776054382324,0.019999999552965164 +0.9897037744522095,0.8012522459030151,0.3973642885684967,0.109375 +0.9828903079032898,1.1558295488357544,-0.6781614422798157,0.5859375 +0.9924454689025879,1.1040401458740234,1.3243318796157837,0.019999999552965164 +0.9826735258102417,1.0064337253570557,-0.5324167013168335,0.38671875 +0.949999988079071,0.8152432441711426,0.6293236613273621,0.00390625 +0.9905489087104797,0.9191447496414185,0.5621309876441956,0.019999999552965164 +0.9664857387542725,0.5995981693267822,-0.7409313321113586,0.01171875 +0.9847198724746704,0.8284208178520203,0.2851041555404663,0.9296875 +0.9342833757400513,0.5566492676734924,0.6875373721122742,0.019999999552965164 +0.8894915580749512,0.4102778434753418,0.37977635860443115,0.01953125 +0.9870865941047668,0.44245558977127075,0.16041725873947144,0.10000000149011612 +0.9890456795692444,1.1491310596466064,1.0844204425811768,0.01953125 +0.7304704785346985,0.12790271639823914,-0.1085965558886528,0.019999999552965164 +0.9830618500709534,0.8738722205162048,-0.11583804339170456,0.0234375 +0.9885876178741455,0.744857668876648,0.11028216779232025,0.01953125 +0.9575535655021667,0.3011772632598877,0.5136104226112366,0.00390625 +0.9298899173736572,1.1736249923706055,4.0247297286987305,0.09765625 +0.9907795190811157,1.0897759199142456,0.6261603236198425,0.019999999552965164 +0.9855174422264099,0.6543705463409424,0.08955699950456619,0.08984375 +0.976660430431366,0.5610390901565552,0.6389923095703125,0.0390625 +0.9870068430900574,0.80875563621521,-0.6651867032051086,0.08984375 +0.9652793407440186,0.5887689590454102,0.5353426933288574,0.0703125 +0.9875175952911377,0.7699108123779297,0.876632034778595,0.019999999552965164 +0.9016479849815369,0.9994669556617737,0.30356451869010925,0.015625 +0.989987850189209,0.7350922226905823,0.8748764991760254,0.0078125 +0.983323335647583,0.8931586146354675,1.0226351022720337,0.01171875 +0.9914804100990295,0.9369975328445435,0.8283791542053223,0.019999999552965164 +0.9704275727272034,1.124052882194519,0.9457330107688904,0.019999999552965164 +0.9867291450500488,0.9667392373085022,-0.6122757196426392,0.44140625 +0.9887421131134033,0.7823470234870911,0.343982458114624,0.00390625 +0.9861542582511902,0.9171664118766785,0.35665032267570496,0.019999999552965164 +0.9772396683692932,0.08705096691846848,1.7621256113052368,0.66796875 +0.9819098114967346,0.8605496883392334,0.5151250958442688,0.01171875 +0.982971727848053,0.5631197690963745,1.608361005783081,0.019999999552965164 + +0.9914254546165466,0.3850722908973694,1.4068152904510498,0.98828125 +0.9880355596542358,1.1387118101119995,1.4653834104537964,0.05859375 +0.9586950540542603,1.7633997201919556,1.0344760417938232,0.019999999552965164 +0.9828103184700012,0.8817474842071533,0.7680216431617737,0.890625 +0.9880233407020569,0.899823784828186,0.44692227244377136,0.19921875 +0.9862816333770752,0.8610615134239197,0.4195229709148407,0.03125 +0.9813369512557983,0.8014124631881714,1.1136316061019897,0.0078125 +0.9148907661437988,0.5909111499786377,1.2860896587371826,0.015625 +0.9865161776542664,0.8720636963844299,0.6233670115470886,0.015625 +0.9786784648895264,0.48225611448287964,-0.005022380966693163,0.12109375 +0.9843324422836304,1.0519789457321167,-2.2056643962860107,0.03125 +0.9688847064971924,0.8007095456123352,0.14495795965194702,0.1640625 +0.9724696278572083,0.9987169504165649,0.32869264483451843,0.019999999552965164 +0.9875112175941467,1.0948023796081543,2.15657114982605,0.03125 +0.9923174381256104,0.10759950429201126,0.6762840747833252,0.019999999552965164 +0.9666666388511658,0.6234443783760071,1.4971232414245605,0.0390625 +0.989655613899231,0.8248854279518127,0.4701078534126282,0.019999999552965164 +0.9753870368003845,0.6746605634689331,-0.23550045490264893,0.1640625 +0.9170913100242615,1.0504746437072754,2.7344093322753906,0.019999999552965164 +0.9821392297744751,1.4154850244522095,1.2012253999710083,0.019999999552965164 +0.9886221885681152,1.22860586643219,1.160277009010315,0.890625 +0.9877735376358032,0.6805673837661743,1.5975077152252197,0.359375 +0.9831939339637756,0.6648986339569092,1.1059051752090454,0.28515625 +0.950076162815094,0.724887490272522,0.316800057888031,0.019999999552965164 +0.9817547798156738,0.8619367480278015,-0.24251239001750946,0.109375 +0.9849069714546204,0.8399055004119873,1.7567216157913208,0.4000000059604645 +0.9821556806564331,0.8135135769844055,0.33616918325424194,0.0078125 +0.8329862356185913,0.7938078045845032,1.0597797632217407,0.019999999552965164 +0.9856904149055481,0.05120579153299332,0.8267747759819031,0.5 +0.9766159057617188,0.7623113989830017,0.7656452059745789,0.09765625 +0.9885436296463013,0.9814053177833557,0.05546858534216881,0.00390625 +0.9900276064872742,0.9320858716964722,-0.36458709836006165,0.03125 +0.9058290123939514,0.7260504364967346,1.1726433038711548,0.019999999552965164 +0.9503811597824097,0.6632846593856812,0.7332696914672852,0.019999999552965164 +0.9846004247665405,0.6996731758117676,-0.8613988757133484,0.019999999552965164 +0.9897956252098083,0.8407823443412781,1.2952353954315186,0.76171875 +0.9898385405540466,0.7309674024581909,0.7317643761634827,0.019999999552965164 +0.9850022196769714,0.7537633180618286,0.3925366699695587,0.03125 +0.9858620762825012,0.9250133633613586,2.0220303535461426,0.9296875 +0.8120821714401245,0.3994182348251343,-0.4576922655105591,0.019999999552965164 +0.9496838450431824,0.8251343965530396,0.15125347673892975,0.019999999552965164 +0.9420520067214966,0.6087028384208679,1.0767998695373535,0.019999999552965164 +0.9899152517318726,0.8887513279914856,0.9602599143981934,0.019999999552965164 +0.9461711049079895,1.1373282670974731,0.6371906995773315,0.00390625 +0.9834751486778259,0.7226889729499817,0.8995278477668762,0.109375 +0.9850850105285645,1.2857465744018555,-2.2220215797424316,0.38671875 +0.9789451956748962,0.9153420925140381,0.12551555037498474,0.01171875 +0.8774109482765198,0.9271970987319946,0.5529487729072571,0.019999999552965164 +0.9074040651321411,0.920030951499939,0.40618932247161865,0.00390625 +0.9878932237625122,0.5347745418548584,0.8865230679512024,0.046875 +0.937852144241333,1.1346293687820435,-0.3324768841266632,0.019999999552965164 +0.7542195916175842,0.44728168845176697,0.45312440395355225,0.019999999552965164 +0.9915731549263,1.3838905096054077,-0.043990228325128555,0.01171875 +0.9284758567810059,0.4973248541355133,0.9887621998786926,0.019999999552965164 +0.9700435400009155,0.8664135336875916,1.0059133768081665,0.046875 +0.9667003750801086,0.7796391844749451,-0.10554620623588562,0.00390625 +0.9698932766914368,0.7340040802955627,0.4837290942668915,0.00390625 +0.973517894744873,0.9678344130516052,0.36683231592178345,0.00390625 +0.9770389795303345,0.8958415389060974,1.2423408031463623,0.015625 +0.9902989864349365,0.7568255066871643,0.9843511581420898,0.019999999552965164 +0.9908176064491272,0.8731094002723694,0.6906698346138,0.00390625 +0.9901729226112366,0.8561913371086121,0.8783953189849854,0.5859375 diff --git a/mediapipe/tasks/testdata/metadata/score_calibration_file_meta.json b/mediapipe/tasks/testdata/metadata/score_calibration_file_meta.json new file mode 100644 index 000000000..c47d84604 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/score_calibration_file_meta.json @@ -0,0 +1,9 @@ +{ + "associated_files": [ + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] +} diff --git a/mediapipe/tasks/testdata/metadata/score_calibration_tensor_meta.json b/mediapipe/tasks/testdata/metadata/score_calibration_tensor_meta.json new file mode 100644 index 000000000..fd6c0e4e7 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/score_calibration_tensor_meta.json @@ -0,0 +1,15 @@ +{ + "subgraph_metadata": [ + { + "input_process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "LOG", + "default_score": 0.2 + } + } + ] + } + ] +} diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index acf05c2fc..254692856 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -22,12 +22,24 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/albert_with_metadata.tflite?generation=1661875651648830"], ) + http_file( + name = "com_google_mediapipe_associated_file_meta_json", + sha256 = "5b2cba11ae893e1226af6570813955889e9f171d6d2c67b3e96ecb6b96d8c681", + urls = ["https://storage.googleapis.com/mediapipe-assets/associated_file_meta.json?generation=1665422792304395"], + ) + http_file( name = "com_google_mediapipe_bert_text_classifier_tflite", sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600", urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1663009542017720"], ) + http_file( + name = "com_google_mediapipe_bounding_box_tensor_meta_json", + sha256 = "cc019cee86529955a24a3d43ca3d778fa366bcb90d67c8eaf55696789833841a", + urls = ["https://storage.googleapis.com/mediapipe-assets/bounding_box_tensor_meta.json?generation=1665422797529909"], + ) + http_file( name = "com_google_mediapipe_BUILD", sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3", @@ -70,6 +82,24 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_no_resizing.jpg?generation=1661875687251296"], ) + http_file( + name = "com_google_mediapipe_classification_tensor_float_meta_json", + sha256 = "1d10b1c9c87eabac330651136804074ddc134779e94a73cf783207c3aa2a5619", + urls = ["https://storage.googleapis.com/mediapipe-assets/classification_tensor_float_meta.json?generation=1665422803073223"], + ) + + http_file( + name = "com_google_mediapipe_classification_tensor_uint8_meta_json", + sha256 = "74f4d64ee0017d11e0fdc975a88d974d73b72b889fd4d67992356052edde0f1e", + urls = ["https://storage.googleapis.com/mediapipe-assets/classification_tensor_uint8_meta.json?generation=1665422808178685"], + ) + + http_file( + name = "com_google_mediapipe_classification_tensor_unsupported_meta_json", + sha256 = "4810ad8a00f0078c6a693114d00f692aa70ff2d61030a6e516db1e654707e208", + urls = ["https://storage.googleapis.com/mediapipe-assets/classification_tensor_unsupported_meta.json?generation=1665422813312699"], + ) + http_file( name = "com_google_mediapipe_coco_efficientdet_lite0_v1_1_0_quant_2021_09_06_tflite", sha256 = "dee1b4af055a644804d5594442300ecc9e4f7080c25b7c044c98f527eeabb6cf", @@ -166,6 +196,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmark_with_attention.tflite?generation=1661875751615925"], ) + http_file( + name = "com_google_mediapipe_feature_tensor_meta_json", + sha256 = "b2c30ddfd495956ce81085f8a143422f4310b002cfbf1c594ff2ee0576e29d6f", + urls = ["https://storage.googleapis.com/mediapipe-assets/feature_tensor_meta.json?generation=1665422818797346"], + ) + + http_file( + name = "com_google_mediapipe_general_meta_json", + sha256 = "b95363e4bae89b9c2af484498312aaad4efc7ff57c7eadcc4e5e7adca641445f", + urls = ["https://storage.googleapis.com/mediapipe-assets/general_meta.json?generation=1665422822603848"], + ) + http_file( name = "com_google_mediapipe_golden_json_json", sha256 = "55c0c88748d099aa379930504df62c6c8f1d8874ea52d2f8a925f352c4c7f09c", @@ -208,6 +250,30 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/hand_recrop.tflite?generation=1661875770633070"], ) + http_file( + name = "com_google_mediapipe_image_tensor_meta_json", + sha256 = "aad86fde3defb379c82ff7ee48e50493a58529cdc0623cf0d7bf135c3577060e", + urls = ["https://storage.googleapis.com/mediapipe-assets/image_tensor_meta.json?generation=1665422826106636"], + ) + + http_file( + name = "com_google_mediapipe_input_image_tensor_float_meta_json", + sha256 = "426ecf5c3ace61db3936b950c3709daece15827ea21905ddbcdc81b1c6e70232", + urls = ["https://storage.googleapis.com/mediapipe-assets/input_image_tensor_float_meta.json?generation=1665422829230563"], + ) + + http_file( + name = "com_google_mediapipe_input_image_tensor_uint8_meta_json", + sha256 = "dc7ff86b606641e480c7d154b5f467e1f8c895f85733c73ba47a259a66ed187b", + urls = ["https://storage.googleapis.com/mediapipe-assets/input_image_tensor_uint8_meta.json?generation=1665422832572887"], + ) + + http_file( + name = "com_google_mediapipe_input_image_tensor_unsupported_meta_json", + sha256 = "443d436c2068df8201b9822c35e724acfd8004a788d388e7d74c38a2425c55df", + urls = ["https://storage.googleapis.com/mediapipe-assets/input_image_tensor_unsupported_meta.json?generation=1665422835757143"], + ) + http_file( name = "com_google_mediapipe_iris_and_gaze_tflite", sha256 = "b6dcb860a92a3c7264a8e50786f46cecb529672cdafc17d39c78931257da661d", @@ -472,6 +538,24 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/right_hands.jpg?generation=1661875908672404"], ) + http_file( + name = "com_google_mediapipe_score_calibration_file_meta_json", + sha256 = "6a3c305620371f662419a496f75be5a10caebca7803b1e99d8d5d22ba51cda94", + urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration_file_meta.json?generation=1665422841236117"], + ) + + http_file( + name = "com_google_mediapipe_score_calibration_tensor_meta_json", + sha256 = "24cbde7f76dd6a09a55d07f30493c2f254d61154eb2e8d18ed947ff56781186d", + urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration_tensor_meta.json?generation=1665422844327992"], + ) + + http_file( + name = "com_google_mediapipe_score_calibration_txt", + sha256 = "34b0c51a8c79b4515bdd24e440c4b76a9f0fd01ef6385b36af983036e7be6271", + urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration.txt?generation=1665422847392804"], + ) + http_file( name = "com_google_mediapipe_segmentation_golden_rotation0_png", sha256 = "9ee993919b753118928ba2d14f7c5c83a6cfc23355e6943dac4ad81eedd73069", From 6fa455f40ea74264fa4182657755cb1e767dae79 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 10 Oct 2022 14:31:46 -0700 Subject: [PATCH 02/18] Dataclasses for text classifier PiperOrigin-RevId: 480178697 --- mediapipe/model_maker/python/core/BUILD | 7 ++ .../python/core/hyperparameters.py | 68 +++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 mediapipe/model_maker/python/core/hyperparameters.py diff --git a/mediapipe/model_maker/python/core/BUILD b/mediapipe/model_maker/python/core/BUILD index 10aef8c33..9f205bb11 100644 --- a/mediapipe/model_maker/python/core/BUILD +++ b/mediapipe/model_maker/python/core/BUILD @@ -12,8 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Placeholder for internal Python strict library compatibility macro. + package( default_visibility = ["//mediapipe:__subpackages__"], ) licenses(["notice"]) + +py_library( + name = "hyperparameters", + srcs = ["hyperparameters.py"], +) diff --git a/mediapipe/model_maker/python/core/hyperparameters.py b/mediapipe/model_maker/python/core/hyperparameters.py new file mode 100644 index 000000000..2a7a8678c --- /dev/null +++ b/mediapipe/model_maker/python/core/hyperparameters.py @@ -0,0 +1,68 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hyperparameters for training models. Shared across tasks.""" + +import dataclasses +import tempfile + +from typing import Optional + + +# TODO: Integrate this class into ImageClassifier and other tasks. +@dataclasses.dataclass +class BaseHParams: + """Hyperparameters used for training models. + + A common set of hyperparameters shared by the training jobs of all model + maker tasks. + + Attributes: + learning_rate: The learning rate to use for gradient descent training. + batch_size: Batch size for training. + epochs: Number of training iterations over the dataset. + steps_per_epoch: An optional integer indicate the number of training steps + per epoch. If not set, the training pipeline calculates the default steps + per epoch as the training dataset size devided by batch size. + shuffle: True if the dataset is shuffled before training. + export_dir: The location of the model checkpoint files. + distribution_strategy: A string specifying which Distribution Strategy to + use. Accepted values are 'off', 'one_device', 'mirrored', + 'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case + insensitive. 'off' means not to use Distribution Strategy; 'tpu' means to + use TPUStrategy using `tpu_address`. See the tf.distribute.Strategy + documentation for more details: + https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy. + num_gpus: How many GPUs to use at each worker with the + DistributionStrategies API. The default is -1, which means utilize all + available GPUs. + tpu: The Cloud TPU to use for training. This should be either the name used + when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url. + """ + + # Parameters for train configuration + learning_rate: float + batch_size: int + epochs: int + steps_per_epoch: Optional[int] = None + + # Dataset-related parameters + shuffle: bool = False + + # Parameters for model / checkpoint files + export_dir: str = tempfile.mkdtemp() + + # Parameters for hardware acceleration + distribution_strategy: str = 'off' + num_gpus: int = -1 # default value of -1 means use all available GPUs + tpu: str = '' From 12f72f067dfe9cf3e2803cbe0b2d60ae2b14aada Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 10 Oct 2022 15:36:03 -0700 Subject: [PATCH 03/18] updated documentation PiperOrigin-RevId: 480193380 --- docs/getting_started/cpp.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/getting_started/cpp.md b/docs/getting_started/cpp.md index 8fc091fea..47fb697b0 100644 --- a/docs/getting_started/cpp.md +++ b/docs/getting_started/cpp.md @@ -54,7 +54,7 @@ Note: This currently works only on Linux, and please first follow ```bash GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \ - --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt + --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live_gpu.pbtxt ``` This will open up your webcam as long as it is connected and on. Any errors From f4fd1063a71d408b10b79c97793753f3b659cd3e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 11 Oct 2022 13:52:43 -0700 Subject: [PATCH 04/18] Add helper methods to load saved model from external files in model maker. PiperOrigin-RevId: 480444918 --- .../python/core/utils/model_util.py | 26 +++++++++++++++++++ .../python/core/utils/model_util_test.py | 19 +++++++++++--- .../python/core/utils/test_util.py | 18 +++++++++++++ 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index 4914fea57..8962f2868 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -26,6 +26,7 @@ from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union import numpy as np import tensorflow as tf +# resources dependency from mediapipe.model_maker.python.core.data import dataset from mediapipe.model_maker.python.core.utils import quantization @@ -33,6 +34,31 @@ DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0 ESTIMITED_STEPS_PER_EPOCH = 1000 +def load_keras_model(model_path: str, + compile_on_load: bool = False) -> tf.keras.Model: + """Loads a tensorflow Keras model from file and returns the Keras model. + + Args: + model_path: Relative path to a directory containing model data, such as + /saved_model/. + compile_on_load: Whether the model should be compiled while loading. If + False, the model returned has to be compiled with the appropriate loss + function and custom metrics before running for inference on a test + dataset. + + Returns: + A tensorflow Keras model. + """ + # Extract the file path before mediapipe/ as the `base_dir`. By joining it + # with the `model_path` which defines the relative path under mediapipe/, it + # yields to the aboslution path of the model files directory. + cwd = os.path.dirname(__file__) + base_dir = cwd[:cwd.rfind('mediapipe')] + absolute_path = os.path.join(base_dir, model_path) + return tf.keras.models.load_model( + absolute_path, custom_objects={'tf': tf}, compile=compile_on_load) + + def get_steps_per_epoch(steps_per_epoch: Optional[int] = None, batch_size: Optional[int] = None, train_data: Optional[dataset.Dataset] = None) -> int: diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index 9c3908841..ce31c1877 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -15,7 +15,6 @@ import os from absl.testing import parameterized -import numpy as np import tensorflow as tf from mediapipe.model_maker.python.core.utils import model_util @@ -25,6 +24,18 @@ from mediapipe.model_maker.python.core.utils import test_util class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): + def test_load_model(self): + input_dim = 4 + model = test_util.build_model(input_shape=[input_dim], num_classes=2) + saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model') + model.save(saved_model_path) + loaded_model = model_util.load_keras_model(saved_model_path) + + input_tensors = test_util.create_random_sample(size=[1, input_dim]) + model_output = model.predict_on_batch(input_tensors) + loaded_model_output = loaded_model.predict_on_batch(input_tensors) + self.assertTrue((model_output == loaded_model_output).all()) + @parameterized.named_parameters( dict( testcase_name='input_only_steps_per_epoch', @@ -124,9 +135,9 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): input_dim: int, max_input_value: int = 1000, atol: float = 1e-04): - np.random.seed(0) - random_input = np.random.uniform( - low=0, high=max_input_value, size=(1, input_dim)).astype(np.float32) + random_input = test_util.create_random_sample( + size=[1, input_dim], high=max_input_value) + random_input = tf.convert_to_tensor(random_input) self.assertTrue( test_util.is_same_output( diff --git a/mediapipe/model_maker/python/core/utils/test_util.py b/mediapipe/model_maker/python/core/utils/test_util.py index eb2952dd3..cac2a0e1f 100644 --- a/mediapipe/model_maker/python/core/utils/test_util.py +++ b/mediapipe/model_maker/python/core/utils/test_util.py @@ -46,6 +46,24 @@ def create_dataset(data_size: int, return dataset +def create_random_sample(size: Union[int, List[int]], + low: float = 0, + high: float = 1) -> np.ndarray: + """Creates and returns a random sample with floating point values. + + Args: + size: Size of the output multi-dimensional array. + low: Lower boundary of the output values. + high: Higher boundary of the output values. + + Returns: + 1D array if the size is scalar. Otherwise, N-D array whose dimension equals + input size. + """ + np.random.seed(0) + return np.random.uniform(low=low, high=high, size=size).astype(np.float32) + + def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model: """Builds a simple Keras model for test.""" inputs = tf.keras.layers.Input(shape=input_shape) From 1b611c66bb1577d4ecbbf4b4e97347261cbf9b78 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 11 Oct 2022 14:38:14 -0700 Subject: [PATCH 05/18] Improve quantization support in model_maker/image_classifier PiperOrigin-RevId: 480455944 --- mediapipe/model_maker/python/core/utils/model_util.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index 8962f2868..0899a9b1a 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -94,7 +94,8 @@ def export_tflite( tflite_filepath: str, quantization_config: Optional[quantization.QuantizationConfig] = None, supported_ops: Tuple[tf.lite.OpsSet, - ...] = (tf.lite.OpsSet.TFLITE_BUILTINS,)): + ...] = (tf.lite.OpsSet.TFLITE_BUILTINS,), + preprocess: Optional[Callable[..., bool]] = None): """Converts the model to tflite format and saves it. Args: @@ -102,6 +103,9 @@ def export_tflite( tflite_filepath: File path to save tflite model. quantization_config: Configuration for post-training quantization. supported_ops: A list of supported ops in the converted TFLite file. + preprocess: A callable to preprocess the representative dataset for + quantization. The callable takes three arguments in order: feature, + label, and is_training. """ if tflite_filepath is None: raise ValueError( @@ -113,7 +117,8 @@ def export_tflite( converter = tf.lite.TFLiteConverter.from_saved_model(save_path) if quantization_config: - converter = quantization_config.set_converter_with_quantization(converter) + converter = quantization_config.set_converter_with_quantization( + converter, preprocess=preprocess) converter.target_spec.supported_ops = supported_ops tflite_model = converter.convert() From 64deb791dc6814858e6f19133ee691d72b80993e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 11 Oct 2022 15:35:37 -0700 Subject: [PATCH 06/18] Fix empty packet bug with no hands detected. PiperOrigin-RevId: 480469392 --- .../gesture_recognizer/gesture_recognizer.cc | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index ca5deee7f..e0d1473c2 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -180,6 +180,15 @@ absl::StatusOr> GestureRecognizer::Create( if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { return; } + Packet image_packet = status_or_packets.value()[kImageOutStreamName]; + if (status_or_packets.value()[kHandGesturesStreamName].IsEmpty()) { + Packet empty_packet = + status_or_packets.value()[kHandGesturesStreamName]; + result_callback( + {{{}, {}, {}, {}}}, image_packet.Get(), + empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); + return; + } Packet gesture_packet = status_or_packets.value()[kHandGesturesStreamName]; Packet handedness_packet = @@ -188,7 +197,6 @@ absl::StatusOr> GestureRecognizer::Create( status_or_packets.value()[kHandLandmarksStreamName]; Packet hand_world_landmarks_packet = status_or_packets.value()[kHandWorldLandmarksStreamName]; - Packet image_packet = status_or_packets.value()[kImageOutStreamName]; result_callback( {{gesture_packet.Get>(), handedness_packet.Get>(), @@ -218,6 +226,9 @@ absl::StatusOr GestureRecognizer::Recognize( ASSIGN_OR_RETURN(auto output_packets, ProcessImageData({{kImageInStreamName, MakePacket(std::move(image))}})); + if (output_packets[kHandGesturesStreamName].IsEmpty()) { + return {{{}, {}, {}, {}}}; + } return { {/* gestures= */ {output_packets[kHandGesturesStreamName] .Get>()}, @@ -247,6 +258,9 @@ absl::StatusOr GestureRecognizer::RecognizeForVideo( {{kImageInStreamName, MakePacket(std::move(image)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); + if (output_packets[kHandGesturesStreamName].IsEmpty()) { + return {{{}, {}, {}, {}}}; + } return { {/* gestures= */ {output_packets[kHandGesturesStreamName] .Get>()}, From 77de8b0bb058845af4c28a716fea8245f89d8fb8 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 12 Oct 2022 00:31:05 -0700 Subject: [PATCH 07/18] Split RotationMode proto target for reuse. PiperOrigin-RevId: 480548644 --- mediapipe/calculators/image/BUILD | 8 ++++++ .../image/image_transformation_calculator.cc | 1 + .../image_transformation_calculator.proto | 12 +------- .../calculators/image/rotation_mode.proto | 28 +++++++++++++++++++ 4 files changed, 38 insertions(+), 11 deletions(-) create mode 100644 mediapipe/calculators/image/rotation_mode.proto diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 458c5368b..89e2d371c 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -209,11 +209,18 @@ cc_library( alwayslink = 1, ) +mediapipe_proto_library( + name = "rotation_mode_proto", + srcs = ["rotation_mode.proto"], + visibility = ["//visibility:public"], +) + mediapipe_proto_library( name = "image_transformation_calculator_proto", srcs = ["image_transformation_calculator.proto"], visibility = ["//visibility:public"], deps = [ + ":rotation_mode_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/gpu:scale_mode_proto", @@ -238,6 +245,7 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ + ":rotation_mode_cc_proto", ":image_transformation_calculator_cc_proto", "//mediapipe/framework:packet", "//mediapipe/framework:timestamp", diff --git a/mediapipe/calculators/image/image_transformation_calculator.cc b/mediapipe/calculators/image/image_transformation_calculator.cc index bc7fd8df7..84697cc62 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.cc +++ b/mediapipe/calculators/image/image_transformation_calculator.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "mediapipe/calculators/image/image_transformation_calculator.pb.h" +#include "mediapipe/calculators/image/rotation_mode.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" diff --git a/mediapipe/calculators/image/image_transformation_calculator.proto b/mediapipe/calculators/image/image_transformation_calculator.proto index c90e03be9..739c5bfbb 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.proto +++ b/mediapipe/calculators/image/image_transformation_calculator.proto @@ -16,20 +16,10 @@ syntax = "proto2"; package mediapipe; +import "mediapipe/calculators/image/rotation_mode.proto"; import "mediapipe/framework/calculator.proto"; import "mediapipe/gpu/scale_mode.proto"; -// Counterclockwise rotation. -message RotationMode { - enum Mode { - UNKNOWN = 0; - ROTATION_0 = 1; - ROTATION_90 = 2; - ROTATION_180 = 3; - ROTATION_270 = 4; - } -} - message ImageTransformationCalculatorOptions { extend CalculatorOptions { optional ImageTransformationCalculatorOptions ext = 251952830; diff --git a/mediapipe/calculators/image/rotation_mode.proto b/mediapipe/calculators/image/rotation_mode.proto new file mode 100644 index 000000000..d4859aa4c --- /dev/null +++ b/mediapipe/calculators/image/rotation_mode.proto @@ -0,0 +1,28 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +// Counterclockwise rotation. +message RotationMode { + enum Mode { + UNKNOWN = 0; + ROTATION_0 = 1; + ROTATION_90 = 2; + ROTATION_180 = 3; + ROTATION_270 = 4; + } +} From f9a4e472eba2755e7ba4b2012f125e643b82ddaa Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 12 Oct 2022 02:03:26 -0700 Subject: [PATCH 08/18] Add AndroidManifest.xml into third_party/mediapipe/tasks/java/com/google/mediapipe/tasks/core PiperOrigin-RevId: 480567195 --- .../com/google/mediapipe/tasks/core/AndroidManifest.xml | 8 ++++++++ .../tasks/java/com/google/mediapipe/tasks/core/BUILD | 1 + 2 files changed, 9 insertions(+) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/core/AndroidManifest.xml diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/AndroidManifest.xml new file mode 100644 index 000000000..e45fc3dcf --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index 8da7b8561..b4ebfe8cc 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -20,6 +20,7 @@ android_library( javacopts = [ "-Xep:AndroidJdkLibsChecker:OFF", ], + manifest = "AndroidManifest.xml", deps = [ "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", From cbbd4718a01ee014fe0fc77945738cf35dd6b738 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 12 Oct 2022 03:44:52 -0700 Subject: [PATCH 09/18] Update mediapipe_aar.bzl to put more mediapipe framework java proto classes into AARs. PiperOrigin-RevId: 480583365 --- .../com/google/mediapipe/mediapipe_aar.bzl | 47 +++++++++++++++++-- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl index 7f2cb146c..9b01e2f0b 100644 --- a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl +++ b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl @@ -164,7 +164,10 @@ EOF assets_dir = assets_dir, ) - _aar_with_jni(name, name + "_android_lib") + mediapipe_build_aar_with_jni( + name = name, + android_library = name + "_android_lib", + ) def _mediapipe_jni(name, gen_libmediapipe, calculators = []): """Generates MediaPipe jni library. @@ -203,7 +206,14 @@ def _mediapipe_jni(name, gen_libmediapipe, calculators = []): alwayslink = 1, ) -def _aar_with_jni(name, android_library): +def mediapipe_build_aar_with_jni(name, android_library): + """Builds MediaPipe AAR with jni. + + Args: + name: The bazel target name. + android_library: the android library that contains jni. + """ + # Generates dummy AndroidManifest.xml for dummy apk usage # (dummy apk is generated by _dummy_app target below) native.genrule( @@ -214,7 +224,7 @@ cat > $(OUTS) < - + EOF """, @@ -241,6 +251,7 @@ chmod +w $(location :{}.aar) origdir=$$PWD cd $$(mktemp -d) unzip $$origdir/$(location :{}_dummy_app_unsigned.apk) "lib/*" +find lib -name *_dummy_app.so -delete cp -r lib jni zip -r $$origdir/$(location :{}.aar) jni/*/*.so """.format(android_library, name, name, name, name), @@ -287,6 +298,36 @@ def mediapipe_java_proto_srcs(name = ""): src_out = "com/google/mediapipe/proto/CalculatorProto.java", )) + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework:calculator_options_java_proto_lite", + src_out = "com/google/mediapipe/proto/CalculatorOptionsProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework:stream_handler_java_proto_lite", + src_out = "com/google/mediapipe/proto/StreamHandlerProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework:packet_factory_java_proto_lite", + src_out = "com/google/mediapipe/proto/PacketFactoryProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework:packet_generator_java_proto_lite", + src_out = "com/google/mediapipe/proto/PacketGeneratorProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework:status_handler_java_proto_lite", + src_out = "com/google/mediapipe/proto/StatusHandlerProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework:mediapipe_options_java_proto_lite", + src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java", + )) + proto_src_list.append(mediapipe_java_proto_src_extractor( target = "//mediapipe/framework/formats:landmark_java_proto_lite", src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java", From 51a760608380f895570304b6ad4014ce01a94046 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 12 Oct 2022 10:19:45 -0700 Subject: [PATCH 10/18] Add Java ImageClassifier API. PiperOrigin-RevId: 480656683 --- .../imageclassifier/AndroidManifest.xml | 8 + .../tasks/vision/imageclassifier/BUILD | 46 ++ .../ImageClassificationResult.java | 102 ++++ .../imageclassifier/ImageClassifier.java | 456 ++++++++++++++++++ .../imageclassifier/AndroidManifest.xml | 24 + .../tasks/vision/imageclassifier/BUILD | 19 + .../imageclassifier/ImageClassifierTest.java | 445 +++++++++++++++++ 7 files changed, 1100 insertions(+) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/BUILD create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/BUILD create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml new file mode 100644 index 000000000..e257ddc42 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/BUILD new file mode 100644 index 000000000..cecd9f521 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/BUILD @@ -0,0 +1,46 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +android_library( + name = "imageclassifier", + srcs = [ + "ImageClassificationResult.java", + "ImageClassifier.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = ":AndroidManifest.xml", + deps = [ + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", + "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java new file mode 100644 index 000000000..09f854caa --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java @@ -0,0 +1,102 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageclassifier; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.container.proto.CategoryProto; +import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; +import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.ClassificationEntry; +import com.google.mediapipe.tasks.components.containers.Classifications; +import com.google.mediapipe.tasks.core.TaskResult; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** Represents the classification results generated by {@link ImageClassifier}. */ +@AutoValue +public abstract class ImageClassificationResult implements TaskResult { + + /** + * Creates an {@link ImageClassificationResult} instance from a {@link + * ClassificationsProto.ClassificationResult} protobuf message. + * + * @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf + * message. + * @param timestampMs a timestamp for this result. + */ + // TODO: consolidate output formats across platforms. + static ImageClassificationResult create( + ClassificationsProto.ClassificationResult classificationResult, long timestampMs) { + List classifications = new ArrayList<>(); + for (ClassificationsProto.Classifications classificationsProto : + classificationResult.getClassificationsList()) { + classifications.add(classificationsFromProto(classificationsProto)); + } + return new AutoValue_ImageClassificationResult( + timestampMs, Collections.unmodifiableList(classifications)); + } + + @Override + public abstract long timestampMs(); + + /** Contains one set of results per classifier head. */ + public abstract List classifications(); + + /** + * Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object. + * + * @param category the {@link CategoryProto.Category} protobuf message to convert. + */ + static Category categoryFromProto(CategoryProto.Category category) { + return Category.create( + category.getScore(), + category.getIndex(), + category.getCategoryName(), + category.getDisplayName()); + } + + /** + * Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link + * ClassificationEntry} object. + * + * @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert. + */ + static ClassificationEntry classificationEntryFromProto( + ClassificationsProto.ClassificationEntry entry) { + List categories = new ArrayList<>(); + for (CategoryProto.Category category : entry.getCategoriesList()) { + categories.add(categoryFromProto(category)); + } + return ClassificationEntry.create(categories, entry.getTimestampMs()); + } + + /** + * Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link + * Classifications} object. + * + * @param classifications the {@link ClassificationsProto.Classifications} protobuf message to + * convert. + */ + static Classifications classificationsFromProto( + ClassificationsProto.Classifications classifications) { + List entries = new ArrayList<>(); + for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) { + entries.add(classificationEntryFromProto(entry)); + } + return Classifications.create( + entries, classifications.getHeadIndex(), classifications.getHeadName()); + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java new file mode 100644 index 000000000..68cae151f --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -0,0 +1,456 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageclassifier; + +import android.content.Context; +import android.graphics.RectF; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; +import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto; +import com.google.protobuf.InvalidProtocolBufferException; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs classification on images. + * + *

The API expects a TFLite model with optional, but strongly recommended, TFLite Model Metadata.. + * + *

The API supports models with one image input tensor and one or more output tensors. To be more + * specific, here are the requirements. + * + *

    + *
  • Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) + *
      + *
    • image input of size {@code [batch x height x width x channels]}. + *
    • batch inference is not supported ({@code batch} is required to be 1). + *
    • only RGB inputs are supported ({@code channels} is required to be 3). + *
    • if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the + * metadata for input normalization. + *
    + *
  • At least one output tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) with: + *
      + *
    • {@code N} classes and either 2 or 4 dimensions, i.e. {@code [1 x N]} or {@code [1 x 1 + * x 1 x N]} + *
    • optional (but recommended) label map(s) as AssociatedFile-s with type + * TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if + * any) is used to fill the {@code class_name} field of the results. The {@code + * display_name} field is filled from the AssociatedFile (if any) whose locale matches + * the {@code display_names_locale} field of the {@code ImageClassifierOptions} used at + * creation time ("en" by default, i.e. English). If none of these are available, only + * the {@code index} field of the results will be filled. + *
    • optional score calibration can be attached using ScoreCalibrationOptions and an + * AssociatedFile with type TENSOR_AXIS_SCORE_CALIBRATION. See + * metadata_schema.fbs for more details. + *
    + *
+ * + *

An example of such model can be found + * TensorFlow Hub. + */ +public final class ImageClassifier extends BaseVisionTaskApi { + private static final String TAG = ImageClassifier.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("CLASSIFICATION_RESULT:classification_result_out", "IMAGE:image_out")); + private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0; + private static final int IMAGE_OUT_STREAM_INDEX = 1; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; + + static { + ProtoUtil.registerTypeName( + ClassificationsProto.ClassificationResult.class, + "mediapipe.tasks.components.containers.proto.ClassificationResult"); + } + + /** + * Creates an {@link ImageClassifier} instance from a model file and default {@link + * ImageClassifierOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the classification model in the assets. + * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation. + */ + public static ImageClassifier createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, ImageClassifierOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ImageClassifier} instance from a model file and default {@link + * ImageClassifierOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the classification model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation. + */ + public static ImageClassifier createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, ImageClassifierOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates an {@link ImageClassifier} instance from a model buffer and default {@link + * ImageClassifierOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * classification model. + * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation. + */ + public static ImageClassifier createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, ImageClassifierOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ImageClassifier} instance from an {@link ImageClassifierOptions} instance. + * + * @param context an Android {@link Context}. + * @param options an {@link ImageClassifierOptions} instance. + * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation. + */ + public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) { + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public ImageClassificationResult convertToTaskResult(List packets) { + try { + return ImageClassificationResult.create( + PacketGetter.getProto( + packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX), + ClassificationsProto.ClassificationResult.getDefaultInstance()), + packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp()); + } catch (InvalidProtocolBufferException e) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); + } + } + + @Override + public Image convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + options.resultListener().ifPresent(handler::setResultListener); + options.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(options) + .setEnableFlowLimiting(options.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new ImageClassifier(runner, options.runningMode()); + } + + /** + * Constructor to initialize an {@link ImageClassifier} from a {@link TaskRunner} and {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private ImageClassifier(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs classification on the provided single image. Only use this method when the {@link + * ImageClassifier} is created with {@link RunningMode.IMAGE}. + * + *

{@link ImageClassifier} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public ImageClassificationResult classify(Image inputImage) { + return (ImageClassificationResult) processImageData(inputImage, buildFullImageRectF()); + } + + /** + * Performs classification on the provided single image and region-of-interest. Only use this + * method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}. + * + *

{@link ImageClassifier} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param roi a {@link RectF} specifying the region of interest on which to perform + * classification. Coordinates are expected to be specified as normalized values in [0,1]. + * @throws MediaPipeException if there is an internal error. + */ + public ImageClassificationResult classify(Image inputImage, RectF roi) { + return (ImageClassificationResult) processImageData(inputImage, roi); + } + + /** + * Performs classification on the provided video frame. Only use this method when the {@link + * ImageClassifier} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageClassifier} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputTimestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ImageClassificationResult classifyForVideo(Image inputImage, long inputTimestampMs) { + return (ImageClassificationResult) + processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs); + } + + /** + * Performs classification on the provided video frame with additional region-of-interest. Only + * use this method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageClassifier} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param roi a {@link RectF} specifying the region of interest on which to perform + * classification. Coordinates are expected to be specified as normalized values in [0,1]. + * @param inputTimestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ImageClassificationResult classifyForVideo( + Image inputImage, RectF roi, long inputTimestampMs) { + return (ImageClassificationResult) processVideoData(inputImage, roi, inputTimestampMs); + } + + /** + * Sends live image data to perform classification, and the results will be available via the + * {@link ResultListener} provided in the {@link ImageClassifierOptions}. Only use this method + * when the {@link ImageClassifier} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

{@link ImageClassifier} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputTimestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void classifyAsync(Image inputImage, long inputTimestampMs) { + sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs); + } + + /** + * Sends live image data and additional region-of-interest to perform classification, and the + * results will be available via the {@link ResultListener} provided in the {@link + * ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with + * {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

{@link ImageClassifier} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param roi a {@link RectF} specifying the region of interest on which to perform + * classification. Coordinates are expected to be specified as normalized values in [0,1]. + * @param inputTimestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void classifyAsync(Image inputImage, RectF roi, long inputTimestampMs) { + sendLiveStreamData(inputImage, roi, inputTimestampMs); + } + + /** Options for setting up and {@link ImageClassifier}. */ + @AutoValue + public abstract static class ImageClassifierOptions extends TaskOptions { + + /** Builder for {@link ImageClassifierOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the {@link BaseOptions} for the image classifier task. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /** + * Sets the {@link RunningMode} for the image classifier task. Default to the image mode. + * Image classifier has three modes: + * + *
    + *
  • IMAGE: The mode for performing classification on single image inputs. + *
  • VIDEO: The mode for performing classification on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for performing classification on a live stream of input + * data, such as from camera. In this mode, {@code setResultListener} must be called to + * set up a listener to receive the classification results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode runningMode); + + /** + * Sets the optional {@link ClassifierOptions} controling classification behavior, such as + * score threshold, number of results, etc. + */ + public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + + /** + * Sets the {@link ResultListener} to receive the classification results asynchronously when + * the image classifier is in the live stream mode. + */ + public abstract Builder setResultListener( + ResultListener resultListener); + + /** Sets an optional {@link ErrorListener}. */ + public abstract Builder setErrorListener(ErrorListener errorListener); + + abstract ImageClassifierOptions autoBuild(); + + /** + * Validates and builds the {@link ImageClassifierOptions} instance. * + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the image classifier + * is in the live stream mode. + */ + public final ImageClassifierOptions build() { + ImageClassifierOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The image classifier is in the live stream mode, a user-defined result listener" + + " must be provided in the ImageClassifierOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The image classifier is in the image or video mode, a user-defined result listener" + + " shouldn't be provided in ImageClassifierOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract Optional classifierOptions(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder() + .setRunningMode(RunningMode.IMAGE); + } + + /** + * Converts a {@link ImageClassifierOptions} to a {@link CalculatorOptions} protobuf message. + */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder = + ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder); + if (classifierOptions().isPresent()) { + taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); + } + return CalculatorOptions.newBuilder() + .setExtension( + ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } + + /** Creates a RectF covering the full image. */ + private static RectF buildFullImageRectF() { + return new RectF(0, 0, 1, 1); + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml new file mode 100644 index 000000000..66fa20509 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java new file mode 100644 index 000000000..e02e8ebe7 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java @@ -0,0 +1,445 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageclassifier; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import android.graphics.RectF; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imageclassifier.ImageClassifier.ImageClassifierOptions; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link ImageClassifier}/ */ +@RunWith(Suite.class) +@SuiteClasses({ImageClassifierTest.General.class, ImageClassifierTest.RunningModeTest.class}) +public class ImageClassifierTest { + private static final String FLOAT_MODEL_FILE = "mobilenet_v2_1.0_224.tflite"; + private static final String QUANTIZED_MODEL_FILE = "mobilenet_v1_0.25_224_quant.tflite"; + private static final String BURGER_IMAGE = "burger.jpg"; + private static final String MULTI_OBJECTS_IMAGE = "multi_objects.jpg"; + + @RunWith(AndroidJUnit4.class) + public static final class General extends ImageClassifierTest { + + @Test + public void create_failsWithMissingModel() throws Exception { + String nonExistentFile = "/path/to/non/existent/file"; + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + ImageClassifier.createFromFile( + ApplicationProvider.getApplicationContext(), nonExistentFile)); + assertThat(exception).hasMessageThat().contains(nonExistentFile); + } + + @Test + public void create_failsWithInvalidModelBuffer() throws Exception { + // Create a non-direct model ByteBuffer. + ByteBuffer modelBuffer = + TestUtils.loadToNonDirectByteBuffer( + ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifier.createFromBuffer( + ApplicationProvider.getApplicationContext(), modelBuffer)); + + assertThat(exception) + .hasMessageThat() + .contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + + @Test + public void classify_succeedsWithNoOptions() throws Exception { + ImageClassifier imageClassifier = + ImageClassifier.createFromFile( + ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE); + ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertThat(results.classifications().get(0).entries().get(0).categories()).hasSize(1001); + assertThat(results.classifications().get(0).entries().get(0).categories().get(0)) + .isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", "")); + } + + @Test + public void classify_succeedsWithFloatModel() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, + Arrays.asList( + Category.create(0.7952058f, 934, "cheeseburger", ""), + Category.create(0.027329788f, 932, "bagel", ""), + Category.create(0.019334773f, 925, "guacamole", ""))); + } + + @Test + public void classify_succeedsWithQuantizedModel() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", ""))); + } + + @Test + public void classify_succeedsWithScoreThreshold() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, + Arrays.asList( + Category.create(0.7952058f, 934, "cheeseburger", ""), + Category.create(0.027329788f, 932, "bagel", ""))); + } + + @Test + public void classify_succeedsWithAllowlist() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions( + ClassifierOptions.builder() + .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf")) + .build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, + Arrays.asList( + Category.create(0.7952058f, 934, "cheeseburger", ""), + Category.create(0.019334773f, 925, "guacamole", ""), + Category.create(0.006279315f, 963, "meat loaf", ""))); + } + + @Test + public void classify_succeedsWithDenylist() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions( + ClassifierOptions.builder() + .setMaxResults(3) + .setCategoryDenylist(Arrays.asList("bagel")) + .build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, + Arrays.asList( + Category.create(0.7952058f, 934, "cheeseburger", ""), + Category.create(0.019334773f, 925, "guacamole", ""), + Category.create(0.006279315f, 963, "meat loaf", ""))); + } + + @Test + public void classify_succeedsWithRegionOfInterest() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + // RectF around the soccer ball. + RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f); + ImageClassificationResult results = + imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), roi); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", ""))); + } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends ImageClassifierTest { + + @Test + public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception { + for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setRunningMode(mode) + .setResultListener((imageClassificationResult, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void classify_failsWithCallingWrongApiInImageMode() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void classify_failsWithCallingWrongApiInVideoMode() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> imageClassifier.classify(getImageFromAsset(BURGER_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void classify_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((imageClassificationResult, inputImage) -> {}) + .build(); + + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> imageClassifier.classify(getImageFromAsset(BURGER_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void classify_succeedsWithImageMode() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); + } + + @Test + public void classify_succeedsWithVideoMode() throws Exception { + Image image = getImageFromAsset(BURGER_IMAGE); + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + for (int i = 0; i < 3; i++) { + ImageClassificationResult results = imageClassifier.classifyForVideo(image, i); + assertHasOneHeadAndOneTimestamp(results, i); + assertCategoriesAre( + results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); + } + } + + @Test + public void classify_failsWithOutOfOrderInputTimestamps() throws Exception { + Image image = getImageFromAsset(BURGER_IMAGE); + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (imageClassificationResult, inputImage) -> { + assertCategoriesAre( + imageClassificationResult, + Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 1); + MediaPipeException exception = + assertThrows(MediaPipeException.class, () -> imageClassifier.classifyAsync(image, 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void classify_succeedsWithLiveStreamMode() throws Exception { + Image image = getImageFromAsset(BURGER_IMAGE); + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (imageClassificationResult, inputImage) -> { + assertCategoriesAre( + imageClassificationResult, + Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; ++i) { + imageClassifier.classifyAsync(image, i); + } + } + } + } + + private static Image getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static void assertHasOneHeadAndOneTimestamp( + ImageClassificationResult results, long timestampMs) { + assertThat(results.classifications()).hasSize(1); + assertThat(results.classifications().get(0).headIndex()).isEqualTo(0); + assertThat(results.classifications().get(0).headName()).isEqualTo("probability"); + assertThat(results.classifications().get(0).entries()).hasSize(1); + assertThat(results.classifications().get(0).entries().get(0).timestampMs()) + .isEqualTo(timestampMs); + } + + private static void assertCategoriesAre( + ImageClassificationResult results, List categories) { + assertThat(results.classifications().get(0).entries().get(0).categories()) + .hasSize(categories.size()); + for (int i = 0; i < categories.size(); i++) { + assertThat(results.classifications().get(0).entries().get(0).categories().get(i)) + .isEqualTo(categories.get(i)); + } + } + + private static void assertImageSizeIsExpected(Image inputImage) { + assertThat(inputImage).isNotNull(); + assertThat(inputImage.getWidth()).isEqualTo(480); + assertThat(inputImage.getHeight()).isEqualTo(325); + } +} From ae4b2ae577e47fe503f252cc79041819da885224 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 12 Oct 2022 11:33:11 -0700 Subject: [PATCH 11/18] Add support for input image rotation in ImageClassifier. PiperOrigin-RevId: 480676070 --- .../image_classifier/image_classifier.cc | 45 +++--- .../image_classifier/image_classifier.h | 55 ++++++-- .../image_classifier/image_classifier_test.cc | 131 +++++++++++++++--- mediapipe/tasks/testdata/vision/BUILD | 4 + third_party/external_files.bzl | 12 ++ 5 files changed, 194 insertions(+), 53 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index 0338b2ee2..f3dcdd07d 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -59,14 +59,24 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::PacketMap; -// Builds a NormalizedRect covering the entire image. -NormalizedRect BuildFullImageNormRect() { - NormalizedRect norm_rect; - norm_rect.set_x_center(0.5); - norm_rect.set_y_center(0.5); - norm_rect.set_width(1); - norm_rect.set_height(1); - return norm_rect; +// Returns a NormalizedRect covering the full image if input is not present. +// Otherwise, makes sure the x_center, y_center, width and height are set in +// case only a rotation was provided in the input. +NormalizedRect FillNormalizedRect( + std::optional normalized_rect) { + NormalizedRect result; + if (normalized_rect.has_value()) { + result = *normalized_rect; + } + bool has_coordinates = result.has_x_center() || result.has_y_center() || + result.has_width() || result.has_height(); + if (!has_coordinates) { + result.set_x_center(0.5); + result.set_y_center(0.5); + result.set_width(1); + result.set_height(1); + } + return result; } // Creates a MediaPipe graph config that contains a subgraph node of @@ -154,15 +164,14 @@ absl::StatusOr> ImageClassifier::Create( } absl::StatusOr ImageClassifier::Classify( - Image image, std::optional roi) { + Image image, std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + NormalizedRect norm_rect = FillNormalizedRect(image_processing_options); ASSIGN_OR_RETURN( auto output_packets, ProcessImageData( @@ -173,15 +182,15 @@ absl::StatusOr ImageClassifier::Classify( } absl::StatusOr ImageClassifier::ClassifyForVideo( - Image image, int64 timestamp_ms, std::optional roi) { + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + NormalizedRect norm_rect = FillNormalizedRect(image_processing_options); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( @@ -195,16 +204,16 @@ absl::StatusOr ImageClassifier::ClassifyForVideo( .Get(); } -absl::Status ImageClassifier::ClassifyAsync(Image image, int64 timestamp_ms, - std::optional roi) { +absl::Status ImageClassifier::ClassifyAsync( + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + NormalizedRect norm_rect = FillNormalizedRect(image_processing_options); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h index 24f36017a..5dff06cc7 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h @@ -105,9 +105,18 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { static absl::StatusOr> Create( std::unique_ptr options); - // Performs image classification on the provided single image. Classification - // is performed on the region of interest specified by the `roi` argument if - // provided, or on the entire image otherwise. + // Performs image classification on the provided single image. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing classification, by + // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° + // anti-clockwise rotation). + // and/or + // - the region-of-interest on which to perform classification, by setting its + // 'x_center', 'y_center', 'width' and 'height' fields. If none of these is + // set, they will automatically be set to cover the full image. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageClassifier is created with the image // running mode. @@ -117,11 +126,21 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // YUVToImageCalculator is integrated. absl::StatusOr Classify( mediapipe::Image image, - std::optional roi = std::nullopt); + std::optional image_processing_options = + std::nullopt); - // Performs image classification on the provided video frame. Classification - // is performed on the region of interested specified by the `roi` argument if - // provided, or on the entire image otherwise. + // Performs image classification on the provided video frame. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing classification, by + // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° + // anti-clockwise rotation). + // and/or + // - the region-of-interest on which to perform classification, by setting its + // 'x_center', 'y_center', 'width' and 'height' fields. If none of these is + // set, they will automatically be set to cover the full image. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageClassifier is created with the video // running mode. @@ -131,12 +150,22 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // must be monotonically increasing. absl::StatusOr ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms, - std::optional roi = std::nullopt); + std::optional + image_processing_options = std::nullopt); // Sends live image data to image classification, and the results will be // available via the "result_callback" provided in the ImageClassifierOptions. - // Classification is performed on the region of interested specified by the - // `roi` argument if provided, or on the entire image otherwise. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing classification, by + // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° + // anti-clockwise rotation). + // and/or + // - the region-of-interest on which to perform classification, by setting its + // 'x_center', 'y_center', 'width' and 'height' fields. If none of these is + // set, they will automatically be set to cover the full image. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageClassifier is created with the live // stream running mode. @@ -153,9 +182,9 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status ClassifyAsync( - mediapipe::Image image, int64 timestamp_ms, - std::optional roi = std::nullopt); + absl::Status ClassifyAsync(mediapipe::Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); // TODO: add Classify() variants taking a region of interest as // additional argument. diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index dcb2fddfc..55830e520 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/image_classifier/image_classifier.h" +#include #include #include #include @@ -546,18 +547,102 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { options->classifier_options.max_results = 1; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); - // NormalizedRect around the soccer ball. - NormalizedRect roi; - roi.set_x_center(0.532); - roi.set_y_center(0.521); - roi.set_width(0.164); - roi.set_height(0.427); + // Crop around the soccer ball. + NormalizedRect image_processing_options; + image_processing_options.set_x_center(0.532); + image_processing_options.set_y_center(0.521); + image_processing_options.set_width(0.164); + image_processing_options.set_height(0.427); - MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image, roi)); + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( + image, image_processing_options)); ExpectApproximatelyEqual(results, GenerateSoccerBallResults(0)); } +TEST_F(ImageModeTest, SucceedsWithRotation) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "burger_rotated.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->classifier_options.max_results = 3; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + // Specify a 90° anti-clockwise rotation. + NormalizedRect image_processing_options; + image_processing_options.set_rotation(M_PI / 2.0); + + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( + image, image_processing_options)); + + // Results differ slightly from the non-rotated image, but that's expected + // as models are very sensitive to the slightest numerical differences + // introduced by the rotation and JPG encoding. + ExpectApproximatelyEqual(results, ParseTextProtoOrDie( + R"pb(classifications { + entries { + categories { + index: 934 + score: 0.6371766 + category_name: "cheeseburger" + } + categories { + index: 963 + score: 0.049443405 + category_name: "meat loaf" + } + categories { + index: 925 + score: 0.047918003 + category_name: "guacamole" + } + timestamp_ms: 0 + } + head_index: 0 + head_name: "probability" + })pb")); +} + +TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "multi_objects_rotated.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->classifier_options.max_results = 1; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + // Crop around the chair, with 90° anti-clockwise rotation. + NormalizedRect image_processing_options; + image_processing_options.set_x_center(0.2821); + image_processing_options.set_y_center(0.2406); + image_processing_options.set_width(0.5642); + image_processing_options.set_height(0.1286); + image_processing_options.set_rotation(M_PI / 2.0); + + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( + image, image_processing_options)); + + ExpectApproximatelyEqual(results, + ParseTextProtoOrDie( + R"pb(classifications { + entries { + categories { + index: 560 + score: 0.6800408 + category_name: "folding chair" + } + timestamp_ms: 0 + } + head_index: 0 + head_name: "probability" + })pb")); +} + class VideoModeTest : public tflite_shims::testing::Test {}; TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { @@ -646,16 +731,17 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { options->classifier_options.max_results = 1; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); - // NormalizedRect around the soccer ball. - NormalizedRect roi; - roi.set_x_center(0.532); - roi.set_y_center(0.521); - roi.set_width(0.164); - roi.set_height(0.427); + // Crop around the soccer ball. + NormalizedRect image_processing_options; + image_processing_options.set_x_center(0.532); + image_processing_options.set_y_center(0.521); + image_processing_options.set_width(0.164); + image_processing_options.set_height(0.427); for (int i = 0; i < iterations; ++i) { - MP_ASSERT_OK_AND_ASSIGN(auto results, - image_classifier->ClassifyForVideo(image, i, roi)); + MP_ASSERT_OK_AND_ASSIGN( + auto results, + image_classifier->ClassifyForVideo(image, i, image_processing_options)); ExpectApproximatelyEqual(results, GenerateSoccerBallResults(i)); } MP_ASSERT_OK(image_classifier->Close()); @@ -790,15 +876,16 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { }; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); - // NormalizedRect around the soccer ball. - NormalizedRect roi; - roi.set_x_center(0.532); - roi.set_y_center(0.521); - roi.set_width(0.164); - roi.set_height(0.427); + // Crop around the soccer ball. + NormalizedRect image_processing_options; + image_processing_options.set_x_center(0.532); + image_processing_options.set_y_center(0.521); + image_processing_options.set_width(0.164); + image_processing_options.set_height(0.427); for (int i = 0; i < iterations; ++i) { - MP_ASSERT_OK(image_classifier->ClassifyAsync(image, i, roi)); + MP_ASSERT_OK( + image_classifier->ClassifyAsync(image, i, image_processing_options)); } MP_ASSERT_OK(image_classifier->Close()); diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 290b29016..8b205cc49 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -25,6 +25,7 @@ package( mediapipe_files(srcs = [ "burger.jpg", "burger_crop.jpg", + "burger_rotated.jpg", "cat.jpg", "cat_mask.jpg", "cats_and_dogs.jpg", @@ -46,6 +47,7 @@ mediapipe_files(srcs = [ "mobilenet_v3_small_100_224_embedder.tflite", "mozart_square.jpg", "multi_objects.jpg", + "multi_objects_rotated.jpg", "palm_detection_full.tflite", "pointing_up.jpg", "right_hands.jpg", @@ -72,6 +74,7 @@ filegroup( srcs = [ "burger.jpg", "burger_crop.jpg", + "burger_rotated.jpg", "cat.jpg", "cat_mask.jpg", "cats_and_dogs.jpg", @@ -81,6 +84,7 @@ filegroup( "left_hands.jpg", "mozart_square.jpg", "multi_objects.jpg", + "multi_objects_rotated.jpg", "pointing_up.jpg", "right_hands.jpg", "segmentation_golden_rotation0.png", diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 254692856..24fb15446 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -58,6 +58,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/burger.jpg?generation=1661875667922678"], ) + http_file( + name = "com_google_mediapipe_burger_rotated_jpg", + sha256 = "b7bb5e59ef778f3ce6b3e616c511908a53d513b83a56aae58b7453e14b0a4b2a", + urls = ["https://storage.googleapis.com/mediapipe-assets/burger_rotated.jpg?generation=1665065843774448"], + ) + http_file( name = "com_google_mediapipe_cat_jpg", sha256 = "2533197401eebe9410ea4d063f86c43fbd2666f3e8165a38aca155c0d09c21be", @@ -436,6 +442,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/multi_objects.jpg?generation=1663251779213308"], ) + http_file( + name = "com_google_mediapipe_multi_objects_rotated_jpg", + sha256 = "175f6c572ffbab6554e382fd5056d09720eef931ccc4ed79481bdc47a8443911", + urls = ["https://storage.googleapis.com/mediapipe-assets/multi_objects_rotated.jpg?generation=1665065847969523"], + ) + http_file( name = "com_google_mediapipe_object_detection_3d_camera_tflite", sha256 = "f66e92e81ed3f4698f74d565a7668e016e2288ea92fb42938e33b778bd1e110d", From a9fea36cb3c4ce28e2a1f217d3d098a0152c8bb6 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 12 Oct 2022 11:41:28 -0700 Subject: [PATCH 12/18] Put Destination::base_ into private section. (Cleanup.) PiperOrigin-RevId: 480678168 --- mediapipe/framework/api2/builder.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 11bcd21c6..111451ce9 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -120,6 +120,9 @@ using AllowCast = std::integral_constant && } // namespace internal_builder +template +class SourceImpl; + // These classes wrap references to the underlying source/destination // endpoints, adding type information and the user-visible API. template @@ -137,10 +140,14 @@ class DestinationImpl { return DestinationImpl(&base_); } + private: DestinationBase& base_; + + template + friend class SourceImpl; }; -template +template class SourceImpl { public: using Base = SourceBase; From 02746d0700ff3a99ba1ebf2a3fdbbb3edde0ac35 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 12 Oct 2022 14:36:15 -0700 Subject: [PATCH 13/18] Remove unused includes. PiperOrigin-RevId: 480720274 --- mediapipe/framework/api2/BUILD | 5 ----- mediapipe/framework/api2/builder.h | 5 ----- 2 files changed, 10 deletions(-) diff --git a/mediapipe/framework/api2/BUILD b/mediapipe/framework/api2/BUILD index 6de444438..76aace6f5 100644 --- a/mediapipe/framework/api2/BUILD +++ b/mediapipe/framework/api2/BUILD @@ -14,15 +14,10 @@ cc_library( name = "builder", hdrs = ["builder.h"], deps = [ - ":const_str", - ":contract", - ":node", - ":packet", ":port", "//mediapipe/framework:calculator_base", "//mediapipe/framework:calculator_contract", "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 111451ce9..90c43a8d5 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -5,12 +5,7 @@ #include #include "absl/container/btree_map.h" -#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" -#include "mediapipe/framework/api2/const_str.h" -#include "mediapipe/framework/api2/contract.h" -#include "mediapipe/framework/api2/node.h" -#include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_contract.h" From 179824a21de1c227a7d009bf9a86030cd603d0e1 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 12 Oct 2022 14:40:26 -0700 Subject: [PATCH 14/18] Use string_view when adding nodes/generators #cleanup PiperOrigin-RevId: 480721234 --- mediapipe/framework/api2/builder.h | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 90c43a8d5..82905d2f5 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -440,8 +440,9 @@ class Graph { // Creates a node of a specific type. Should be used for pure interfaces, // which do not have a built-in type string. template - Node& AddNode(const std::string& type) { - auto node = std::make_unique>(type); + Node& AddNode(absl::string_view type) { + auto node = + std::make_unique>(std::string(type.data(), type.size())); auto node_p = node.get(); nodes_.emplace_back(std::move(node)); return *node_p; @@ -449,16 +450,18 @@ class Graph { // Creates a generic node, with no compile-time checking of inputs and // outputs. This can be used for calculators whose contract is not visible. - GenericNode& AddNode(const std::string& type) { - auto node = std::make_unique(type); + GenericNode& AddNode(absl::string_view type) { + auto node = + std::make_unique(std::string(type.data(), type.size())); auto node_p = node.get(); nodes_.emplace_back(std::move(node)); return *node_p; } // For legacy PacketGenerators. - PacketGenerator& AddPacketGenerator(const std::string& type) { - auto node = std::make_unique(type); + PacketGenerator& AddPacketGenerator(absl::string_view type) { + auto node = std::make_unique( + std::string(type.data(), type.size())); auto node_p = node.get(); packet_gens_.emplace_back(std::move(node)); return *node_p; From 12c323ffdecc02c9c8617ee72787cea1de183a36 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 12 Oct 2022 15:39:55 -0700 Subject: [PATCH 15/18] Exports gesture recognizer related proto as Java package. PiperOrigin-RevId: 480735444 --- .../containers/proto/landmarks_detection_result.proto | 3 +++ .../proto/gesture_classifier_graph_options.proto | 3 +++ .../proto/gesture_embedder_graph_options.proto | 3 +++ .../proto/hand_gesture_recognizer_graph_options.proto | 3 +++ .../hand_landmarker/proto/hand_landmarker_graph_options.proto | 3 +++ .../proto/hand_landmarks_detector_graph_options.proto | 3 +++ 6 files changed, 18 insertions(+) diff --git a/mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.proto b/mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.proto index 9be6ce47a..ac44f9b58 100644 --- a/mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.proto +++ b/mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.proto @@ -21,6 +21,9 @@ import "mediapipe/framework/formats/classification.proto"; import "mediapipe/framework/formats/landmark.proto"; import "mediapipe/framework/formats/rect.proto"; +option java_package = "com.google.mediapipe.tasks.components.containers.proto"; +option java_outer_classname = "LandmarksDetectionResultProto"; + message LandmarksDetectionResult { optional mediapipe.NormalizedLandmarkList landmarks = 1; optional mediapipe.ClassificationList classifications = 2; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto index 7730f005f..61e367b2b 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto @@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer"; +option java_outer_classname = "GestureClassifierGraphOptionsProto"; + message GestureClassifierGraphOptions { extend mediapipe.CalculatorOptions { optional GestureClassifierGraphOptions ext = 478825465; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto index c12359eb3..6170006f4 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto @@ -20,6 +20,9 @@ package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer"; +option java_outer_classname = "GestureEmbedderGraphOptionsProto"; + message GestureEmbedderGraphOptions { extend mediapipe.CalculatorOptions { optional GestureEmbedderGraphOptions ext = 478825422; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto index f71a6b22f..982841690 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto @@ -23,6 +23,9 @@ import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto"; +option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer"; +option java_outer_classname = "HandGestureRecognizerGraphOptionsProto"; + message HandGestureRecognizerGraphOptions { extend mediapipe.CalculatorOptions { optional HandGestureRecognizerGraphOptions ext = 463370452; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto index 7f3536b09..cd4efc042 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto @@ -22,6 +22,9 @@ import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto"; import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto"; +option java_package = "com.google.mediapipe.tasks.vision.handlandmarker"; +option java_outer_classname = "HandLandmarkerGraphOptionsProto"; + message HandLandmarkerGraphOptions { extend mediapipe.CalculatorOptions { optional HandLandmarkerGraphOptions ext = 462713202; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto index 8c0fc66f2..7b8b45c4f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto @@ -20,6 +20,9 @@ package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +option java_package = "com.google.mediapipe.tasks.vision.handlandmarker"; +option java_outer_classname = "HandLandmarksDetectorGraphOptionsProto"; + message HandLandmarksDetectorGraphOptions { extend mediapipe.CalculatorOptions { optional HandLandmarksDetectorGraphOptions ext = 474472470; From 9353ed6cced1668cd0841f9ba9542819c5d5e093 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 13 Oct 2022 14:00:35 -0700 Subject: [PATCH 16/18] Java gesture recognizer Tasks API and unit test. PiperOrigin-RevId: 480978244 --- .../gesture_recognizer/gesture_recognizer.h | 5 +- .../google/mediapipe/tasks/vision/core/BUILD | 1 + .../tasks/vision/gesturerecognizer/BUILD | 9 + .../gesturerecognizer/GestureRecognizer.java | 466 +++++++++++++++++ .../gesturerecognizer/AndroidManifest.xml | 24 + .../tasks/vision/gesturerecognizer/BUILD | 19 + .../GestureRecognizerTest.java | 495 ++++++++++++++++++ 7 files changed, 1016 insertions(+), 3 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/AndroidManifest.xml create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h index 17c9cc921..53b824e25 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h @@ -87,9 +87,8 @@ struct GestureRecognizerOptions { // Performs hand gesture recognition on the given image. // // TODO add the link to DevSite. -// This API expects expects a pre-trained hand gesture model asset bundle, or a -// custom one created using Model Maker. See . +// This API expects a pre-trained hand gesture model asset bundle, or a custom +// one created using Model Maker. See . // // Inputs: // Image diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD index 8df9173b2..453ae9a90 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD @@ -40,6 +40,7 @@ cc_binary( deps = [ "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD index eb3eca52b..7782a747e 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD @@ -20,6 +20,7 @@ android_library( name = "gesturerecognizer", srcs = [ "GestureRecognitionResult.java", + "GestureRecognizer.java", ], javacopts = [ "-Xep:AndroidJdkLibsChecker:OFF", @@ -29,11 +30,19 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/framework/formats:classification_java_proto_lite", "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", ], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java new file mode 100644 index 000000000..e429cc6dc --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -0,0 +1,466 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.gesturerecognizer; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.handdetector.HandDetectorGraphOptionsProto; +import com.google.mediapipe.tasks.vision.handlandmarker.HandLandmarkerGraphOptionsProto; +import com.google.mediapipe.tasks.vision.handlandmarker.HandLandmarksDetectorGraphOptionsProto; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs gesture recognition on images. + * + *

This API expects a pre-trained hand gesture model asset bundle, or a custom one created using + * Model Maker. See . + * + *

    + *
  • Input image {@link Image} + *
      + *
    • The image that gesture recognition runs on. + *
    + *
  • Output GestureRecognitionResult {@link GestureRecognitionResult} + *
      + *
    • A GestureRecognitionResult containing hand landmarks and recognized hand gestures. + *
    + *
+ */ +public final class GestureRecognizer extends BaseVisionTaskApi { + private static final String TAG = GestureRecognizer.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList( + "LANDMARKS:hand_landmarks", + "WORLD_LANDMARKS:world_hand_landmarks", + "HANDEDNESS:handedness", + "HAND_GESTURES:hand_gestures", + "IMAGE:image_out")); + private static final int LANDMARKS_OUT_STREAM_INDEX = 0; + private static final int WORLD_LANDMARKS_OUT_STREAM_INDEX = 1; + private static final int HANDEDNESS_OUT_STREAM_INDEX = 2; + private static final int HAND_GESTURES_OUT_STREAM_INDEX = 3; + private static final int IMAGE_OUT_STREAM_INDEX = 4; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph"; + + /** + * Creates a {@link GestureRecognizer} instance from a model file and the default {@link + * GestureRecognizerOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the gesture recognition model with metadata in the assets. + * @throws MediaPipeException if there is an error during {@link GestureRecognizer} creation. + */ + public static GestureRecognizer createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, GestureRecognizerOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link GestureRecognizer} instance from a model file and the default {@link + * GestureRecognizerOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the gesture recognition model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link GestureRecognizer} creation. + */ + public static GestureRecognizer createFromFile(Context context, File modelFile) + throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, GestureRecognizerOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates a {@link GestureRecognizer} instance from a model buffer and the default {@link + * GestureRecognizerOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection + * model. + * @throws MediaPipeException if there is an error during {@link GestureRecognizer} creation. + */ + public static GestureRecognizer createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, GestureRecognizerOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link GestureRecognizer} instance from a {@link GestureRecognizerOptions}. + * + * @param context an Android {@link Context}. + * @param recognizerOptions a {@link GestureRecognizerOptions} instance. + * @throws MediaPipeException if there is an error during {@link GestureRecognizer} creation. + */ + public static GestureRecognizer createFromOptions( + Context context, GestureRecognizerOptions recognizerOptions) { + // TODO: Consolidate OutputHandler and TaskRunner. + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public GestureRecognitionResult convertToTaskResult(List packets) { + // If there is no hands detected in the image, just returns empty lists. + if (packets.get(HAND_GESTURES_OUT_STREAM_INDEX).isEmpty()) { + return GestureRecognitionResult.create( + new ArrayList<>(), + new ArrayList<>(), + new ArrayList<>(), + new ArrayList<>(), + packets.get(HAND_GESTURES_OUT_STREAM_INDEX).getTimestamp()); + } + return GestureRecognitionResult.create( + PacketGetter.getProtoVector( + packets.get(LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()), + PacketGetter.getProtoVector( + packets.get(WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser()), + PacketGetter.getProtoVector( + packets.get(HANDEDNESS_OUT_STREAM_INDEX), ClassificationList.parser()), + PacketGetter.getProtoVector( + packets.get(HAND_GESTURES_OUT_STREAM_INDEX), ClassificationList.parser()), + packets.get(HAND_GESTURES_OUT_STREAM_INDEX).getTimestamp()); + } + + @Override + public Image convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + recognizerOptions.resultListener().ifPresent(handler::setResultListener); + recognizerOptions.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(recognizerOptions) + .setEnableFlowLimiting(recognizerOptions.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new GestureRecognizer(runner, recognizerOptions.runningMode()); + } + + /** + * Constructor to initialize an {@link GestureRecognizer} from a {@link TaskRunner} and a {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private GestureRecognizer(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME); + } + + /** + * Performs gesture recognition on the provided single image. Only use this method when the {@link + * GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc + * for input image format. + * + *

{@link GestureRecognizer} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public GestureRecognitionResult recognize(Image inputImage) { + return (GestureRecognitionResult) processImageData(inputImage); + } + + /** + * Performs gesture recognition on the provided video frame. Only use this method when the {@link + * GestureRecognizer} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link GestureRecognizer} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputTimestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public GestureRecognitionResult recognizeForVideo(Image inputImage, long inputTimestampMs) { + return (GestureRecognitionResult) processVideoData(inputImage, inputTimestampMs); + } + + /** + * Sends live image data to perform gesture recognition, and the results will be available via the + * {@link ResultListener} provided in the {@link GestureRecognizerOptions}. Only use this method + * when the {@link GestureRecognition} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the gesture recognizer. The input timestamps must be monotonically increasing. + * + *

{@link GestureRecognizer} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputTimestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void recognizeAsync(Image inputImage, long inputTimestampMs) { + sendLiveStreamData(inputImage, inputTimestampMs); + } + + /** Options for setting up an {@link GestureRecognizer}. */ + @AutoValue + public abstract static class GestureRecognizerOptions extends TaskOptions { + + /** Builder for {@link GestureRecognizerOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the base options for the gesture recognizer task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** + * Sets the running mode for the gesture recognizer task. Default to the image mode. Gesture + * recognizer has three modes: + * + *
    + *
  • IMAGE: The mode for recognizing gestures on single image inputs. + *
  • VIDEO: The mode for recognizing gestures on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for recognizing gestures on a live stream of input data, + * such as from camera. In this mode, {@code setResultListener} must be called to set up + * a listener to receive the recognition results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode value); + + // TODO: remove these. Temporary solutions before bundle asset is ready. + public abstract Builder setBaseOptionsHandDetector(BaseOptions value); + + public abstract Builder setBaseOptionsHandLandmarker(BaseOptions value); + + public abstract Builder setBaseOptionsGestureRecognizer(BaseOptions value); + + /** Sets the maximum number of hands can be detected by the GestureRecognizer. */ + public abstract Builder setNumHands(Integer value); + + /** Sets minimum confidence score for the hand detection to be considered successfully */ + public abstract Builder setMinHandDetectionConfidence(Float value); + + /** Sets minimum confidence score of hand presence score in the hand landmark detection. */ + public abstract Builder setMinHandPresenceConfidence(Float value); + + /** Sets the minimum confidence score for the hand tracking to be considered successfully. */ + public abstract Builder setMinTrackingConfidence(Float value); + + /** + * Sets the minimum confidence score for the gestures to be considered successfully. If < 0, + * the gesture confidence threshold=0.5 for the model is used. + * + *

TODO Note this option is subject to change, after scoring merging + * calculator is implemented. + */ + public abstract Builder setMinGestureConfidence(Float value); + + /** + * Sets the result listener to receive the detection results asynchronously when the gesture + * recognizer is in the live stream mode. + */ + public abstract Builder setResultListener( + ResultListener value); + + /** Sets an optional error listener. */ + public abstract Builder setErrorListener(ErrorListener value); + + abstract GestureRecognizerOptions autoBuild(); + + /** + * Validates and builds the {@link GestureRecognizerOptions} instance. + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the object detector is + * in the live stream mode. + */ + public final GestureRecognizerOptions build() { + GestureRecognizerOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The gesture recognizer is in the live stream mode, a user-defined result listener" + + " must be provided in GestureRecognizerOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The gesture recognizer is in the image or the video mode, a user-defined result" + + " listener shouldn't be provided in GestureRecognizerOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + // TODO: remove these. Temporary solutions before bundle asset is ready. + abstract BaseOptions baseOptionsHandDetector(); + + abstract BaseOptions baseOptionsHandLandmarker(); + + abstract BaseOptions baseOptionsGestureRecognizer(); + + abstract RunningMode runningMode(); + + abstract Optional numHands(); + + abstract Optional minHandDetectionConfidence(); + + abstract Optional minHandPresenceConfidence(); + + abstract Optional minTrackingConfidence(); + + // TODO update gesture confidence options after score merging calculator is ready. + abstract Optional minGestureConfidence(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_GestureRecognizer_GestureRecognizerOptions.Builder() + .setRunningMode(RunningMode.IMAGE) + .setNumHands(1) + .setMinHandDetectionConfidence(0.5f) + .setMinHandPresenceConfidence(0.5f) + .setMinTrackingConfidence(0.5f) + .setMinGestureConfidence(-1f); + } + + /** + * Converts a {@link GestureRecognizerOptions} to a {@link CalculatorOptions} protobuf message. + */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptions())); + GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.Builder taskOptionsBuilder = + GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder); + + // Setup HandDetectorGraphOptions. + HandDetectorGraphOptionsProto.HandDetectorGraphOptions.Builder + handDetectorGraphOptionsBuilder = + HandDetectorGraphOptionsProto.HandDetectorGraphOptions.newBuilder() + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptionsHandDetector()))); + numHands().ifPresent(handDetectorGraphOptionsBuilder::setNumHands); + minHandDetectionConfidence() + .ifPresent(handDetectorGraphOptionsBuilder::setMinDetectionConfidence); + + // Setup HandLandmarkerGraphOptions. + HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.Builder + handLandmarksDetectorGraphOptionsBuilder = + HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.newBuilder() + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptionsHandLandmarker()))); + minHandPresenceConfidence() + .ifPresent(handLandmarksDetectorGraphOptionsBuilder::setMinDetectionConfidence); + HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.Builder + handLandmarkerGraphOptionsBuilder = + HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder() + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptionsHandLandmarker()))); + minTrackingConfidence() + .ifPresent(handLandmarkerGraphOptionsBuilder::setMinTrackingConfidence); + handLandmarkerGraphOptionsBuilder + .setHandDetectorGraphOptions(handDetectorGraphOptionsBuilder.build()) + .setHandLandmarksDetectorGraphOptions(handLandmarksDetectorGraphOptionsBuilder.build()); + + // Setup HandGestureRecognizerGraphOptions. + HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.Builder + handGestureRecognizerGraphOptionsBuilder = + HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.newBuilder() + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptionsGestureRecognizer()))); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + minGestureConfidence().ifPresent(classifierOptionsBuilder::setScoreThreshold); + handGestureRecognizerGraphOptionsBuilder.setClassifierOptions( + classifierOptionsBuilder.build()); + + taskOptionsBuilder + .setHandLandmarkerGraphOptions(handLandmarkerGraphOptionsBuilder.build()) + .setHandGestureRecognizerGraphOptions(handGestureRecognizerGraphOptionsBuilder.build()); + return CalculatorOptions.newBuilder() + .setExtension( + GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/AndroidManifest.xml new file mode 100644 index 000000000..dd3ceb848 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java new file mode 100644 index 000000000..efec02b2a --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java @@ -0,0 +1,495 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.gesturerecognizer; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.common.truth.Correspondence; +import com.google.mediapipe.formats.proto.ClassificationProto; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.gesturerecognizer.GestureRecognizer.GestureRecognizerOptions; +import java.io.InputStream; +import java.util.Arrays; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link GestureRecognizer}. */ +@RunWith(Suite.class) +@SuiteClasses({GestureRecognizerTest.General.class, GestureRecognizerTest.RunningModeTest.class}) +public class GestureRecognizerTest { + private static final String HAND_DETECTOR_MODEL_FILE = "palm_detection_full.tflite"; + private static final String HAND_LANDMARKER_MODEL_FILE = "hand_landmark_full.tflite"; + private static final String GESTURE_RECOGNIZER_MODEL_FILE = + "cg_classifier_screen3d_landmark_features_nn_2022_08_04_base_simple_model.tflite"; + private static final String TWO_HANDS_IMAGE = "right_hands.jpg"; + private static final String THUMB_UP_IMAGE = "thumb_up.jpg"; + private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg"; + private static final String THUMB_UP_LANDMARKS = "thumb_up_landmarks.pb"; + private static final String TAG = "Gesture Recognizer Test"; + private static final String THUMB_UP_LABEL = "Thumb_Up"; + private static final int THUMB_UP_INDEX = 5; + private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f; + private static final int IMAGE_WIDTH = 382; + private static final int IMAGE_HEIGHT = 406; + + @RunWith(AndroidJUnit4.class) + public static final class General extends GestureRecognizerTest { + + @Test + public void recognize_successWithValidModels() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE)); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void recognize_successWithEmptyResult() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(NO_HANDS_IMAGE)); + assertThat(actualResult.landmarks()).isEmpty(); + assertThat(actualResult.worldLandmarks()).isEmpty(); + assertThat(actualResult.handednesses()).isEmpty(); + assertThat(actualResult.gestures()).isEmpty(); + } + + @Test + public void recognize_successWithMinGestureConfidence() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + // TODO update the confidence to be in range [0,1] after embedding model + // and scoring calculator is integrated. + .setMinGestureConfidence(3.0f) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE)); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + // Only contains one top scoring gesture. + assertThat(actualResult.gestures().get(0)).hasSize(1); + assertActualGestureEqualExpectedGesture( + actualResult.gestures().get(0).get(0), expectedResult.gestures().get(0).get(0)); + } + + @Test + public void recognize_successWithNumHands() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setNumHands(2) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(TWO_HANDS_IMAGE)); + assertThat(actualResult.handednesses()).hasSize(2); + } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends GestureRecognizerTest { + @Test + public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception { + for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) + .build()) + .setBaseOptionsHandDetector( + BaseOptions.builder() + .setModelAssetPath(HAND_DETECTOR_MODEL_FILE) + .build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder() + .setModelAssetPath(HAND_LANDMARKER_MODEL_FILE) + .build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) + .build()) + .setRunningMode(mode) + .setResultListener((gestureRecognitionResult, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + } + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) + .build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) + .build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void recognize_failsWithCallingWrongApiInImageMode() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void recognize_failsWithCallingWrongApiInVideoMode() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void recognize_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((gestureRecognitionResult, inputImage) -> {}) + .build(); + + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void recognize_successWithImageMode() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE)); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void recognize_successWithVideoMode() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + for (int i = 0; i < 3; i++) { + GestureRecognitionResult actualResult = + gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), i); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + } + + @Test + public void recognize_failsWithOutOfOrderInputTimestamps() throws Exception { + Image image = getImageFromAsset(THUMB_UP_IMAGE); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (actualResult, inputImage) -> { + assertActualResultApproximatelyEqualsToExpectedResult( + actualResult, expectedResult); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + gestureRecognizer.recognizeAsync(image, 1); + MediaPipeException exception = + assertThrows(MediaPipeException.class, () -> gestureRecognizer.recognizeAsync(image, 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void recognize_successWithLiveSteamMode() throws Exception { + Image image = getImageFromAsset(THUMB_UP_IMAGE); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (actualResult, inputImage) -> { + assertActualResultApproximatelyEqualsToExpectedResult( + actualResult, expectedResult); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; i++) { + gestureRecognizer.recognizeAsync(image, i); + } + } + } + + private static Image getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static GestureRecognitionResult getExpectedGestureRecognitionResult( + String filePath, String gestureLabel, int gestureIndex) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + LandmarksDetectionResult landmarksDetectionResultProto = + LandmarksDetectionResult.parser().parseFrom(istr); + ClassificationProto.ClassificationList gesturesProto = + ClassificationProto.ClassificationList.newBuilder() + .addClassification( + ClassificationProto.Classification.newBuilder() + .setLabel(gestureLabel) + .setIndex(gestureIndex)) + .build(); + return GestureRecognitionResult.create( + Arrays.asList(landmarksDetectionResultProto.getLandmarks()), + Arrays.asList(landmarksDetectionResultProto.getWorldLandmarks()), + Arrays.asList(landmarksDetectionResultProto.getClassifications()), + Arrays.asList(gesturesProto), + /*timestampMs=*/ 0); + } + + private static void assertActualResultApproximatelyEqualsToExpectedResult( + GestureRecognitionResult actualResult, GestureRecognitionResult expectedResult) { + // Expects to have the same number of hands detected. + assertThat(actualResult.landmarks()).hasSize(expectedResult.landmarks().size()); + assertThat(actualResult.worldLandmarks()).hasSize(expectedResult.worldLandmarks().size()); + assertThat(actualResult.handednesses()).hasSize(expectedResult.handednesses().size()); + assertThat(actualResult.gestures()).hasSize(expectedResult.gestures().size()); + + // Actual landmarks match expected landmarks. + assertThat(actualResult.landmarks().get(0)) + .comparingElementsUsing( + Correspondence.from( + (Correspondence.BinaryPredicate) + (actual, expected) -> { + return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) + .compare(actual.x(), expected.x()) + && Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) + .compare(actual.y(), expected.y()); + }, + "landmarks approximately equal to")) + .containsExactlyElementsIn(expectedResult.landmarks().get(0)); + + // Actual handedness matches expected handedness. + Category actualTopHandedness = actualResult.handednesses().get(0).get(0); + Category expectedTopHandedness = expectedResult.handednesses().get(0).get(0); + assertThat(actualTopHandedness.index()).isEqualTo(expectedTopHandedness.index()); + assertThat(actualTopHandedness.categoryName()).isEqualTo(expectedTopHandedness.categoryName()); + + // Actual gesture with top score matches expected gesture. + Category actualTopGesture = actualResult.gestures().get(0).get(0); + Category expectedTopGesture = expectedResult.gestures().get(0).get(0); + assertActualGestureEqualExpectedGesture(actualTopGesture, expectedTopGesture); + } + + private static void assertActualGestureEqualExpectedGesture( + Category actualGesture, Category expectedGesture) { + assertThat(actualGesture.index()).isEqualTo(actualGesture.index()); + assertThat(expectedGesture.categoryName()).isEqualTo(expectedGesture.categoryName()); + } + + private static void assertImageSizeIsExpected(Image inputImage) { + assertThat(inputImage).isNotNull(); + assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH); + assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT); + } +} From b632e645f5bc1b744cbb4b477da75810e4369b72 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 14 Oct 2022 00:13:28 -0700 Subject: [PATCH 17/18] Merge all BUILD files in the subdirectories of java/com/google/mediapipe/tasks/vision into one BUILD file. PiperOrigin-RevId: 481074268 --- .../android/objectdetector/src/main/BUILD | 4 +- .../com/google/mediapipe/tasks/vision/BUILD | 142 ++++++++++++++++++ .../google/mediapipe/tasks/vision/core/BUILD | 54 ------- .../tasks/vision/gesturerecognizer/BUILD | 49 ------ .../tasks/vision/imageclassifier/BUILD | 46 ------ .../tasks/vision/objectdetector/BUILD | 44 ------ 6 files changed, 144 insertions(+), 195 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD delete mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD delete mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD delete mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/BUILD delete mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/BUILD diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD b/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD index 65b98d647..acbdbd6eb 100644 --- a/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD @@ -34,8 +34,8 @@ android_binary( "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:objectdetector", "//third_party:androidx_appcompat", "//third_party:androidx_constraint_layout", "//third_party:opencv", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD new file mode 100644 index 000000000..dcf3b3542 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -0,0 +1,142 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +android_library( + name = "core", + srcs = glob(["core/*.java"]), + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + deps = [ + ":libmediapipe_tasks_vision_jni_lib", + "//mediapipe/framework/formats:rect_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "@maven//:com_google_guava_guava", + ], +) + +# The native library of all MediaPipe vision tasks. +cc_binary( + name = "libmediapipe_tasks_vision_jni.so", + linkshared = 1, + linkstatic = 1, + deps = [ + "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", + "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", + "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", + ], +) + +cc_library( + name = "libmediapipe_tasks_vision_jni_lib", + srcs = [":libmediapipe_tasks_vision_jni.so"], + alwayslink = 1, +) + +android_library( + name = "objectdetector", + srcs = [ + "objectdetector/ObjectDetectionResult.java", + "objectdetector/ObjectDetector.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "objectdetector/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/framework/formats:detection_java_proto_lite", + "//mediapipe/framework/formats:location_data_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + +android_library( + name = "imageclassifier", + srcs = [ + "imageclassifier/ImageClassificationResult.java", + "imageclassifier/ImageClassifier.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "imageclassifier/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", + "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + +android_library( + name = "gesturerecognizer", + srcs = [ + "gesturerecognizer/GestureRecognitionResult.java", + "gesturerecognizer/GestureRecognizer.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "gesturerecognizer/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/framework/formats:classification_java_proto_lite", + "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD deleted file mode 100644 index 453ae9a90..000000000 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -licenses(["notice"]) - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -android_library( - name = "core", - srcs = glob(["*.java"]), - javacopts = [ - "-Xep:AndroidJdkLibsChecker:OFF", - ], - deps = [ - ":libmediapipe_tasks_vision_jni_lib", - "//mediapipe/framework/formats:rect_java_proto_lite", - "//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff", - "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", - "@maven//:com_google_guava_guava", - ], -) - -# The native library of all MediaPipe vision tasks. -cc_binary( - name = "libmediapipe_tasks_vision_jni.so", - linkshared = 1, - linkstatic = 1, - deps = [ - "//mediapipe/calculators/core:flow_limiter_calculator", - "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", - "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", - "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", - "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", - ], -) - -cc_library( - name = "libmediapipe_tasks_vision_jni_lib", - srcs = [":libmediapipe_tasks_vision_jni.so"], - alwayslink = 1, -) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD deleted file mode 100644 index 7782a747e..000000000 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -android_library( - name = "gesturerecognizer", - srcs = [ - "GestureRecognitionResult.java", - "GestureRecognizer.java", - ], - javacopts = [ - "-Xep:AndroidJdkLibsChecker:OFF", - ], - manifest = ":AndroidManifest.xml", - deps = [ - "//mediapipe/framework:calculator_options_java_proto_lite", - "//mediapipe/framework/formats:classification_java_proto_lite", - "//mediapipe/framework/formats:landmark_java_proto_lite", - "//mediapipe/java/com/google/mediapipe/framework:android_framework", - "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", - "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core", - "//third_party:autovalue", - "@maven//:com_google_guava_guava", - ], -) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/BUILD deleted file mode 100644 index cecd9f521..000000000 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/BUILD +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -android_library( - name = "imageclassifier", - srcs = [ - "ImageClassificationResult.java", - "ImageClassifier.java", - ], - javacopts = [ - "-Xep:AndroidJdkLibsChecker:OFF", - ], - manifest = ":AndroidManifest.xml", - deps = [ - "//mediapipe/framework:calculator_options_java_proto_lite", - "//mediapipe/java/com/google/mediapipe/framework:android_framework", - "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", - "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", - "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", - "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core", - "//third_party:autovalue", - "@maven//:com_google_guava_guava", - ], -) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/BUILD deleted file mode 100644 index 8ba2705eb..000000000 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/BUILD +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -android_library( - name = "objectdetector", - srcs = [ - "ObjectDetectionResult.java", - "ObjectDetector.java", - ], - javacopts = [ - "-Xep:AndroidJdkLibsChecker:OFF", - ], - manifest = ":AndroidManifest.xml", - deps = [ - "//mediapipe/framework:calculator_options_java_proto_lite", - "//mediapipe/framework/formats:detection_java_proto_lite", - "//mediapipe/framework/formats:location_data_java_proto_lite", - "//mediapipe/java/com/google/mediapipe/framework:android_framework", - "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", - "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core", - "//third_party:autovalue", - "@maven//:com_google_guava_guava", - ], -) From 0ebe6ccf5959bd2a9bcdd4b3839482fccf3ddf26 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 14 Oct 2022 01:00:52 -0700 Subject: [PATCH 18/18] Add filegroups to mediapipe java src dirs. PiperOrigin-RevId: 481080348 --- mediapipe/java/com/google/mediapipe/framework/image/BUILD | 7 +++++++ .../com/google/mediapipe/tasks/components/containers/BUILD | 7 +++++++ .../com/google/mediapipe/tasks/components/processors/BUILD | 7 +++++++ 3 files changed, 21 insertions(+) diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BUILD b/mediapipe/java/com/google/mediapipe/framework/image/BUILD index abf82a892..bb3be318d 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/image/BUILD @@ -30,3 +30,10 @@ android_library( "@maven//:com_google_guava_guava", ], ) + +# Expose the java source files for building mediapipe AAR. +filegroup( + name = "java_src", + srcs = glob(["*.java"]), + visibility = ["//mediapipe:__subpackages__"], +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index 610bec911..9dfa53031 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -63,3 +63,10 @@ android_library( "@maven//:com_google_guava_guava", ], ) + +# Expose the java source files for building mediapipe tasks core AAR. +filegroup( + name = "java_src", + srcs = glob(["*.java"]), + visibility = ["//mediapipe/tasks/java/com/google/mediapipe/tasks/core:__subpackages__"], +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD index 88516d806..1f99f1612 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD @@ -28,3 +28,10 @@ android_library( "@maven//:com_google_guava_guava", ], ) + +# Expose the java source files for building mediapipe tasks core AAR. +filegroup( + name = "java_src", + srcs = glob(["*.java"]), + visibility = ["//mediapipe/tasks/java/com/google/mediapipe/tasks/core:__subpackages__"], +)