Code cleanup

This commit is contained in:
kinaryml 2022-09-11 14:00:49 -07:00
parent f2e42d16bd
commit ec0c5f4341
2 changed files with 6 additions and 52 deletions

View File

@ -124,7 +124,7 @@ class ImageClassifierTest(parameterized.TestCase):
(ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT),
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT))
def test_classify(self, model_file_type, max_results,
expected_classification_result):
expected_classification_result):
# Creates classifier.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(file_name=self.model_path)
@ -152,7 +152,7 @@ class ImageClassifierTest(parameterized.TestCase):
(ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT),
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT))
def test_classify_in_context(self, model_file_type, max_results,
expected_classification_result):
expected_classification_result):
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(file_name=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
@ -275,27 +275,6 @@ class ImageClassifierTest(parameterized.TestCase):
with _ImageClassifier.create_from_options(options) as unused_classifier:
pass
# @parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT),
# (1, _ClassificationResult(classifications=[])))
# def test_classify_async_calls(self, threshold, expected_result):
# observed_timestamp_ms = -1
#
# def check_result(result: _ClassificationResult, timestamp_ms: int):
# self.assertEqual(result, expected_result)
# self.assertLess(observed_timestamp_ms, timestamp_ms)
# self.observed_timestamp_ms = timestamp_ms
#
# options = _ImageClassifierOptions(
# base_options=_BaseOptions(file_name=self.model_path),
# running_mode=_RUNNING_MODE.LIVE_STREAM,
# max_results=4,
# score_threshold=threshold,
# result_callback=check_result)
# classifier = _ImageClassifier.create_from_options(options)
# for timestamp in range(0, 300, 30):
# classifier.classify_async(self.test_image, timestamp)
# classifier.close()
if __name__ == '__main__':
absltest.main()

View File

@ -83,7 +83,8 @@ class ImageClassifierOptions:
category_allowlist: Optional[List[str]] = None
category_denylist: Optional[List[str]] = None
result_callback: Optional[
Callable[[classifications_module.ClassificationResult], None]] = None
Callable[[classifications_module.ClassificationResult],
None]] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageClassifierOptionsProto:
@ -96,7 +97,8 @@ class ImageClassifierOptions:
max_results=self.max_results,
score_threshold=self.score_threshold,
category_allowlist=self.category_allowlist,
category_denylist=self.category_denylist)
category_denylist=self.category_denylist
)
return _ImageClassifierOptionsProto(
base_options=base_options_proto,
@ -198,30 +200,3 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
classifications_module.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])
def classify_async(self, image: image_module.Image, timestamp_ms: int) -> None:
"""Sends live image data (an Image with a unique timestamp) to perform image
classification.
This method will return immediately after the input image is accepted. The
results will be available via the `result_callback` provided in the
`ImageClassifierOptions`. The `detect_async` method is designed to process
live stream data such as camera input. To lower the overall latency, image
classifier may drop the input images if needed. In other words, it's not
guaranteed to have output per input image. The `result_callback` provides:
- A classification result object that contains a list of classifications.
- The input image that the image classifier runs on.
- The input timestamp in milliseconds.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds.
Raises:
ValueError: If the current input timestamp is smaller than what the image
classifier has already processed.
"""
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(timestamp_ms)
})