Revised Face Stylizer tests

This commit is contained in:
Kinar 2023-09-13 22:54:32 +05:30
parent 1216c0f998
commit d21460e17b
3 changed files with 7 additions and 162 deletions

View File

@ -15,9 +15,7 @@
import enum
import os
from unittest import mock
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
@ -38,7 +36,7 @@ _FaceStylizerOptions = face_stylizer.FaceStylizerOptions
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_MODEL = 'face_stylizer.task'
_MODEL = 'face_stylizer_color_ink.task'
_LARGE_FACE_IMAGE = "portrait.jpg"
_MODEL_IMAGE_SIZE = 256
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
@ -181,163 +179,6 @@ class FaceStylizerTest(parameterized.TestCase):
stylized_image = stylizer.stylize(test_image, image_processing_options)
self.assertIsNone(stylized_image)
def test_missing_result_callback(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
)
with self.assertRaisesRegex(
ValueError, r'result callback must be provided'
):
with _FaceStylizer.create_from_options(options) as unused_stylizer:
pass
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock(),
)
with self.assertRaisesRegex(
ValueError, r'result callback should not be provided'
):
with _FaceStylizer.create_from_options(options) as unused_stylizer:
pass
def test_calling_stylize_for_video_in_image_mode(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE,
)
with _FaceStylizer.create_from_options(options) as stylizer:
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
stylizer.stylize_for_video(self.test_image, 0)
def test_calling_stylize_async_in_image_mode(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE,
)
with _FaceStylizer.create_from_options(options) as stylizer:
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
stylizer.stylize_async(self.test_image, 0)
def test_calling_classify_in_video_mode(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _FaceStylizer.create_from_options(options) as stylizer:
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
stylizer.stylize(self.test_image)
def test_calling_classify_async_in_video_mode(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _FaceStylizer.create_from_options(options) as stylizer:
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
stylizer.stylize_async(self.test_image, 0)
def test_classify_for_video_with_out_of_order_timestamp(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _FaceStylizer.create_from_options(options) as stylizer:
unused_result = stylizer.stylize_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'
):
stylizer.stylize_for_video(self.test_image, 0)
def test_stylize_for_video(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _FaceStylizer.create_from_options(options) as stylizer:
for timestamp in range(0, 300, 30):
stylized_image = stylizer.stylize_for_video(
self.test_image, timestamp
)
self.assertIsInstance(stylized_image, _Image)
self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
def test_calling_stylize_in_live_stream_mode(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock(),
)
with _FaceStylizer.create_from_options(options) as stylizer:
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
stylizer.stylize(self.test_image)
def test_calling_stylize_for_video_in_live_stream_mode(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock(),
)
with _FaceStylizer.create_from_options(options) as stylizer:
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
stylizer.stylize_for_video(self.test_image, 0)
def test_stylize_async_calls_with_illegal_timestamp(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock(),
)
with _FaceStylizer.create_from_options(options) as stylizer:
stylizer.stylize_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'
):
stylizer.stylize_async(self.test_image, 0)
def test_stylize_async_calls(self):
observed_timestamp_ms = -1
def check_result(
stylized_image: _Image, output_image: _Image, timestamp_ms: int
):
self.assertIsInstance(stylized_image, _Image)
self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
self.assertTrue(
np.array_equal(
output_image.numpy_view(), self.test_image.numpy_view()
)
)
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result,
)
with _FaceStylizer.create_from_options(options) as stylizer:
for timestamp in range(0, 300, 30):
stylizer.stylize_async(self.test_image, timestamp)
if __name__ == '__main__':
absltest.main()

View File

@ -14,7 +14,7 @@
"""MediaPipe face stylizer task."""
import dataclasses
from typing import Optional
from typing import Optional, Callable
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
@ -31,6 +31,7 @@ _BaseOptions = base_options_module.BaseOptions
_FaceStylizerGraphOptionsProto = (
face_stylizer_graph_options_pb2.FaceStylizerGraphOptions
)
_RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
@ -54,11 +55,15 @@ class FaceStylizerOptions:
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _FaceStylizerGraphOptionsProto:
"""Generates an FaceStylizerOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
return _FaceStylizerGraphOptionsProto(base_options=base_options_proto)

View File

@ -48,7 +48,6 @@ mediapipe_files(srcs = [
"face_landmark.tflite",
"face_landmarker.task",
"face_landmarker_v2.task",
"face_stylizer_color_ink.task",
"fist.jpg",
"fist.png",
"gesture_recognizer.task",