Revised Face Stylizer tests
This commit is contained in:
parent
1216c0f998
commit
d21460e17b
|
@ -15,9 +15,7 @@
|
|||
|
||||
import enum
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
|
@ -38,7 +36,7 @@ _FaceStylizerOptions = face_stylizer.FaceStylizerOptions
|
|||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_MODEL = 'face_stylizer.task'
|
||||
_MODEL = 'face_stylizer_color_ink.task'
|
||||
_LARGE_FACE_IMAGE = "portrait.jpg"
|
||||
_MODEL_IMAGE_SIZE = 256
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
|
@ -181,163 +179,6 @@ class FaceStylizerTest(parameterized.TestCase):
|
|||
stylized_image = stylizer.stylize(test_image, image_processing_options)
|
||||
self.assertIsNone(stylized_image)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _FaceStylizerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback must be provided'
|
||||
):
|
||||
with _FaceStylizer.create_from_options(options) as unused_stylizer:
|
||||
pass
|
||||
|
||||
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||
def test_illegal_result_callback(self, running_mode):
|
||||
options = _FaceStylizerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback should not be provided'
|
||||
):
|
||||
with _FaceStylizer.create_from_options(options) as unused_stylizer:
|
||||
pass
|
||||
|
||||
def test_calling_stylize_for_video_in_image_mode(self):
|
||||
options = _FaceStylizerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
stylizer.stylize_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_stylize_async_in_image_mode(self):
|
||||
options = _FaceStylizerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
stylizer.stylize_async(self.test_image, 0)
|
||||
|
||||
def test_calling_classify_in_video_mode(self):
|
||||
options = _FaceStylizerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
stylizer.stylize(self.test_image)
|
||||
|
||||
def test_calling_classify_async_in_video_mode(self):
|
||||
options = _FaceStylizerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
stylizer.stylize_async(self.test_image, 0)
|
||||
|
||||
def test_classify_for_video_with_out_of_order_timestamp(self):
|
||||
options = _FaceStylizerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
unused_result = stylizer.stylize_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
stylizer.stylize_for_video(self.test_image, 0)
|
||||
|
||||
def test_stylize_for_video(self):
|
||||
options = _FaceStylizerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
for timestamp in range(0, 300, 30):
|
||||
stylized_image = stylizer.stylize_for_video(
|
||||
self.test_image, timestamp
|
||||
)
|
||||
self.assertIsInstance(stylized_image, _Image)
|
||||
self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
|
||||
self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
|
||||
|
||||
def test_calling_stylize_in_live_stream_mode(self):
|
||||
options = _FaceStylizerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
stylizer.stylize(self.test_image)
|
||||
|
||||
def test_calling_stylize_for_video_in_live_stream_mode(self):
|
||||
options = _FaceStylizerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
stylizer.stylize_for_video(self.test_image, 0)
|
||||
|
||||
def test_stylize_async_calls_with_illegal_timestamp(self):
|
||||
options = _FaceStylizerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
stylizer.stylize_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
stylizer.stylize_async(self.test_image, 0)
|
||||
|
||||
def test_stylize_async_calls(self):
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(
|
||||
stylized_image: _Image, output_image: _Image, timestamp_ms: int
|
||||
):
|
||||
self.assertIsInstance(stylized_image, _Image)
|
||||
self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
|
||||
self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
output_image.numpy_view(), self.test_image.numpy_view()
|
||||
)
|
||||
)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _FaceStylizerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=check_result,
|
||||
)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
for timestamp in range(0, 300, 30):
|
||||
stylizer.stylize_async(self.test_image, timestamp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
|
@ -14,7 +14,7 @@
|
|||
"""MediaPipe face stylizer task."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
from typing import Optional, Callable
|
||||
|
||||
from mediapipe.python import packet_creator
|
||||
from mediapipe.python import packet_getter
|
||||
|
@ -31,6 +31,7 @@ _BaseOptions = base_options_module.BaseOptions
|
|||
_FaceStylizerGraphOptionsProto = (
|
||||
face_stylizer_graph_options_pb2.FaceStylizerGraphOptions
|
||||
)
|
||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
|
@ -54,11 +55,15 @@ class FaceStylizerOptions:
|
|||
"""
|
||||
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _FaceStylizerGraphOptionsProto:
|
||||
"""Generates an FaceStylizerOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = (
|
||||
False if self.running_mode == _RunningMode.IMAGE else True
|
||||
)
|
||||
return _FaceStylizerGraphOptionsProto(base_options=base_options_proto)
|
||||
|
||||
|
||||
|
|
1
mediapipe/tasks/testdata/vision/BUILD
vendored
1
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -48,7 +48,6 @@ mediapipe_files(srcs = [
|
|||
"face_landmark.tflite",
|
||||
"face_landmarker.task",
|
||||
"face_landmarker_v2.task",
|
||||
"face_stylizer_color_ink.task",
|
||||
"fist.jpg",
|
||||
"fist.png",
|
||||
"gesture_recognizer.task",
|
||||
|
|
Loading…
Reference in New Issue
Block a user