Removed an unnecessary check and updated tests to check if the masks are generated or not
This commit is contained in:
		
							parent
							
								
									9032bce577
								
							
						
					
					
						commit
						b511822815
					
				| 
						 | 
				
			
			@ -104,12 +104,19 @@ class PoseLandmarkerTest(parameterized.TestCase):
 | 
			
		|||
      self,
 | 
			
		||||
      actual_result: PoseLandmarkerResult,
 | 
			
		||||
      expected_result: PoseLandmarkerResult,
 | 
			
		||||
      output_segmentation_masks: bool,
 | 
			
		||||
      margin: float
 | 
			
		||||
  ):
 | 
			
		||||
    self._expect_pose_landmarks_correct(
 | 
			
		||||
        actual_result.pose_landmarks, expected_result.pose_landmarks,
 | 
			
		||||
        margin
 | 
			
		||||
    )
 | 
			
		||||
    if output_segmentation_masks:
 | 
			
		||||
      self.assertIsInstance(actual_result.segmentation_masks, List)
 | 
			
		||||
      for i, mask in enumerate(actual_result.segmentation_masks):
 | 
			
		||||
        self.assertIsInstance(mask, _Image)
 | 
			
		||||
    else:
 | 
			
		||||
      self.assertIsNone(actual_result.segmentation_masks)
 | 
			
		||||
 | 
			
		||||
  def test_create_from_file_succeeds_with_valid_model_path(self):
 | 
			
		||||
    # Creates with default option and valid model file successfully.
 | 
			
		||||
| 
						 | 
				
			
			@ -141,12 +148,17 @@ class PoseLandmarkerTest(parameterized.TestCase):
 | 
			
		|||
      self.assertIsInstance(landmarker, _PoseLandmarker)
 | 
			
		||||
 | 
			
		||||
  @parameterized.parameters(
 | 
			
		||||
      (ModelFileType.FILE_NAME,
 | 
			
		||||
      (ModelFileType.FILE_NAME, False,
 | 
			
		||||
       _get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
 | 
			
		||||
      (ModelFileType.FILE_CONTENT,
 | 
			
		||||
      (ModelFileType.FILE_CONTENT, False,
 | 
			
		||||
       _get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
 | 
			
		||||
      (ModelFileType.FILE_NAME, True,
 | 
			
		||||
       _get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
 | 
			
		||||
      (ModelFileType.FILE_CONTENT, True,
 | 
			
		||||
       _get_expected_pose_landmarker_result(_POSE_LANDMARKS))
 | 
			
		||||
  )
 | 
			
		||||
  def test_detect(self, model_file_type, expected_detection_result):
 | 
			
		||||
  def test_detect(self, model_file_type, output_segmentation_masks,
 | 
			
		||||
                  expected_detection_result):
 | 
			
		||||
    # Creates pose landmarker.
 | 
			
		||||
    if model_file_type is ModelFileType.FILE_NAME:
 | 
			
		||||
      base_options = _BaseOptions(model_asset_path=self.model_path)
 | 
			
		||||
| 
						 | 
				
			
			@ -158,7 +170,10 @@ class PoseLandmarkerTest(parameterized.TestCase):
 | 
			
		|||
      # Should never happen
 | 
			
		||||
      raise ValueError('model_file_type is invalid.')
 | 
			
		||||
 | 
			
		||||
    options = _PoseLandmarkerOptions(base_options=base_options)
 | 
			
		||||
    options = _PoseLandmarkerOptions(
 | 
			
		||||
        base_options=base_options,
 | 
			
		||||
        output_segmentation_masks=output_segmentation_masks
 | 
			
		||||
    )
 | 
			
		||||
    landmarker = _PoseLandmarker.create_from_options(options)
 | 
			
		||||
 | 
			
		||||
    # Performs pose landmarks detection on the input.
 | 
			
		||||
| 
						 | 
				
			
			@ -166,19 +181,27 @@ class PoseLandmarkerTest(parameterized.TestCase):
 | 
			
		|||
 | 
			
		||||
    # Comparing results.
 | 
			
		||||
    self._expect_pose_landmarker_results_correct(
 | 
			
		||||
        detection_result, expected_detection_result, _LANDMARKS_MARGIN
 | 
			
		||||
        detection_result, expected_detection_result, output_segmentation_masks,
 | 
			
		||||
        _LANDMARKS_MARGIN
 | 
			
		||||
    )
 | 
			
		||||
    # Closes the pose landmarker explicitly when the pose landmarker is not used
 | 
			
		||||
    # in a context.
 | 
			
		||||
    landmarker.close()
 | 
			
		||||
 | 
			
		||||
  @parameterized.parameters(
 | 
			
		||||
      (ModelFileType.FILE_NAME,
 | 
			
		||||
      (ModelFileType.FILE_NAME, False,
 | 
			
		||||
       _get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
 | 
			
		||||
      (ModelFileType.FILE_CONTENT,
 | 
			
		||||
      (ModelFileType.FILE_CONTENT, False,
 | 
			
		||||
       _get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
 | 
			
		||||
      (ModelFileType.FILE_NAME, True,
 | 
			
		||||
       _get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
 | 
			
		||||
      (ModelFileType.FILE_CONTENT, True,
 | 
			
		||||
       _get_expected_pose_landmarker_result(_POSE_LANDMARKS))
 | 
			
		||||
  )
 | 
			
		||||
  def test_detect_in_context(self, model_file_type, expected_detection_result):
 | 
			
		||||
  def test_detect_in_context(
 | 
			
		||||
      self, model_file_type, output_segmentation_masks,
 | 
			
		||||
      expected_detection_result
 | 
			
		||||
  ):
 | 
			
		||||
    # Creates pose landmarker.
 | 
			
		||||
    if model_file_type is ModelFileType.FILE_NAME:
 | 
			
		||||
      base_options = _BaseOptions(model_asset_path=self.model_path)
 | 
			
		||||
| 
						 | 
				
			
			@ -190,14 +213,18 @@ class PoseLandmarkerTest(parameterized.TestCase):
 | 
			
		|||
      # Should never happen
 | 
			
		||||
      raise ValueError('model_file_type is invalid.')
 | 
			
		||||
 | 
			
		||||
    options = _PoseLandmarkerOptions(base_options=base_options)
 | 
			
		||||
    options = _PoseLandmarkerOptions(
 | 
			
		||||
        base_options=base_options,
 | 
			
		||||
        output_segmentation_masks=output_segmentation_masks
 | 
			
		||||
    )
 | 
			
		||||
    with _PoseLandmarker.create_from_options(options) as landmarker:
 | 
			
		||||
      # Performs pose landmarks detection on the input.
 | 
			
		||||
      detection_result = landmarker.detect(self.test_image)
 | 
			
		||||
 | 
			
		||||
      # Comparing results.
 | 
			
		||||
      self._expect_pose_landmarker_results_correct(
 | 
			
		||||
        detection_result, expected_detection_result, _LANDMARKS_MARGIN
 | 
			
		||||
          detection_result, expected_detection_result,
 | 
			
		||||
          output_segmentation_masks, _LANDMARKS_MARGIN
 | 
			
		||||
      )
 | 
			
		||||
 | 
			
		||||
  def test_detect_fails_with_region_of_interest(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -295,12 +322,15 @@ class PoseLandmarkerTest(parameterized.TestCase):
 | 
			
		|||
        landmarker.detect_for_video(self.test_image, 0)
 | 
			
		||||
 | 
			
		||||
  @parameterized.parameters(
 | 
			
		||||
      (_POSE_IMAGE, 0,
 | 
			
		||||
      (_POSE_IMAGE, 0, False,
 | 
			
		||||
       _get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
 | 
			
		||||
      (_BURGER_IMAGE, 0,
 | 
			
		||||
      (_POSE_IMAGE, 0, True,
 | 
			
		||||
       _get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
 | 
			
		||||
      (_BURGER_IMAGE, 0, False,
 | 
			
		||||
       PoseLandmarkerResult([], [], []))
 | 
			
		||||
  )
 | 
			
		||||
  def test_detect_for_video(self, image_path, rotation, expected_result):
 | 
			
		||||
  def test_detect_for_video(self, image_path, rotation,
 | 
			
		||||
                            output_segmentation_masks, expected_result):
 | 
			
		||||
    test_image = _Image.create_from_file(
 | 
			
		||||
        test_utils.get_test_data_path(image_path))
 | 
			
		||||
    # Set rotation parameters using ImageProcessingOptions.
 | 
			
		||||
| 
						 | 
				
			
			@ -308,6 +338,7 @@ class PoseLandmarkerTest(parameterized.TestCase):
 | 
			
		|||
        rotation_degrees=rotation)
 | 
			
		||||
    options = _PoseLandmarkerOptions(
 | 
			
		||||
        base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
        output_segmentation_masks=output_segmentation_masks,
 | 
			
		||||
        running_mode=_RUNNING_MODE.VIDEO)
 | 
			
		||||
    with _PoseLandmarker.create_from_options(options) as landmarker:
 | 
			
		||||
      for timestamp in range(0, 300, 30):
 | 
			
		||||
| 
						 | 
				
			
			@ -315,7 +346,8 @@ class PoseLandmarkerTest(parameterized.TestCase):
 | 
			
		|||
                                             image_processing_options)
 | 
			
		||||
        if result.pose_landmarks:
 | 
			
		||||
          self._expect_pose_landmarker_results_correct(
 | 
			
		||||
              result, expected_result, _LANDMARKS_MARGIN
 | 
			
		||||
              result, expected_result, output_segmentation_masks,
 | 
			
		||||
              _LANDMARKS_MARGIN
 | 
			
		||||
          )
 | 
			
		||||
        else:
 | 
			
		||||
          self.assertEqual(result, expected_result)
 | 
			
		||||
| 
						 | 
				
			
			@ -352,12 +384,15 @@ class PoseLandmarkerTest(parameterized.TestCase):
 | 
			
		|||
        landmarker.detect_async(self.test_image, 0)
 | 
			
		||||
 | 
			
		||||
  @parameterized.parameters(
 | 
			
		||||
      (_POSE_IMAGE, 0,
 | 
			
		||||
      (_POSE_IMAGE, 0, False,
 | 
			
		||||
       _get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
 | 
			
		||||
      (_BURGER_IMAGE, 0,
 | 
			
		||||
      (_POSE_IMAGE, 0, True,
 | 
			
		||||
       _get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
 | 
			
		||||
      (_BURGER_IMAGE, 0, False,
 | 
			
		||||
       PoseLandmarkerResult([], [], []))
 | 
			
		||||
  )
 | 
			
		||||
  def test_detect_async_calls(self, image_path, rotation, expected_result):
 | 
			
		||||
  def test_detect_async_calls(self, image_path, rotation,
 | 
			
		||||
                              output_segmentation_masks, expected_result):
 | 
			
		||||
    test_image = _Image.create_from_file(
 | 
			
		||||
        test_utils.get_test_data_path(image_path))
 | 
			
		||||
    # Set rotation parameters using ImageProcessingOptions.
 | 
			
		||||
| 
						 | 
				
			
			@ -369,7 +404,8 @@ class PoseLandmarkerTest(parameterized.TestCase):
 | 
			
		|||
                     timestamp_ms: int):
 | 
			
		||||
      if result.pose_landmarks:
 | 
			
		||||
        self._expect_pose_landmarker_results_correct(
 | 
			
		||||
            result, expected_result, _LANDMARKS_MARGIN
 | 
			
		||||
            result, expected_result, output_segmentation_masks,
 | 
			
		||||
            _LANDMARKS_MARGIN
 | 
			
		||||
        )
 | 
			
		||||
      else:
 | 
			
		||||
        self.assertEqual(result, expected_result)
 | 
			
		||||
| 
						 | 
				
			
			@ -380,6 +416,7 @@ class PoseLandmarkerTest(parameterized.TestCase):
 | 
			
		|||
 | 
			
		||||
    options = _PoseLandmarkerOptions(
 | 
			
		||||
        base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
        output_segmentation_masks=output_segmentation_masks,
 | 
			
		||||
        running_mode=_RUNNING_MODE.LIVE_STREAM,
 | 
			
		||||
        result_callback=check_result)
 | 
			
		||||
    with _PoseLandmarker.create_from_options(options) as landmarker:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -263,7 +263,6 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
 | 
			
		|||
      )
 | 
			
		||||
 | 
			
		||||
    output_streams = [
 | 
			
		||||
        ':'.join([_SEGMENTATION_MASK_TAG, _SEGMENTATION_MASK_STREAM_NAME]),
 | 
			
		||||
        ':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]),
 | 
			
		||||
        ':'.join([
 | 
			
		||||
          _POSE_WORLD_LANDMARKS_TAG, _POSE_WORLD_LANDMARKS_STREAM_NAME
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user