Renamed HandLandmarksDetectionResult to HandLandmarkerResult

This commit is contained in:
kinaryml 2022-11-08 23:06:28 -08:00
parent 96e2eb38c7
commit 46f135e54d
2 changed files with 36 additions and 36 deletions

View File

@ -41,7 +41,7 @@ _LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionR
_Image = image_module.Image _Image = image_module.Image
_HandLandmarker = hand_landmarker.HandLandmarker _HandLandmarker = hand_landmarker.HandLandmarker
_HandLandmarkerOptions = hand_landmarker.HandLandmarkerOptions _HandLandmarkerOptions = hand_landmarker.HandLandmarkerOptions
_HandLandmarksDetectionResult = hand_landmarker.HandLandmarksDetectionResult _HandLandmarkerResult = hand_landmarker.HandLandmarkerResult
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
@ -58,8 +58,8 @@ _LANDMARKS_ERROR_TOLERANCE = 0.03
_HANDEDNESS_MARGIN = 0.05 _HANDEDNESS_MARGIN = 0.05
def _get_expected_hand_landmarks_detection_result( def _get_expected_hand_landmarker_result(
file_path: str) -> _HandLandmarksDetectionResult: file_path: str) -> _HandLandmarkerResult:
landmarks_detection_result_file_path = test_utils.get_test_data_path( landmarks_detection_result_file_path = test_utils.get_test_data_path(
file_path) file_path)
with open(landmarks_detection_result_file_path, "rb") as f: with open(landmarks_detection_result_file_path, "rb") as f:
@ -69,7 +69,7 @@ def _get_expected_hand_landmarks_detection_result(
text_format.Parse(f.read(), landmarks_detection_result_proto) text_format.Parse(f.read(), landmarks_detection_result_proto)
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2( landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
landmarks_detection_result_proto) landmarks_detection_result_proto)
return _HandLandmarksDetectionResult( return _HandLandmarkerResult(
handedness=[landmarks_detection_result.categories], handedness=[landmarks_detection_result.categories],
hand_landmarks=[landmarks_detection_result.landmarks], hand_landmarks=[landmarks_detection_result.landmarks],
hand_world_landmarks=[landmarks_detection_result.world_landmarks]) hand_world_landmarks=[landmarks_detection_result.world_landmarks])
@ -91,8 +91,8 @@ class HandLandmarkerTest(parameterized.TestCase):
def _assert_actual_result_approximately_matches_expected_result( def _assert_actual_result_approximately_matches_expected_result(
self, self,
actual_result: _HandLandmarksDetectionResult, actual_result: _HandLandmarkerResult,
expected_result: _HandLandmarksDetectionResult expected_result: _HandLandmarkerResult
): ):
# Expects to have the same number of hands detected. # Expects to have the same number of hands detected.
self.assertLen(actual_result.hand_landmarks, self.assertLen(actual_result.hand_landmarks,
@ -150,10 +150,10 @@ class HandLandmarkerTest(parameterized.TestCase):
self.assertIsInstance(landmarker, _HandLandmarker) self.assertIsInstance(landmarker, _HandLandmarker)
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, _get_expected_hand_landmarks_detection_result( (ModelFileType.FILE_NAME, _get_expected_hand_landmarker_result(
_THUMB_UP_LANDMARKS _THUMB_UP_LANDMARKS
)), )),
(ModelFileType.FILE_CONTENT, _get_expected_hand_landmarks_detection_result( (ModelFileType.FILE_CONTENT, _get_expected_hand_landmarker_result(
_THUMB_UP_LANDMARKS _THUMB_UP_LANDMARKS
))) )))
def test_detect(self, model_file_type, expected_detection_result): def test_detect(self, model_file_type, expected_detection_result):
@ -181,10 +181,10 @@ class HandLandmarkerTest(parameterized.TestCase):
landmarker.close() landmarker.close()
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, _get_expected_hand_landmarks_detection_result( (ModelFileType.FILE_NAME, _get_expected_hand_landmarker_result(
_THUMB_UP_LANDMARKS _THUMB_UP_LANDMARKS
)), )),
(ModelFileType.FILE_CONTENT, _get_expected_hand_landmarks_detection_result( (ModelFileType.FILE_CONTENT, _get_expected_hand_landmarker_result(
_THUMB_UP_LANDMARKS _THUMB_UP_LANDMARKS
))) )))
def test_detect_in_context(self, model_file_type, expected_detection_result): def test_detect_in_context(self, model_file_type, expected_detection_result):
@ -233,7 +233,7 @@ class HandLandmarkerTest(parameterized.TestCase):
# Performs hand landmarks detection on the input. # Performs hand landmarks detection on the input.
detection_result = landmarker.detect(test_image, detection_result = landmarker.detect(test_image,
image_processing_options) image_processing_options)
expected_detection_result = _get_expected_hand_landmarks_detection_result( expected_detection_result = _get_expected_hand_landmarker_result(
_POINTING_UP_ROTATED_LANDMARKS) _POINTING_UP_ROTATED_LANDMARKS)
# Comparing results. # Comparing results.
self._assert_actual_result_approximately_matches_expected_result( self._assert_actual_result_approximately_matches_expected_result(
@ -332,14 +332,14 @@ class HandLandmarkerTest(parameterized.TestCase):
landmarker.detect_for_video(self.test_image, 0) landmarker.detect_for_video(self.test_image, 0)
@parameterized.parameters( @parameterized.parameters(
(_THUMB_UP_IMAGE, 0, _get_expected_hand_landmarks_detection_result( (_THUMB_UP_IMAGE, 0, _get_expected_hand_landmarker_result(
_THUMB_UP_LANDMARKS)), _THUMB_UP_LANDMARKS)),
(_POINTING_UP_IMAGE, 0, _get_expected_hand_landmarks_detection_result( (_POINTING_UP_IMAGE, 0, _get_expected_hand_landmarker_result(
_POINTING_UP_LANDMARKS)), _POINTING_UP_LANDMARKS)),
(_POINTING_UP_ROTATED_IMAGE, -90, (_POINTING_UP_ROTATED_IMAGE, -90,
_get_expected_hand_landmarks_detection_result( _get_expected_hand_landmarker_result(
_POINTING_UP_ROTATED_LANDMARKS)), _POINTING_UP_ROTATED_LANDMARKS)),
(_NO_HANDS_IMAGE, 0, _HandLandmarksDetectionResult([], [], []))) (_NO_HANDS_IMAGE, 0, _HandLandmarkerResult([], [], [])))
def test_detect_for_video(self, image_path, rotation, expected_result): def test_detect_for_video(self, image_path, rotation, expected_result):
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path(image_path)) test_utils.get_test_data_path(image_path))
@ -392,14 +392,14 @@ class HandLandmarkerTest(parameterized.TestCase):
landmarker.detect_async(self.test_image, 0) landmarker.detect_async(self.test_image, 0)
@parameterized.parameters( @parameterized.parameters(
(_THUMB_UP_IMAGE, 0, _get_expected_hand_landmarks_detection_result( (_THUMB_UP_IMAGE, 0, _get_expected_hand_landmarker_result(
_THUMB_UP_LANDMARKS)), _THUMB_UP_LANDMARKS)),
(_POINTING_UP_IMAGE, 0, _get_expected_hand_landmarks_detection_result( (_POINTING_UP_IMAGE, 0, _get_expected_hand_landmarker_result(
_POINTING_UP_LANDMARKS)), _POINTING_UP_LANDMARKS)),
(_POINTING_UP_ROTATED_IMAGE, -90, (_POINTING_UP_ROTATED_IMAGE, -90,
_get_expected_hand_landmarks_detection_result( _get_expected_hand_landmarker_result(
_POINTING_UP_ROTATED_LANDMARKS)), _POINTING_UP_ROTATED_LANDMARKS)),
(_NO_HANDS_IMAGE, 0, _HandLandmarksDetectionResult([], [], []))) (_NO_HANDS_IMAGE, 0, _HandLandmarkerResult([], [], [])))
def test_detect_async_calls(self, image_path, rotation, expected_result): def test_detect_async_calls(self, image_path, rotation, expected_result):
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path(image_path)) test_utils.get_test_data_path(image_path))
@ -407,7 +407,7 @@ class HandLandmarkerTest(parameterized.TestCase):
image_processing_options = _ImageProcessingOptions(rotation_degrees=rotation) image_processing_options = _ImageProcessingOptions(rotation_degrees=rotation)
observed_timestamp_ms = -1 observed_timestamp_ms = -1
def check_result(result: _HandLandmarksDetectionResult, def check_result(result: _HandLandmarkerResult,
output_image: _Image, output_image: _Image,
timestamp_ms: int): timestamp_ms: int):
if result.hand_landmarks and result.hand_world_landmarks and \ if result.hand_landmarks and result.hand_world_landmarks and \

View File

@ -54,8 +54,8 @@ _MICRO_SECONDS_PER_MILLISECOND = 1000
@dataclasses.dataclass @dataclasses.dataclass
class HandLandmarksDetectionResult: class HandLandmarkerResult:
"""The hand landmarks detection result from HandLandmarker, where each vector """The hand landmarks result from HandLandmarker, where each vector
element represents a single hand detected in the image. element represents a single hand detected in the image.
Attributes: Attributes:
@ -69,9 +69,9 @@ class HandLandmarksDetectionResult:
hand_world_landmarks: List[List[landmark_module.Landmark]] hand_world_landmarks: List[List[landmark_module.Landmark]]
def _build_detection_result( def _build_landmarker_result(
output_packets: Mapping[str,packet_module.Packet] output_packets: Mapping[str,packet_module.Packet]
) -> HandLandmarksDetectionResult: ) -> HandLandmarkerResult:
"""Constructs a `HandLandmarksDetectionResult` from output packets.""" """Constructs a `HandLandmarksDetectionResult` from output packets."""
handedness_proto_list = packet_getter.get_proto_list( handedness_proto_list = packet_getter.get_proto_list(
output_packets[_HANDEDNESS_STREAM_NAME]) output_packets[_HANDEDNESS_STREAM_NAME])
@ -114,7 +114,7 @@ def _build_detection_result(
landmark_module.Landmark.create_from_pb2(hand_world_landmark)) landmark_module.Landmark.create_from_pb2(hand_world_landmark))
hand_world_landmarks_results.append(hand_world_landmarks_list) hand_world_landmarks_results.append(hand_world_landmarks_list)
return HandLandmarksDetectionResult(handedness_results, return HandLandmarkerResult(handedness_results,
hand_landmarks_results, hand_landmarks_results,
hand_world_landmarks_results) hand_world_landmarks_results)
@ -151,7 +151,7 @@ class HandLandmarkerOptions:
min_hand_presence_confidence: Optional[float] = 0.5 min_hand_presence_confidence: Optional[float] = 0.5
min_tracking_confidence: Optional[float] = 0.5 min_tracking_confidence: Optional[float] = 0.5
result_callback: Optional[Callable[ result_callback: Optional[Callable[
[HandLandmarksDetectionResult, image_module.Image, int], None]] = None [HandLandmarkerResult, image_module.Image, int], None]] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _HandLandmarkerGraphOptionsProto: def to_pb2(self) -> _HandLandmarkerGraphOptionsProto:
@ -221,11 +221,11 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty(): if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
empty_packet = output_packets[_HAND_LANDMARKS_STREAM_NAME] empty_packet = output_packets[_HAND_LANDMARKS_STREAM_NAME]
options.result_callback( options.result_callback(
HandLandmarksDetectionResult([], [], []), image, HandLandmarkerResult([], [], []), image,
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
return return
hand_landmarks_detection_result = _build_detection_result(output_packets) hand_landmarks_detection_result = _build_landmarker_result(output_packets)
timestamp = output_packets[_HAND_LANDMARKS_STREAM_NAME].timestamp timestamp = output_packets[_HAND_LANDMARKS_STREAM_NAME].timestamp
options.result_callback(hand_landmarks_detection_result, image, options.result_callback(hand_landmarks_detection_result, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
@ -255,7 +255,7 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> HandLandmarksDetectionResult: ) -> HandLandmarkerResult:
"""Performs hand landmarks detection on the given image. """Performs hand landmarks detection on the given image.
Only use this method when the HandLandmarker is created with the image Only use this method when the HandLandmarker is created with the image
@ -286,16 +286,16 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
}) })
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty(): if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
return HandLandmarksDetectionResult([], [], []) return HandLandmarkerResult([], [], [])
return _build_detection_result(output_packets) return _build_landmarker_result(output_packets)
def detect_for_video( def detect_for_video(
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> HandLandmarksDetectionResult: ) -> HandLandmarkerResult:
"""Performs hand landmarks detection on the provided video frame. """Performs hand landmarks detection on the provided video frame.
Only use this method when the HandLandmarker is created with the video Only use this method when the HandLandmarker is created with the video
@ -330,9 +330,9 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
}) })
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty(): if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
return HandLandmarksDetectionResult([], [], []) return HandLandmarkerResult([], [], [])
return _build_detection_result(output_packets) return _build_landmarker_result(output_packets)
def detect_async( def detect_async(
self, self,