diff --git a/mediapipe/tasks/python/test/vision/face_stylizer_test.py b/mediapipe/tasks/python/test/vision/face_stylizer_test.py index 27bcc531e..4f5392128 100644 --- a/mediapipe/tasks/python/test/vision/face_stylizer_test.py +++ b/mediapipe/tasks/python/test/vision/face_stylizer_test.py @@ -15,9 +15,7 @@ import enum import os -from unittest import mock -import numpy as np from absl.testing import absltest from absl.testing import parameterized @@ -38,7 +36,7 @@ _FaceStylizerOptions = face_stylizer.FaceStylizerOptions _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions -_MODEL = 'face_stylizer.task' +_MODEL = 'face_stylizer_color_ink.task' _LARGE_FACE_IMAGE = "portrait.jpg" _MODEL_IMAGE_SIZE = 256 _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' @@ -181,163 +179,6 @@ class FaceStylizerTest(parameterized.TestCase): stylized_image = stylizer.stylize(test_image, image_processing_options) self.assertIsNone(stylized_image) - def test_missing_result_callback(self): - options = _FaceStylizerOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.LIVE_STREAM, - ) - with self.assertRaisesRegex( - ValueError, r'result callback must be provided' - ): - with _FaceStylizer.create_from_options(options) as unused_stylizer: - pass - - @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO)) - def test_illegal_result_callback(self, running_mode): - options = _FaceStylizerOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=running_mode, - result_callback=mock.MagicMock(), - ) - with self.assertRaisesRegex( - ValueError, r'result callback should not be provided' - ): - with _FaceStylizer.create_from_options(options) as unused_stylizer: - pass - - def test_calling_stylize_for_video_in_image_mode(self): - options = _FaceStylizerOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.IMAGE, - ) - with _FaceStylizer.create_from_options(options) as stylizer: - with self.assertRaisesRegex( - ValueError, r'not initialized with the video mode' - ): - stylizer.stylize_for_video(self.test_image, 0) - - def test_calling_stylize_async_in_image_mode(self): - options = _FaceStylizerOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.IMAGE, - ) - with _FaceStylizer.create_from_options(options) as stylizer: - with self.assertRaisesRegex( - ValueError, r'not initialized with the live stream mode' - ): - stylizer.stylize_async(self.test_image, 0) - - def test_calling_classify_in_video_mode(self): - options = _FaceStylizerOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO, - ) - with _FaceStylizer.create_from_options(options) as stylizer: - with self.assertRaisesRegex( - ValueError, r'not initialized with the image mode' - ): - stylizer.stylize(self.test_image) - - def test_calling_classify_async_in_video_mode(self): - options = _FaceStylizerOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO, - ) - with _FaceStylizer.create_from_options(options) as stylizer: - with self.assertRaisesRegex( - ValueError, r'not initialized with the live stream mode' - ): - stylizer.stylize_async(self.test_image, 0) - - def test_classify_for_video_with_out_of_order_timestamp(self): - options = _FaceStylizerOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO, - ) - with _FaceStylizer.create_from_options(options) as stylizer: - unused_result = stylizer.stylize_for_video(self.test_image, 1) - with self.assertRaisesRegex( - ValueError, r'Input timestamp must be monotonically increasing' - ): - stylizer.stylize_for_video(self.test_image, 0) - - def test_stylize_for_video(self): - options = _FaceStylizerOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO, - ) - with _FaceStylizer.create_from_options(options) as stylizer: - for timestamp in range(0, 300, 30): - stylized_image = stylizer.stylize_for_video( - self.test_image, timestamp - ) - self.assertIsInstance(stylized_image, _Image) - self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE) - self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE) - - def test_calling_stylize_in_live_stream_mode(self): - options = _FaceStylizerOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=mock.MagicMock(), - ) - with _FaceStylizer.create_from_options(options) as stylizer: - with self.assertRaisesRegex( - ValueError, r'not initialized with the image mode' - ): - stylizer.stylize(self.test_image) - - def test_calling_stylize_for_video_in_live_stream_mode(self): - options = _FaceStylizerOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=mock.MagicMock(), - ) - with _FaceStylizer.create_from_options(options) as stylizer: - with self.assertRaisesRegex( - ValueError, r'not initialized with the video mode' - ): - stylizer.stylize_for_video(self.test_image, 0) - - def test_stylize_async_calls_with_illegal_timestamp(self): - options = _FaceStylizerOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=mock.MagicMock(), - ) - with _FaceStylizer.create_from_options(options) as stylizer: - stylizer.stylize_async(self.test_image, 100) - with self.assertRaisesRegex( - ValueError, r'Input timestamp must be monotonically increasing' - ): - stylizer.stylize_async(self.test_image, 0) - - def test_stylize_async_calls(self): - observed_timestamp_ms = -1 - - def check_result( - stylized_image: _Image, output_image: _Image, timestamp_ms: int - ): - self.assertIsInstance(stylized_image, _Image) - self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE) - self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE) - self.assertTrue( - np.array_equal( - output_image.numpy_view(), self.test_image.numpy_view() - ) - ) - self.assertLess(observed_timestamp_ms, timestamp_ms) - self.observed_timestamp_ms = timestamp_ms - - options = _FaceStylizerOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=check_result, - ) - with _FaceStylizer.create_from_options(options) as stylizer: - for timestamp in range(0, 300, 30): - stylizer.stylize_async(self.test_image, timestamp) - if __name__ == '__main__': absltest.main() \ No newline at end of file diff --git a/mediapipe/tasks/python/vision/face_stylizer.py b/mediapipe/tasks/python/vision/face_stylizer.py index af70a61dc..aa8562e80 100644 --- a/mediapipe/tasks/python/vision/face_stylizer.py +++ b/mediapipe/tasks/python/vision/face_stylizer.py @@ -14,7 +14,7 @@ """MediaPipe face stylizer task.""" import dataclasses -from typing import Optional +from typing import Optional, Callable from mediapipe.python import packet_creator from mediapipe.python import packet_getter @@ -31,6 +31,7 @@ _BaseOptions = base_options_module.BaseOptions _FaceStylizerGraphOptionsProto = ( face_stylizer_graph_options_pb2.FaceStylizerGraphOptions ) +_RunningMode = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo @@ -54,11 +55,15 @@ class FaceStylizerOptions: """ base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.IMAGE @doc_controls.do_not_generate_docs def to_pb2(self) -> _FaceStylizerGraphOptionsProto: """Generates an FaceStylizerOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = ( + False if self.running_mode == _RunningMode.IMAGE else True + ) return _FaceStylizerGraphOptionsProto(base_options=base_options_proto) diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index f53d45bda..4fec8df9f 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -48,7 +48,6 @@ mediapipe_files(srcs = [ "face_landmark.tflite", "face_landmarker.task", "face_landmarker_v2.task", - "face_stylizer_color_ink.task", "fist.jpg", "fist.png", "gesture_recognizer.task",