diff --git a/mediapipe/tasks/python/metadata/metadata_writers/BUILD b/mediapipe/tasks/python/metadata/metadata_writers/BUILD index f7db16682..268dec494 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/BUILD +++ b/mediapipe/tasks/python/metadata/metadata_writers/BUILD @@ -78,6 +78,15 @@ py_library( ], ) +py_library( + name = "face_stylizer", + srcs = ["face_stylizer.py"], + deps = [ + ":metadata_writer", + ":model_asset_bundle_utils", + ], +) + py_library( name = "model_asset_bundle_utils", srcs = ["model_asset_bundle_utils.py"], diff --git a/mediapipe/tasks/python/metadata/metadata_writers/face_stylizer.py b/mediapipe/tasks/python/metadata/metadata_writers/face_stylizer.py new file mode 100644 index 000000000..01d0d7027 --- /dev/null +++ b/mediapipe/tasks/python/metadata/metadata_writers/face_stylizer.py @@ -0,0 +1,138 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Writes metadata and creates model asset bundle for face stylizer.""" + +import os +import tempfile +from typing import List + +from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer +from mediapipe.tasks.python.metadata.metadata_writers import model_asset_bundle_utils + +_MODEL_NAME = "FaceStylizer" +_MODEL_DESCRIPTION = "Performs face stylization on images." +_FACE_DETECTOR_MODEL = "face_detector.tflite" +_FACE_LANDMARKS_DETECTOR_MODEL = "face_landmarks_detector.tflite" +_FACE_STYLIZER_MODEL = "face_stylizer.tflite" +_FACE_STYLIZER_TASK = "face_stylizer.task" + + +class MetadataWriter: + """MetadataWriter to write the metadata for face stylizer.""" + + def __init__( + self, + face_detector_model_buffer: bytearray, + face_landmarks_detector_model_buffer: bytearray, + face_stylizer_metadata_writer: metadata_writer.MetadataWriter, + ) -> None: + """Initializes MetadataWriter to write the metadata and model asset bundle. + + Args: + face_detector_model_buffer: A valid flatbuffer loaded from the face + detector TFLite model file with metadata already packed inside. + face_landmarks_detector_model_buffer: A valid flatbuffer loaded from the + face landmarks detector TFLite model file with metadata already packed + inside. + face_stylizer_metadata_writer: Metadata writer to write face stylizer + metadata into the TFLite file. + """ + self._face_detector_model_buffer = face_detector_model_buffer + self._face_landmarks_detector_model_buffer = ( + face_landmarks_detector_model_buffer + ) + self._face_stylizer_metadata_writer = face_stylizer_metadata_writer + self._temp_folder = tempfile.TemporaryDirectory() + + def __del__(self): + if os.path.exists(self._temp_folder.name): + self._temp_folder.cleanup() + + @classmethod + def create( + cls, + face_stylizer_model_buffer: bytearray, + face_detector_model_buffer: bytearray, + face_landmarks_detector_model_buffer: bytearray, + input_norm_mean: List[float], + input_norm_std: List[float], + ) -> "MetadataWriter": + """Creates MetadataWriter to write the metadata for face stylizer. + + The parameters required in this method are mandatory when using MediaPipe + Tasks. + + Note that only the output TFLite is used for deployment. The output JSON + content is used to interpret the metadata content. + + Args: + face_stylizer_model_buffer: A valid flatbuffer loaded from the face + stylizer TFLite model file. + face_detector_model_buffer: A valid flatbuffer loaded from the face + detector TFLite model file with metadata already packed inside. + face_landmarks_detector_model_buffer: A valid flatbuffer loaded from the + face landmarks detector TFLite model file with metadata already packed + inside. + input_norm_mean: the mean value used in the input tensor normalization for + face stylizer model [1]. + input_norm_std: the std value used in the input tensor normalizarion for + face stylizer model [1]. + + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389 + + Returns: + A MetadataWriter object. + """ + face_stylizer_writer = metadata_writer.MetadataWriter( + face_stylizer_model_buffer + ) + face_stylizer_writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION) + face_stylizer_writer.add_image_input(input_norm_mean, input_norm_std) + return cls( + face_detector_model_buffer, + face_landmarks_detector_model_buffer, + face_stylizer_writer, + ) + + def populate(self): + """Populates the metadata and creates model asset bundle. + + Note that only the output model asset bundle is used for deployment. + The output JSON content is used to interpret the face stylizer metadata + content. + + Returns: + A tuple of (model_asset_bundle_in_bytes, metadata_json_content) + """ + # Write metadata into the face stylizer TFLite model. + face_stylizer_model_buffer, face_stylizer_metadata_json = ( + self._face_stylizer_metadata_writer.populate() + ) + # Create the model asset bundle for the face stylizer task. + face_stylizer_models = { + _FACE_DETECTOR_MODEL: self._face_detector_model_buffer, + _FACE_LANDMARKS_DETECTOR_MODEL: ( + self._face_landmarks_detector_model_buffer + ), + _FACE_STYLIZER_MODEL: face_stylizer_model_buffer, + } + output_path = os.path.join(self._temp_folder.name, _FACE_STYLIZER_TASK) + model_asset_bundle_utils.create_model_asset_bundle( + face_stylizer_models, output_path + ) + with open(output_path, "rb") as f: + face_stylizer_model_bundle_buffer = f.read() + return face_stylizer_model_bundle_buffer, face_stylizer_metadata_json diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD b/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD index 66ddf54b8..863cc1a64 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD @@ -107,3 +107,16 @@ py_test( "//mediapipe/tasks/python/test:test_utils", ], ) + +py_test( + name = "face_stylizer_test", + srcs = ["face_stylizer_test.py"], + data = [ + "//mediapipe/tasks/testdata/metadata:data_files", + "//mediapipe/tasks/testdata/metadata:model_files", + ], + deps = [ + "//mediapipe/tasks/python/metadata/metadata_writers:face_stylizer", + "//mediapipe/tasks/python/test:test_utils", + ], +) diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/face_stylizer_test.py b/mediapipe/tasks/python/test/metadata/metadata_writers/face_stylizer_test.py new file mode 100644 index 000000000..127a4a47d --- /dev/null +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/face_stylizer_test.py @@ -0,0 +1,80 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for metadata_writer.face_stylizer.""" + +import os +import tempfile +import zipfile + +from absl.testing import absltest +from absl.testing import parameterized + +from mediapipe.tasks.python.metadata.metadata_writers import face_stylizer +from mediapipe.tasks.python.test import test_utils + +_TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata" +_NORM_MEAN = 0 +_NORM_STD = 255 +_TFLITE = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, "dummy_face_stylizer.tflite") +) +_EXPECTED_JSON = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, "face_stylizer.json") +) + + +class FaceStylizerTest(parameterized.TestCase): + + def test_write_metadata_and_create_model_asset_bundle_successful(self): + # Use dummy model buffer for unit test only. + with open(_TFLITE, "rb") as f: + face_stylizer_model_buffer = f.read() + face_detector_model_buffer = b"\x33\x44" + face_landmarks_detector_model_buffer = b"\x55\x66" + writer = face_stylizer.MetadataWriter.create( + face_stylizer_model_buffer, + face_detector_model_buffer, + face_landmarks_detector_model_buffer, + input_norm_mean=[_NORM_MEAN], + input_norm_std=[_NORM_STD], + ) + model_bundle_content, metadata_json = writer.populate() + with open(_EXPECTED_JSON, "r") as f: + expected_json = f.read() + self.assertEqual(metadata_json, expected_json) + + temp_folder = tempfile.TemporaryDirectory() + + # Checks the model bundle can be extracted successfully. + model_bundle_filepath = os.path.join(temp_folder.name, "face_stylizer.task") + + with open(model_bundle_filepath, "wb") as f: + f.write(model_bundle_content) + + with zipfile.ZipFile(model_bundle_filepath) as zf: + self.assertEqual( + set(zf.namelist()), + set([ + "face_detector.tflite", + "face_landmarks_detector.tflite", + "face_stylizer.tflite", + ]), + ) + zf.extractall(temp_folder.name) + temp_folder.cleanup() + + +if __name__ == "__main__": + absltest.main() diff --git a/mediapipe/tasks/testdata/metadata/BUILD b/mediapipe/tasks/testdata/metadata/BUILD index 710e9d8cf..25cac9e15 100644 --- a/mediapipe/tasks/testdata/metadata/BUILD +++ b/mediapipe/tasks/testdata/metadata/BUILD @@ -32,8 +32,10 @@ mediapipe_files(srcs = [ "deeplabv3_with_activation.json", "deeplabv3_without_labels.json", "deeplabv3_without_metadata.tflite", + "dummy_face_stylizer.tflite", "efficientdet_lite0_fp16_no_nms.tflite", "efficientdet_lite0_v1.tflite", + "face_stylizer.json", "labelmap.txt", "mobile_ica_8bit-with-custom-metadata.tflite", "mobile_ica_8bit-with-large-min-parser-version.tflite", @@ -96,6 +98,7 @@ filegroup( "bert_text_classifier_no_metadata.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite", "deeplabv3_without_metadata.tflite", + "dummy_face_stylizer.tflite", "efficientdet_lite0_fp16_no_nms.tflite", "efficientdet_lite0_v1.tflite", "mobile_ica_8bit-with-custom-metadata.tflite", @@ -132,6 +135,7 @@ filegroup( "efficientdet_lite0_fp16_no_nms_anchors.csv", "efficientdet_lite0_v1.json", "external_file", + "face_stylizer.json", "feature_tensor_meta.json", "general_meta.json", "golden_json.json", diff --git a/mediapipe/tasks/testdata/metadata/face_stylizer.json b/mediapipe/tasks/testdata/metadata/face_stylizer.json new file mode 100644 index 000000000..e987ad816 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/face_stylizer.json @@ -0,0 +1,47 @@ +{ + "name": "FaceStylizer", + "description": "Performs face stylization on images.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be processed.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 0.0 + ], + "std": [ + 255.0 + ] + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "PartitionedCall:0" + } + ] + } + ], + "min_parser_version": "1.0.0" +} diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index ff338bae2..af9361bb3 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -250,6 +250,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/dense.tflite?generation=1678218351373709"], ) + http_file( + name = "com_google_mediapipe_dummy_face_stylizer_tflite", + sha256 = "c44a32a673790aac4aca63ca4b4192b9870c21045241e69d9fe09b7ad1a38d65", + urls = ["https://storage.googleapis.com/mediapipe-assets/dummy_face_stylizer.tflite?generation=1682960595073526"], + ) + http_file( name = "com_google_mediapipe_dummy_gesture_recognizer_task", sha256 = "18e54586bda33300d459ca140cd045f6daf43d897224ba215a16db3423eae18e", @@ -418,6 +424,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/face_stylization_dummy.tflite?generation=1678323589048063"], ) + http_file( + name = "com_google_mediapipe_face_stylizer_json", + sha256 = "ad89860d5daba6a1c4163a576428713fc3ddab76d6bbaf06d675164423ae159f", + urls = ["https://storage.googleapis.com/mediapipe-assets/face_stylizer.json?generation=1682960598942694"], + ) + http_file( name = "com_google_mediapipe_face_stylizer_task", sha256 = "b34f3896cbe860468538cf5a562c0468964f182b8bb07cb527224312969d1625",