Internal change
PiperOrigin-RevId: 528517562
This commit is contained in:
parent
cab619f8da
commit
085f8265fb
|
@ -78,6 +78,15 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "face_stylizer",
|
||||
srcs = ["face_stylizer.py"],
|
||||
deps = [
|
||||
":metadata_writer",
|
||||
":model_asset_bundle_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_asset_bundle_utils",
|
||||
srcs = ["model_asset_bundle_utils.py"],
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Writes metadata and creates model asset bundle for face stylizer."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import List
|
||||
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import model_asset_bundle_utils
|
||||
|
||||
_MODEL_NAME = "FaceStylizer"
|
||||
_MODEL_DESCRIPTION = "Performs face stylization on images."
|
||||
_FACE_DETECTOR_MODEL = "face_detector.tflite"
|
||||
_FACE_LANDMARKS_DETECTOR_MODEL = "face_landmarks_detector.tflite"
|
||||
_FACE_STYLIZER_MODEL = "face_stylizer.tflite"
|
||||
_FACE_STYLIZER_TASK = "face_stylizer.task"
|
||||
|
||||
|
||||
class MetadataWriter:
|
||||
"""MetadataWriter to write the metadata for face stylizer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
face_detector_model_buffer: bytearray,
|
||||
face_landmarks_detector_model_buffer: bytearray,
|
||||
face_stylizer_metadata_writer: metadata_writer.MetadataWriter,
|
||||
) -> None:
|
||||
"""Initializes MetadataWriter to write the metadata and model asset bundle.
|
||||
|
||||
Args:
|
||||
face_detector_model_buffer: A valid flatbuffer loaded from the face
|
||||
detector TFLite model file with metadata already packed inside.
|
||||
face_landmarks_detector_model_buffer: A valid flatbuffer loaded from the
|
||||
face landmarks detector TFLite model file with metadata already packed
|
||||
inside.
|
||||
face_stylizer_metadata_writer: Metadata writer to write face stylizer
|
||||
metadata into the TFLite file.
|
||||
"""
|
||||
self._face_detector_model_buffer = face_detector_model_buffer
|
||||
self._face_landmarks_detector_model_buffer = (
|
||||
face_landmarks_detector_model_buffer
|
||||
)
|
||||
self._face_stylizer_metadata_writer = face_stylizer_metadata_writer
|
||||
self._temp_folder = tempfile.TemporaryDirectory()
|
||||
|
||||
def __del__(self):
|
||||
if os.path.exists(self._temp_folder.name):
|
||||
self._temp_folder.cleanup()
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
face_stylizer_model_buffer: bytearray,
|
||||
face_detector_model_buffer: bytearray,
|
||||
face_landmarks_detector_model_buffer: bytearray,
|
||||
input_norm_mean: List[float],
|
||||
input_norm_std: List[float],
|
||||
) -> "MetadataWriter":
|
||||
"""Creates MetadataWriter to write the metadata for face stylizer.
|
||||
|
||||
The parameters required in this method are mandatory when using MediaPipe
|
||||
Tasks.
|
||||
|
||||
Note that only the output TFLite is used for deployment. The output JSON
|
||||
content is used to interpret the metadata content.
|
||||
|
||||
Args:
|
||||
face_stylizer_model_buffer: A valid flatbuffer loaded from the face
|
||||
stylizer TFLite model file.
|
||||
face_detector_model_buffer: A valid flatbuffer loaded from the face
|
||||
detector TFLite model file with metadata already packed inside.
|
||||
face_landmarks_detector_model_buffer: A valid flatbuffer loaded from the
|
||||
face landmarks detector TFLite model file with metadata already packed
|
||||
inside.
|
||||
input_norm_mean: the mean value used in the input tensor normalization for
|
||||
face stylizer model [1].
|
||||
input_norm_std: the std value used in the input tensor normalizarion for
|
||||
face stylizer model [1].
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
|
||||
|
||||
Returns:
|
||||
A MetadataWriter object.
|
||||
"""
|
||||
face_stylizer_writer = metadata_writer.MetadataWriter(
|
||||
face_stylizer_model_buffer
|
||||
)
|
||||
face_stylizer_writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||
face_stylizer_writer.add_image_input(input_norm_mean, input_norm_std)
|
||||
return cls(
|
||||
face_detector_model_buffer,
|
||||
face_landmarks_detector_model_buffer,
|
||||
face_stylizer_writer,
|
||||
)
|
||||
|
||||
def populate(self):
|
||||
"""Populates the metadata and creates model asset bundle.
|
||||
|
||||
Note that only the output model asset bundle is used for deployment.
|
||||
The output JSON content is used to interpret the face stylizer metadata
|
||||
content.
|
||||
|
||||
Returns:
|
||||
A tuple of (model_asset_bundle_in_bytes, metadata_json_content)
|
||||
"""
|
||||
# Write metadata into the face stylizer TFLite model.
|
||||
face_stylizer_model_buffer, face_stylizer_metadata_json = (
|
||||
self._face_stylizer_metadata_writer.populate()
|
||||
)
|
||||
# Create the model asset bundle for the face stylizer task.
|
||||
face_stylizer_models = {
|
||||
_FACE_DETECTOR_MODEL: self._face_detector_model_buffer,
|
||||
_FACE_LANDMARKS_DETECTOR_MODEL: (
|
||||
self._face_landmarks_detector_model_buffer
|
||||
),
|
||||
_FACE_STYLIZER_MODEL: face_stylizer_model_buffer,
|
||||
}
|
||||
output_path = os.path.join(self._temp_folder.name, _FACE_STYLIZER_TASK)
|
||||
model_asset_bundle_utils.create_model_asset_bundle(
|
||||
face_stylizer_models, output_path
|
||||
)
|
||||
with open(output_path, "rb") as f:
|
||||
face_stylizer_model_bundle_buffer = f.read()
|
||||
return face_stylizer_model_bundle_buffer, face_stylizer_metadata_json
|
|
@ -107,3 +107,16 @@ py_test(
|
|||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "face_stylizer_test",
|
||||
srcs = ["face_stylizer_test.py"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/metadata:data_files",
|
||||
"//mediapipe/tasks/testdata/metadata:model_files",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:face_stylizer",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for metadata_writer.face_stylizer."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import face_stylizer
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata"
|
||||
_NORM_MEAN = 0
|
||||
_NORM_STD = 255
|
||||
_TFLITE = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "dummy_face_stylizer.tflite")
|
||||
)
|
||||
_EXPECTED_JSON = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "face_stylizer.json")
|
||||
)
|
||||
|
||||
|
||||
class FaceStylizerTest(parameterized.TestCase):
|
||||
|
||||
def test_write_metadata_and_create_model_asset_bundle_successful(self):
|
||||
# Use dummy model buffer for unit test only.
|
||||
with open(_TFLITE, "rb") as f:
|
||||
face_stylizer_model_buffer = f.read()
|
||||
face_detector_model_buffer = b"\x33\x44"
|
||||
face_landmarks_detector_model_buffer = b"\x55\x66"
|
||||
writer = face_stylizer.MetadataWriter.create(
|
||||
face_stylizer_model_buffer,
|
||||
face_detector_model_buffer,
|
||||
face_landmarks_detector_model_buffer,
|
||||
input_norm_mean=[_NORM_MEAN],
|
||||
input_norm_std=[_NORM_STD],
|
||||
)
|
||||
model_bundle_content, metadata_json = writer.populate()
|
||||
with open(_EXPECTED_JSON, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
temp_folder = tempfile.TemporaryDirectory()
|
||||
|
||||
# Checks the model bundle can be extracted successfully.
|
||||
model_bundle_filepath = os.path.join(temp_folder.name, "face_stylizer.task")
|
||||
|
||||
with open(model_bundle_filepath, "wb") as f:
|
||||
f.write(model_bundle_content)
|
||||
|
||||
with zipfile.ZipFile(model_bundle_filepath) as zf:
|
||||
self.assertEqual(
|
||||
set(zf.namelist()),
|
||||
set([
|
||||
"face_detector.tflite",
|
||||
"face_landmarks_detector.tflite",
|
||||
"face_stylizer.tflite",
|
||||
]),
|
||||
)
|
||||
zf.extractall(temp_folder.name)
|
||||
temp_folder.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
4
mediapipe/tasks/testdata/metadata/BUILD
vendored
4
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -32,8 +32,10 @@ mediapipe_files(srcs = [
|
|||
"deeplabv3_with_activation.json",
|
||||
"deeplabv3_without_labels.json",
|
||||
"deeplabv3_without_metadata.tflite",
|
||||
"dummy_face_stylizer.tflite",
|
||||
"efficientdet_lite0_fp16_no_nms.tflite",
|
||||
"efficientdet_lite0_v1.tflite",
|
||||
"face_stylizer.json",
|
||||
"labelmap.txt",
|
||||
"mobile_ica_8bit-with-custom-metadata.tflite",
|
||||
"mobile_ica_8bit-with-large-min-parser-version.tflite",
|
||||
|
@ -96,6 +98,7 @@ filegroup(
|
|||
"bert_text_classifier_no_metadata.tflite",
|
||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite",
|
||||
"deeplabv3_without_metadata.tflite",
|
||||
"dummy_face_stylizer.tflite",
|
||||
"efficientdet_lite0_fp16_no_nms.tflite",
|
||||
"efficientdet_lite0_v1.tflite",
|
||||
"mobile_ica_8bit-with-custom-metadata.tflite",
|
||||
|
@ -132,6 +135,7 @@ filegroup(
|
|||
"efficientdet_lite0_fp16_no_nms_anchors.csv",
|
||||
"efficientdet_lite0_v1.json",
|
||||
"external_file",
|
||||
"face_stylizer.json",
|
||||
"feature_tensor_meta.json",
|
||||
"general_meta.json",
|
||||
"golden_json.json",
|
||||
|
|
47
mediapipe/tasks/testdata/metadata/face_stylizer.json
vendored
Normal file
47
mediapipe/tasks/testdata/metadata/face_stylizer.json
vendored
Normal file
|
@ -0,0 +1,47 @@
|
|||
{
|
||||
"name": "FaceStylizer",
|
||||
"description": "Performs face stylization on images.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "image",
|
||||
"description": "Input image to be processed.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "RGB"
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "NormalizationOptions",
|
||||
"options": {
|
||||
"mean": [
|
||||
0.0
|
||||
],
|
||||
"std": [
|
||||
255.0
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "PartitionedCall:0"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.0.0"
|
||||
}
|
12
third_party/external_files.bzl
vendored
12
third_party/external_files.bzl
vendored
|
@ -250,6 +250,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/dense.tflite?generation=1678218351373709"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_dummy_face_stylizer_tflite",
|
||||
sha256 = "c44a32a673790aac4aca63ca4b4192b9870c21045241e69d9fe09b7ad1a38d65",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/dummy_face_stylizer.tflite?generation=1682960595073526"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_dummy_gesture_recognizer_task",
|
||||
sha256 = "18e54586bda33300d459ca140cd045f6daf43d897224ba215a16db3423eae18e",
|
||||
|
@ -418,6 +424,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/face_stylization_dummy.tflite?generation=1678323589048063"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_face_stylizer_json",
|
||||
sha256 = "ad89860d5daba6a1c4163a576428713fc3ddab76d6bbaf06d675164423ae159f",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/face_stylizer.json?generation=1682960598942694"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_face_stylizer_task",
|
||||
sha256 = "b34f3896cbe860468538cf5a562c0468964f182b8bb07cb527224312969d1625",
|
||||
|
|
Loading…
Reference in New Issue
Block a user