Internal change

PiperOrigin-RevId: 528517562
This commit is contained in:
Yuqi Li 2023-05-01 10:59:22 -07:00 committed by Copybara-Service
parent cab619f8da
commit 085f8265fb
7 changed files with 303 additions and 0 deletions

View File

@ -78,6 +78,15 @@ py_library(
], ],
) )
py_library(
name = "face_stylizer",
srcs = ["face_stylizer.py"],
deps = [
":metadata_writer",
":model_asset_bundle_utils",
],
)
py_library( py_library(
name = "model_asset_bundle_utils", name = "model_asset_bundle_utils",
srcs = ["model_asset_bundle_utils.py"], srcs = ["model_asset_bundle_utils.py"],

View File

@ -0,0 +1,138 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Writes metadata and creates model asset bundle for face stylizer."""
import os
import tempfile
from typing import List
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import model_asset_bundle_utils
_MODEL_NAME = "FaceStylizer"
_MODEL_DESCRIPTION = "Performs face stylization on images."
_FACE_DETECTOR_MODEL = "face_detector.tflite"
_FACE_LANDMARKS_DETECTOR_MODEL = "face_landmarks_detector.tflite"
_FACE_STYLIZER_MODEL = "face_stylizer.tflite"
_FACE_STYLIZER_TASK = "face_stylizer.task"
class MetadataWriter:
"""MetadataWriter to write the metadata for face stylizer."""
def __init__(
self,
face_detector_model_buffer: bytearray,
face_landmarks_detector_model_buffer: bytearray,
face_stylizer_metadata_writer: metadata_writer.MetadataWriter,
) -> None:
"""Initializes MetadataWriter to write the metadata and model asset bundle.
Args:
face_detector_model_buffer: A valid flatbuffer loaded from the face
detector TFLite model file with metadata already packed inside.
face_landmarks_detector_model_buffer: A valid flatbuffer loaded from the
face landmarks detector TFLite model file with metadata already packed
inside.
face_stylizer_metadata_writer: Metadata writer to write face stylizer
metadata into the TFLite file.
"""
self._face_detector_model_buffer = face_detector_model_buffer
self._face_landmarks_detector_model_buffer = (
face_landmarks_detector_model_buffer
)
self._face_stylizer_metadata_writer = face_stylizer_metadata_writer
self._temp_folder = tempfile.TemporaryDirectory()
def __del__(self):
if os.path.exists(self._temp_folder.name):
self._temp_folder.cleanup()
@classmethod
def create(
cls,
face_stylizer_model_buffer: bytearray,
face_detector_model_buffer: bytearray,
face_landmarks_detector_model_buffer: bytearray,
input_norm_mean: List[float],
input_norm_std: List[float],
) -> "MetadataWriter":
"""Creates MetadataWriter to write the metadata for face stylizer.
The parameters required in this method are mandatory when using MediaPipe
Tasks.
Note that only the output TFLite is used for deployment. The output JSON
content is used to interpret the metadata content.
Args:
face_stylizer_model_buffer: A valid flatbuffer loaded from the face
stylizer TFLite model file.
face_detector_model_buffer: A valid flatbuffer loaded from the face
detector TFLite model file with metadata already packed inside.
face_landmarks_detector_model_buffer: A valid flatbuffer loaded from the
face landmarks detector TFLite model file with metadata already packed
inside.
input_norm_mean: the mean value used in the input tensor normalization for
face stylizer model [1].
input_norm_std: the std value used in the input tensor normalizarion for
face stylizer model [1].
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
Returns:
A MetadataWriter object.
"""
face_stylizer_writer = metadata_writer.MetadataWriter(
face_stylizer_model_buffer
)
face_stylizer_writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
face_stylizer_writer.add_image_input(input_norm_mean, input_norm_std)
return cls(
face_detector_model_buffer,
face_landmarks_detector_model_buffer,
face_stylizer_writer,
)
def populate(self):
"""Populates the metadata and creates model asset bundle.
Note that only the output model asset bundle is used for deployment.
The output JSON content is used to interpret the face stylizer metadata
content.
Returns:
A tuple of (model_asset_bundle_in_bytes, metadata_json_content)
"""
# Write metadata into the face stylizer TFLite model.
face_stylizer_model_buffer, face_stylizer_metadata_json = (
self._face_stylizer_metadata_writer.populate()
)
# Create the model asset bundle for the face stylizer task.
face_stylizer_models = {
_FACE_DETECTOR_MODEL: self._face_detector_model_buffer,
_FACE_LANDMARKS_DETECTOR_MODEL: (
self._face_landmarks_detector_model_buffer
),
_FACE_STYLIZER_MODEL: face_stylizer_model_buffer,
}
output_path = os.path.join(self._temp_folder.name, _FACE_STYLIZER_TASK)
model_asset_bundle_utils.create_model_asset_bundle(
face_stylizer_models, output_path
)
with open(output_path, "rb") as f:
face_stylizer_model_bundle_buffer = f.read()
return face_stylizer_model_bundle_buffer, face_stylizer_metadata_json

View File

@ -107,3 +107,16 @@ py_test(
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
) )
py_test(
name = "face_stylizer_test",
srcs = ["face_stylizer_test.py"],
data = [
"//mediapipe/tasks/testdata/metadata:data_files",
"//mediapipe/tasks/testdata/metadata:model_files",
],
deps = [
"//mediapipe/tasks/python/metadata/metadata_writers:face_stylizer",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -0,0 +1,80 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metadata_writer.face_stylizer."""
import os
import tempfile
import zipfile
from absl.testing import absltest
from absl.testing import parameterized
from mediapipe.tasks.python.metadata.metadata_writers import face_stylizer
from mediapipe.tasks.python.test import test_utils
_TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata"
_NORM_MEAN = 0
_NORM_STD = 255
_TFLITE = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "dummy_face_stylizer.tflite")
)
_EXPECTED_JSON = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "face_stylizer.json")
)
class FaceStylizerTest(parameterized.TestCase):
def test_write_metadata_and_create_model_asset_bundle_successful(self):
# Use dummy model buffer for unit test only.
with open(_TFLITE, "rb") as f:
face_stylizer_model_buffer = f.read()
face_detector_model_buffer = b"\x33\x44"
face_landmarks_detector_model_buffer = b"\x55\x66"
writer = face_stylizer.MetadataWriter.create(
face_stylizer_model_buffer,
face_detector_model_buffer,
face_landmarks_detector_model_buffer,
input_norm_mean=[_NORM_MEAN],
input_norm_std=[_NORM_STD],
)
model_bundle_content, metadata_json = writer.populate()
with open(_EXPECTED_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
temp_folder = tempfile.TemporaryDirectory()
# Checks the model bundle can be extracted successfully.
model_bundle_filepath = os.path.join(temp_folder.name, "face_stylizer.task")
with open(model_bundle_filepath, "wb") as f:
f.write(model_bundle_content)
with zipfile.ZipFile(model_bundle_filepath) as zf:
self.assertEqual(
set(zf.namelist()),
set([
"face_detector.tflite",
"face_landmarks_detector.tflite",
"face_stylizer.tflite",
]),
)
zf.extractall(temp_folder.name)
temp_folder.cleanup()
if __name__ == "__main__":
absltest.main()

View File

@ -32,8 +32,10 @@ mediapipe_files(srcs = [
"deeplabv3_with_activation.json", "deeplabv3_with_activation.json",
"deeplabv3_without_labels.json", "deeplabv3_without_labels.json",
"deeplabv3_without_metadata.tflite", "deeplabv3_without_metadata.tflite",
"dummy_face_stylizer.tflite",
"efficientdet_lite0_fp16_no_nms.tflite", "efficientdet_lite0_fp16_no_nms.tflite",
"efficientdet_lite0_v1.tflite", "efficientdet_lite0_v1.tflite",
"face_stylizer.json",
"labelmap.txt", "labelmap.txt",
"mobile_ica_8bit-with-custom-metadata.tflite", "mobile_ica_8bit-with-custom-metadata.tflite",
"mobile_ica_8bit-with-large-min-parser-version.tflite", "mobile_ica_8bit-with-large-min-parser-version.tflite",
@ -96,6 +98,7 @@ filegroup(
"bert_text_classifier_no_metadata.tflite", "bert_text_classifier_no_metadata.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite",
"deeplabv3_without_metadata.tflite", "deeplabv3_without_metadata.tflite",
"dummy_face_stylizer.tflite",
"efficientdet_lite0_fp16_no_nms.tflite", "efficientdet_lite0_fp16_no_nms.tflite",
"efficientdet_lite0_v1.tflite", "efficientdet_lite0_v1.tflite",
"mobile_ica_8bit-with-custom-metadata.tflite", "mobile_ica_8bit-with-custom-metadata.tflite",
@ -132,6 +135,7 @@ filegroup(
"efficientdet_lite0_fp16_no_nms_anchors.csv", "efficientdet_lite0_fp16_no_nms_anchors.csv",
"efficientdet_lite0_v1.json", "efficientdet_lite0_v1.json",
"external_file", "external_file",
"face_stylizer.json",
"feature_tensor_meta.json", "feature_tensor_meta.json",
"general_meta.json", "general_meta.json",
"golden_json.json", "golden_json.json",

View File

@ -0,0 +1,47 @@
{
"name": "FaceStylizer",
"description": "Performs face stylization on images.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "image",
"description": "Input image to be processed.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
"color_space": "RGB"
}
},
"process_units": [
{
"options_type": "NormalizationOptions",
"options": {
"mean": [
0.0
],
"std": [
255.0
]
}
}
],
"stats": {
"max": [
1.0
],
"min": [
0.0
]
}
}
],
"output_tensor_metadata": [
{
"name": "PartitionedCall:0"
}
]
}
],
"min_parser_version": "1.0.0"
}

View File

@ -250,6 +250,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/dense.tflite?generation=1678218351373709"], urls = ["https://storage.googleapis.com/mediapipe-assets/dense.tflite?generation=1678218351373709"],
) )
http_file(
name = "com_google_mediapipe_dummy_face_stylizer_tflite",
sha256 = "c44a32a673790aac4aca63ca4b4192b9870c21045241e69d9fe09b7ad1a38d65",
urls = ["https://storage.googleapis.com/mediapipe-assets/dummy_face_stylizer.tflite?generation=1682960595073526"],
)
http_file( http_file(
name = "com_google_mediapipe_dummy_gesture_recognizer_task", name = "com_google_mediapipe_dummy_gesture_recognizer_task",
sha256 = "18e54586bda33300d459ca140cd045f6daf43d897224ba215a16db3423eae18e", sha256 = "18e54586bda33300d459ca140cd045f6daf43d897224ba215a16db3423eae18e",
@ -418,6 +424,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/face_stylization_dummy.tflite?generation=1678323589048063"], urls = ["https://storage.googleapis.com/mediapipe-assets/face_stylization_dummy.tflite?generation=1678323589048063"],
) )
http_file(
name = "com_google_mediapipe_face_stylizer_json",
sha256 = "ad89860d5daba6a1c4163a576428713fc3ddab76d6bbaf06d675164423ae159f",
urls = ["https://storage.googleapis.com/mediapipe-assets/face_stylizer.json?generation=1682960598942694"],
)
http_file( http_file(
name = "com_google_mediapipe_face_stylizer_task", name = "com_google_mediapipe_face_stylizer_task",
sha256 = "b34f3896cbe860468538cf5a562c0468964f182b8bb07cb527224312969d1625", sha256 = "b34f3896cbe860468538cf5a562c0468964f182b8bb07cb527224312969d1625",