Revised Face Stylizer tests

This commit is contained in:
Kinar 2023-09-13 22:54:32 +05:30
parent 1216c0f998
commit d21460e17b
3 changed files with 7 additions and 162 deletions

View File

@ -15,9 +15,7 @@
import enum import enum
import os import os
from unittest import mock
import numpy as np
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
@ -38,7 +36,7 @@ _FaceStylizerOptions = face_stylizer.FaceStylizerOptions
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_MODEL = 'face_stylizer.task' _MODEL = 'face_stylizer_color_ink.task'
_LARGE_FACE_IMAGE = "portrait.jpg" _LARGE_FACE_IMAGE = "portrait.jpg"
_MODEL_IMAGE_SIZE = 256 _MODEL_IMAGE_SIZE = 256
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
@ -181,163 +179,6 @@ class FaceStylizerTest(parameterized.TestCase):
stylized_image = stylizer.stylize(test_image, image_processing_options) stylized_image = stylizer.stylize(test_image, image_processing_options)
self.assertIsNone(stylized_image) self.assertIsNone(stylized_image)
def test_missing_result_callback(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
)
with self.assertRaisesRegex(
ValueError, r'result callback must be provided'
):
with _FaceStylizer.create_from_options(options) as unused_stylizer:
pass
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock(),
)
with self.assertRaisesRegex(
ValueError, r'result callback should not be provided'
):
with _FaceStylizer.create_from_options(options) as unused_stylizer:
pass
def test_calling_stylize_for_video_in_image_mode(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE,
)
with _FaceStylizer.create_from_options(options) as stylizer:
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
stylizer.stylize_for_video(self.test_image, 0)
def test_calling_stylize_async_in_image_mode(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE,
)
with _FaceStylizer.create_from_options(options) as stylizer:
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
stylizer.stylize_async(self.test_image, 0)
def test_calling_classify_in_video_mode(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _FaceStylizer.create_from_options(options) as stylizer:
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
stylizer.stylize(self.test_image)
def test_calling_classify_async_in_video_mode(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _FaceStylizer.create_from_options(options) as stylizer:
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
stylizer.stylize_async(self.test_image, 0)
def test_classify_for_video_with_out_of_order_timestamp(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _FaceStylizer.create_from_options(options) as stylizer:
unused_result = stylizer.stylize_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'
):
stylizer.stylize_for_video(self.test_image, 0)
def test_stylize_for_video(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _FaceStylizer.create_from_options(options) as stylizer:
for timestamp in range(0, 300, 30):
stylized_image = stylizer.stylize_for_video(
self.test_image, timestamp
)
self.assertIsInstance(stylized_image, _Image)
self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
def test_calling_stylize_in_live_stream_mode(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock(),
)
with _FaceStylizer.create_from_options(options) as stylizer:
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
stylizer.stylize(self.test_image)
def test_calling_stylize_for_video_in_live_stream_mode(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock(),
)
with _FaceStylizer.create_from_options(options) as stylizer:
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
stylizer.stylize_for_video(self.test_image, 0)
def test_stylize_async_calls_with_illegal_timestamp(self):
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock(),
)
with _FaceStylizer.create_from_options(options) as stylizer:
stylizer.stylize_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'
):
stylizer.stylize_async(self.test_image, 0)
def test_stylize_async_calls(self):
observed_timestamp_ms = -1
def check_result(
stylized_image: _Image, output_image: _Image, timestamp_ms: int
):
self.assertIsInstance(stylized_image, _Image)
self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
self.assertTrue(
np.array_equal(
output_image.numpy_view(), self.test_image.numpy_view()
)
)
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
options = _FaceStylizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result,
)
with _FaceStylizer.create_from_options(options) as stylizer:
for timestamp in range(0, 300, 30):
stylizer.stylize_async(self.test_image, timestamp)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -14,7 +14,7 @@
"""MediaPipe face stylizer task.""" """MediaPipe face stylizer task."""
import dataclasses import dataclasses
from typing import Optional from typing import Optional, Callable
from mediapipe.python import packet_creator from mediapipe.python import packet_creator
from mediapipe.python import packet_getter from mediapipe.python import packet_getter
@ -31,6 +31,7 @@ _BaseOptions = base_options_module.BaseOptions
_FaceStylizerGraphOptionsProto = ( _FaceStylizerGraphOptionsProto = (
face_stylizer_graph_options_pb2.FaceStylizerGraphOptions face_stylizer_graph_options_pb2.FaceStylizerGraphOptions
) )
_RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
@ -54,11 +55,15 @@ class FaceStylizerOptions:
""" """
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _FaceStylizerGraphOptionsProto: def to_pb2(self) -> _FaceStylizerGraphOptionsProto:
"""Generates an FaceStylizerOptions protobuf object.""" """Generates an FaceStylizerOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
return _FaceStylizerGraphOptionsProto(base_options=base_options_proto) return _FaceStylizerGraphOptionsProto(base_options=base_options_proto)

View File

@ -48,7 +48,6 @@ mediapipe_files(srcs = [
"face_landmark.tflite", "face_landmark.tflite",
"face_landmarker.task", "face_landmarker.task",
"face_landmarker_v2.task", "face_landmarker_v2.task",
"face_stylizer_color_ink.task",
"fist.jpg", "fist.jpg",
"fist.png", "fist.png",
"gesture_recognizer.task", "gesture_recognizer.task",