Revised Face Stylizer tests
This commit is contained in:
parent
1216c0f998
commit
d21460e17b
|
@ -15,9 +15,7 @@
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import os
|
import os
|
||||||
from unittest import mock
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
@ -38,7 +36,7 @@ _FaceStylizerOptions = face_stylizer.FaceStylizerOptions
|
||||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
|
|
||||||
_MODEL = 'face_stylizer.task'
|
_MODEL = 'face_stylizer_color_ink.task'
|
||||||
_LARGE_FACE_IMAGE = "portrait.jpg"
|
_LARGE_FACE_IMAGE = "portrait.jpg"
|
||||||
_MODEL_IMAGE_SIZE = 256
|
_MODEL_IMAGE_SIZE = 256
|
||||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||||
|
@ -181,163 +179,6 @@ class FaceStylizerTest(parameterized.TestCase):
|
||||||
stylized_image = stylizer.stylize(test_image, image_processing_options)
|
stylized_image = stylizer.stylize(test_image, image_processing_options)
|
||||||
self.assertIsNone(stylized_image)
|
self.assertIsNone(stylized_image)
|
||||||
|
|
||||||
def test_missing_result_callback(self):
|
|
||||||
options = _FaceStylizerOptions(
|
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
|
||||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
|
||||||
)
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError, r'result callback must be provided'
|
|
||||||
):
|
|
||||||
with _FaceStylizer.create_from_options(options) as unused_stylizer:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
|
||||||
def test_illegal_result_callback(self, running_mode):
|
|
||||||
options = _FaceStylizerOptions(
|
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
|
||||||
running_mode=running_mode,
|
|
||||||
result_callback=mock.MagicMock(),
|
|
||||||
)
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError, r'result callback should not be provided'
|
|
||||||
):
|
|
||||||
with _FaceStylizer.create_from_options(options) as unused_stylizer:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_calling_stylize_for_video_in_image_mode(self):
|
|
||||||
options = _FaceStylizerOptions(
|
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
|
||||||
running_mode=_RUNNING_MODE.IMAGE,
|
|
||||||
)
|
|
||||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError, r'not initialized with the video mode'
|
|
||||||
):
|
|
||||||
stylizer.stylize_for_video(self.test_image, 0)
|
|
||||||
|
|
||||||
def test_calling_stylize_async_in_image_mode(self):
|
|
||||||
options = _FaceStylizerOptions(
|
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
|
||||||
running_mode=_RUNNING_MODE.IMAGE,
|
|
||||||
)
|
|
||||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError, r'not initialized with the live stream mode'
|
|
||||||
):
|
|
||||||
stylizer.stylize_async(self.test_image, 0)
|
|
||||||
|
|
||||||
def test_calling_classify_in_video_mode(self):
|
|
||||||
options = _FaceStylizerOptions(
|
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
|
||||||
running_mode=_RUNNING_MODE.VIDEO,
|
|
||||||
)
|
|
||||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError, r'not initialized with the image mode'
|
|
||||||
):
|
|
||||||
stylizer.stylize(self.test_image)
|
|
||||||
|
|
||||||
def test_calling_classify_async_in_video_mode(self):
|
|
||||||
options = _FaceStylizerOptions(
|
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
|
||||||
running_mode=_RUNNING_MODE.VIDEO,
|
|
||||||
)
|
|
||||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError, r'not initialized with the live stream mode'
|
|
||||||
):
|
|
||||||
stylizer.stylize_async(self.test_image, 0)
|
|
||||||
|
|
||||||
def test_classify_for_video_with_out_of_order_timestamp(self):
|
|
||||||
options = _FaceStylizerOptions(
|
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
|
||||||
running_mode=_RUNNING_MODE.VIDEO,
|
|
||||||
)
|
|
||||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
|
||||||
unused_result = stylizer.stylize_for_video(self.test_image, 1)
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError, r'Input timestamp must be monotonically increasing'
|
|
||||||
):
|
|
||||||
stylizer.stylize_for_video(self.test_image, 0)
|
|
||||||
|
|
||||||
def test_stylize_for_video(self):
|
|
||||||
options = _FaceStylizerOptions(
|
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
|
||||||
running_mode=_RUNNING_MODE.VIDEO,
|
|
||||||
)
|
|
||||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
|
||||||
for timestamp in range(0, 300, 30):
|
|
||||||
stylized_image = stylizer.stylize_for_video(
|
|
||||||
self.test_image, timestamp
|
|
||||||
)
|
|
||||||
self.assertIsInstance(stylized_image, _Image)
|
|
||||||
self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
|
|
||||||
self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
|
|
||||||
|
|
||||||
def test_calling_stylize_in_live_stream_mode(self):
|
|
||||||
options = _FaceStylizerOptions(
|
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
|
||||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
|
||||||
result_callback=mock.MagicMock(),
|
|
||||||
)
|
|
||||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError, r'not initialized with the image mode'
|
|
||||||
):
|
|
||||||
stylizer.stylize(self.test_image)
|
|
||||||
|
|
||||||
def test_calling_stylize_for_video_in_live_stream_mode(self):
|
|
||||||
options = _FaceStylizerOptions(
|
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
|
||||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
|
||||||
result_callback=mock.MagicMock(),
|
|
||||||
)
|
|
||||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError, r'not initialized with the video mode'
|
|
||||||
):
|
|
||||||
stylizer.stylize_for_video(self.test_image, 0)
|
|
||||||
|
|
||||||
def test_stylize_async_calls_with_illegal_timestamp(self):
|
|
||||||
options = _FaceStylizerOptions(
|
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
|
||||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
|
||||||
result_callback=mock.MagicMock(),
|
|
||||||
)
|
|
||||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
|
||||||
stylizer.stylize_async(self.test_image, 100)
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError, r'Input timestamp must be monotonically increasing'
|
|
||||||
):
|
|
||||||
stylizer.stylize_async(self.test_image, 0)
|
|
||||||
|
|
||||||
def test_stylize_async_calls(self):
|
|
||||||
observed_timestamp_ms = -1
|
|
||||||
|
|
||||||
def check_result(
|
|
||||||
stylized_image: _Image, output_image: _Image, timestamp_ms: int
|
|
||||||
):
|
|
||||||
self.assertIsInstance(stylized_image, _Image)
|
|
||||||
self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
|
|
||||||
self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
|
|
||||||
self.assertTrue(
|
|
||||||
np.array_equal(
|
|
||||||
output_image.numpy_view(), self.test_image.numpy_view()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
|
||||||
self.observed_timestamp_ms = timestamp_ms
|
|
||||||
|
|
||||||
options = _FaceStylizerOptions(
|
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
|
||||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
|
||||||
result_callback=check_result,
|
|
||||||
)
|
|
||||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
|
||||||
for timestamp in range(0, 300, 30):
|
|
||||||
stylizer.stylize_async(self.test_image, timestamp)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
|
@ -14,7 +14,7 @@
|
||||||
"""MediaPipe face stylizer task."""
|
"""MediaPipe face stylizer task."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Optional
|
from typing import Optional, Callable
|
||||||
|
|
||||||
from mediapipe.python import packet_creator
|
from mediapipe.python import packet_creator
|
||||||
from mediapipe.python import packet_getter
|
from mediapipe.python import packet_getter
|
||||||
|
@ -31,6 +31,7 @@ _BaseOptions = base_options_module.BaseOptions
|
||||||
_FaceStylizerGraphOptionsProto = (
|
_FaceStylizerGraphOptionsProto = (
|
||||||
face_stylizer_graph_options_pb2.FaceStylizerGraphOptions
|
face_stylizer_graph_options_pb2.FaceStylizerGraphOptions
|
||||||
)
|
)
|
||||||
|
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
_TaskInfo = task_info_module.TaskInfo
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
|
||||||
|
@ -54,11 +55,15 @@ class FaceStylizerOptions:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
base_options: _BaseOptions
|
base_options: _BaseOptions
|
||||||
|
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def to_pb2(self) -> _FaceStylizerGraphOptionsProto:
|
def to_pb2(self) -> _FaceStylizerGraphOptionsProto:
|
||||||
"""Generates an FaceStylizerOptions protobuf object."""
|
"""Generates an FaceStylizerOptions protobuf object."""
|
||||||
base_options_proto = self.base_options.to_pb2()
|
base_options_proto = self.base_options.to_pb2()
|
||||||
|
base_options_proto.use_stream_mode = (
|
||||||
|
False if self.running_mode == _RunningMode.IMAGE else True
|
||||||
|
)
|
||||||
return _FaceStylizerGraphOptionsProto(base_options=base_options_proto)
|
return _FaceStylizerGraphOptionsProto(base_options=base_options_proto)
|
||||||
|
|
||||||
|
|
||||||
|
|
1
mediapipe/tasks/testdata/vision/BUILD
vendored
1
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -48,7 +48,6 @@ mediapipe_files(srcs = [
|
||||||
"face_landmark.tflite",
|
"face_landmark.tflite",
|
||||||
"face_landmarker.task",
|
"face_landmarker.task",
|
||||||
"face_landmarker_v2.task",
|
"face_landmarker_v2.task",
|
||||||
"face_stylizer_color_ink.task",
|
|
||||||
"fist.jpg",
|
"fist.jpg",
|
||||||
"fist.png",
|
"fist.png",
|
||||||
"gesture_recognizer.task",
|
"gesture_recognizer.task",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user