Revised Face Stylizer tests
This commit is contained in:
		
							parent
							
								
									1216c0f998
								
							
						
					
					
						commit
						d21460e17b
					
				| 
						 | 
				
			
			@ -15,9 +15,7 @@
 | 
			
		|||
 | 
			
		||||
import enum
 | 
			
		||||
import os
 | 
			
		||||
from unittest import mock
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from absl.testing import absltest
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -38,7 +36,7 @@ _FaceStylizerOptions = face_stylizer.FaceStylizerOptions
 | 
			
		|||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
 | 
			
		||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
 | 
			
		||||
 | 
			
		||||
_MODEL = 'face_stylizer.task'
 | 
			
		||||
_MODEL = 'face_stylizer_color_ink.task'
 | 
			
		||||
_LARGE_FACE_IMAGE = "portrait.jpg"
 | 
			
		||||
_MODEL_IMAGE_SIZE = 256
 | 
			
		||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
 | 
			
		||||
| 
						 | 
				
			
			@ -181,163 +179,6 @@ class FaceStylizerTest(parameterized.TestCase):
 | 
			
		|||
      stylized_image = stylizer.stylize(test_image, image_processing_options)
 | 
			
		||||
      self.assertIsNone(stylized_image)
 | 
			
		||||
 | 
			
		||||
  def test_missing_result_callback(self):
 | 
			
		||||
    options = _FaceStylizerOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.LIVE_STREAM,
 | 
			
		||||
    )
 | 
			
		||||
    with self.assertRaisesRegex(
 | 
			
		||||
        ValueError, r'result callback must be provided'
 | 
			
		||||
    ):
 | 
			
		||||
      with _FaceStylizer.create_from_options(options) as unused_stylizer:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
  @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
 | 
			
		||||
  def test_illegal_result_callback(self, running_mode):
 | 
			
		||||
    options = _FaceStylizerOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=running_mode,
 | 
			
		||||
      result_callback=mock.MagicMock(),
 | 
			
		||||
    )
 | 
			
		||||
    with self.assertRaisesRegex(
 | 
			
		||||
        ValueError, r'result callback should not be provided'
 | 
			
		||||
    ):
 | 
			
		||||
      with _FaceStylizer.create_from_options(options) as unused_stylizer:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
  def test_calling_stylize_for_video_in_image_mode(self):
 | 
			
		||||
    options = _FaceStylizerOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.IMAGE,
 | 
			
		||||
    )
 | 
			
		||||
    with _FaceStylizer.create_from_options(options) as stylizer:
 | 
			
		||||
      with self.assertRaisesRegex(
 | 
			
		||||
          ValueError, r'not initialized with the video mode'
 | 
			
		||||
      ):
 | 
			
		||||
        stylizer.stylize_for_video(self.test_image, 0)
 | 
			
		||||
 | 
			
		||||
  def test_calling_stylize_async_in_image_mode(self):
 | 
			
		||||
    options = _FaceStylizerOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.IMAGE,
 | 
			
		||||
    )
 | 
			
		||||
    with _FaceStylizer.create_from_options(options) as stylizer:
 | 
			
		||||
      with self.assertRaisesRegex(
 | 
			
		||||
          ValueError, r'not initialized with the live stream mode'
 | 
			
		||||
      ):
 | 
			
		||||
        stylizer.stylize_async(self.test_image, 0)
 | 
			
		||||
 | 
			
		||||
  def test_calling_classify_in_video_mode(self):
 | 
			
		||||
    options = _FaceStylizerOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.VIDEO,
 | 
			
		||||
    )
 | 
			
		||||
    with _FaceStylizer.create_from_options(options) as stylizer:
 | 
			
		||||
      with self.assertRaisesRegex(
 | 
			
		||||
          ValueError, r'not initialized with the image mode'
 | 
			
		||||
      ):
 | 
			
		||||
        stylizer.stylize(self.test_image)
 | 
			
		||||
 | 
			
		||||
  def test_calling_classify_async_in_video_mode(self):
 | 
			
		||||
    options = _FaceStylizerOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.VIDEO,
 | 
			
		||||
    )
 | 
			
		||||
    with _FaceStylizer.create_from_options(options) as stylizer:
 | 
			
		||||
      with self.assertRaisesRegex(
 | 
			
		||||
          ValueError, r'not initialized with the live stream mode'
 | 
			
		||||
      ):
 | 
			
		||||
        stylizer.stylize_async(self.test_image, 0)
 | 
			
		||||
 | 
			
		||||
  def test_classify_for_video_with_out_of_order_timestamp(self):
 | 
			
		||||
    options = _FaceStylizerOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.VIDEO,
 | 
			
		||||
    )
 | 
			
		||||
    with _FaceStylizer.create_from_options(options) as stylizer:
 | 
			
		||||
      unused_result = stylizer.stylize_for_video(self.test_image, 1)
 | 
			
		||||
      with self.assertRaisesRegex(
 | 
			
		||||
          ValueError, r'Input timestamp must be monotonically increasing'
 | 
			
		||||
      ):
 | 
			
		||||
        stylizer.stylize_for_video(self.test_image, 0)
 | 
			
		||||
 | 
			
		||||
  def test_stylize_for_video(self):
 | 
			
		||||
    options = _FaceStylizerOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.VIDEO,
 | 
			
		||||
    )
 | 
			
		||||
    with _FaceStylizer.create_from_options(options) as stylizer:
 | 
			
		||||
      for timestamp in range(0, 300, 30):
 | 
			
		||||
        stylized_image = stylizer.stylize_for_video(
 | 
			
		||||
          self.test_image, timestamp
 | 
			
		||||
        )
 | 
			
		||||
        self.assertIsInstance(stylized_image, _Image)
 | 
			
		||||
        self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
 | 
			
		||||
        self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
 | 
			
		||||
 | 
			
		||||
  def test_calling_stylize_in_live_stream_mode(self):
 | 
			
		||||
    options = _FaceStylizerOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.LIVE_STREAM,
 | 
			
		||||
      result_callback=mock.MagicMock(),
 | 
			
		||||
    )
 | 
			
		||||
    with _FaceStylizer.create_from_options(options) as stylizer:
 | 
			
		||||
      with self.assertRaisesRegex(
 | 
			
		||||
          ValueError, r'not initialized with the image mode'
 | 
			
		||||
      ):
 | 
			
		||||
        stylizer.stylize(self.test_image)
 | 
			
		||||
 | 
			
		||||
  def test_calling_stylize_for_video_in_live_stream_mode(self):
 | 
			
		||||
    options = _FaceStylizerOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.LIVE_STREAM,
 | 
			
		||||
      result_callback=mock.MagicMock(),
 | 
			
		||||
    )
 | 
			
		||||
    with _FaceStylizer.create_from_options(options) as stylizer:
 | 
			
		||||
      with self.assertRaisesRegex(
 | 
			
		||||
          ValueError, r'not initialized with the video mode'
 | 
			
		||||
      ):
 | 
			
		||||
        stylizer.stylize_for_video(self.test_image, 0)
 | 
			
		||||
 | 
			
		||||
  def test_stylize_async_calls_with_illegal_timestamp(self):
 | 
			
		||||
    options = _FaceStylizerOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.LIVE_STREAM,
 | 
			
		||||
      result_callback=mock.MagicMock(),
 | 
			
		||||
    )
 | 
			
		||||
    with _FaceStylizer.create_from_options(options) as stylizer:
 | 
			
		||||
      stylizer.stylize_async(self.test_image, 100)
 | 
			
		||||
      with self.assertRaisesRegex(
 | 
			
		||||
          ValueError, r'Input timestamp must be monotonically increasing'
 | 
			
		||||
      ):
 | 
			
		||||
        stylizer.stylize_async(self.test_image, 0)
 | 
			
		||||
 | 
			
		||||
  def test_stylize_async_calls(self):
 | 
			
		||||
    observed_timestamp_ms = -1
 | 
			
		||||
 | 
			
		||||
    def check_result(
 | 
			
		||||
        stylized_image: _Image, output_image: _Image, timestamp_ms: int
 | 
			
		||||
    ):
 | 
			
		||||
      self.assertIsInstance(stylized_image, _Image)
 | 
			
		||||
      self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
 | 
			
		||||
      self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
 | 
			
		||||
      self.assertTrue(
 | 
			
		||||
        np.array_equal(
 | 
			
		||||
          output_image.numpy_view(), self.test_image.numpy_view()
 | 
			
		||||
        )
 | 
			
		||||
      )
 | 
			
		||||
      self.assertLess(observed_timestamp_ms, timestamp_ms)
 | 
			
		||||
      self.observed_timestamp_ms = timestamp_ms
 | 
			
		||||
 | 
			
		||||
    options = _FaceStylizerOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.LIVE_STREAM,
 | 
			
		||||
      result_callback=check_result,
 | 
			
		||||
    )
 | 
			
		||||
    with _FaceStylizer.create_from_options(options) as stylizer:
 | 
			
		||||
      for timestamp in range(0, 300, 30):
 | 
			
		||||
        stylizer.stylize_async(self.test_image, timestamp)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
  absltest.main()
 | 
			
		||||
| 
						 | 
				
			
			@ -14,7 +14,7 @@
 | 
			
		|||
"""MediaPipe face stylizer task."""
 | 
			
		||||
 | 
			
		||||
import dataclasses
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from typing import Optional, Callable
 | 
			
		||||
 | 
			
		||||
from mediapipe.python import packet_creator
 | 
			
		||||
from mediapipe.python import packet_getter
 | 
			
		||||
| 
						 | 
				
			
			@ -31,6 +31,7 @@ _BaseOptions = base_options_module.BaseOptions
 | 
			
		|||
_FaceStylizerGraphOptionsProto = (
 | 
			
		||||
    face_stylizer_graph_options_pb2.FaceStylizerGraphOptions
 | 
			
		||||
)
 | 
			
		||||
_RunningMode = running_mode_module.VisionTaskRunningMode
 | 
			
		||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
 | 
			
		||||
_TaskInfo = task_info_module.TaskInfo
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -54,11 +55,15 @@ class FaceStylizerOptions:
 | 
			
		|||
  """
 | 
			
		||||
 | 
			
		||||
  base_options: _BaseOptions
 | 
			
		||||
  running_mode: _RunningMode = _RunningMode.IMAGE
 | 
			
		||||
 | 
			
		||||
  @doc_controls.do_not_generate_docs
 | 
			
		||||
  def to_pb2(self) -> _FaceStylizerGraphOptionsProto:
 | 
			
		||||
    """Generates an FaceStylizerOptions protobuf object."""
 | 
			
		||||
    base_options_proto = self.base_options.to_pb2()
 | 
			
		||||
    base_options_proto.use_stream_mode = (
 | 
			
		||||
        False if self.running_mode == _RunningMode.IMAGE else True
 | 
			
		||||
    )
 | 
			
		||||
    return _FaceStylizerGraphOptionsProto(base_options=base_options_proto)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										1
									
								
								mediapipe/tasks/testdata/vision/BUILD
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								mediapipe/tasks/testdata/vision/BUILD
									
									
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -48,7 +48,6 @@ mediapipe_files(srcs = [
 | 
			
		|||
    "face_landmark.tflite",
 | 
			
		||||
    "face_landmarker.task",
 | 
			
		||||
    "face_landmarker_v2.task",
 | 
			
		||||
    "face_stylizer_color_ink.task",
 | 
			
		||||
    "fist.jpg",
 | 
			
		||||
    "fist.png",
 | 
			
		||||
    "gesture_recognizer.task",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user