Code cleanup
This commit is contained in:
parent
f2e42d16bd
commit
ec0c5f4341
|
@ -124,7 +124,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
(ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT),
|
(ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT),
|
||||||
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT))
|
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT))
|
||||||
def test_classify(self, model_file_type, max_results,
|
def test_classify(self, model_file_type, max_results,
|
||||||
expected_classification_result):
|
expected_classification_result):
|
||||||
# Creates classifier.
|
# Creates classifier.
|
||||||
if model_file_type is ModelFileType.FILE_NAME:
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
base_options = _BaseOptions(file_name=self.model_path)
|
base_options = _BaseOptions(file_name=self.model_path)
|
||||||
|
@ -152,7 +152,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
(ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT),
|
(ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT),
|
||||||
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT))
|
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT))
|
||||||
def test_classify_in_context(self, model_file_type, max_results,
|
def test_classify_in_context(self, model_file_type, max_results,
|
||||||
expected_classification_result):
|
expected_classification_result):
|
||||||
if model_file_type is ModelFileType.FILE_NAME:
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
base_options = _BaseOptions(file_name=self.model_path)
|
base_options = _BaseOptions(file_name=self.model_path)
|
||||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
@ -275,27 +275,6 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# @parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT),
|
|
||||||
# (1, _ClassificationResult(classifications=[])))
|
|
||||||
# def test_classify_async_calls(self, threshold, expected_result):
|
|
||||||
# observed_timestamp_ms = -1
|
|
||||||
#
|
|
||||||
# def check_result(result: _ClassificationResult, timestamp_ms: int):
|
|
||||||
# self.assertEqual(result, expected_result)
|
|
||||||
# self.assertLess(observed_timestamp_ms, timestamp_ms)
|
|
||||||
# self.observed_timestamp_ms = timestamp_ms
|
|
||||||
#
|
|
||||||
# options = _ImageClassifierOptions(
|
|
||||||
# base_options=_BaseOptions(file_name=self.model_path),
|
|
||||||
# running_mode=_RUNNING_MODE.LIVE_STREAM,
|
|
||||||
# max_results=4,
|
|
||||||
# score_threshold=threshold,
|
|
||||||
# result_callback=check_result)
|
|
||||||
# classifier = _ImageClassifier.create_from_options(options)
|
|
||||||
# for timestamp in range(0, 300, 30):
|
|
||||||
# classifier.classify_async(self.test_image, timestamp)
|
|
||||||
# classifier.close()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -83,7 +83,8 @@ class ImageClassifierOptions:
|
||||||
category_allowlist: Optional[List[str]] = None
|
category_allowlist: Optional[List[str]] = None
|
||||||
category_denylist: Optional[List[str]] = None
|
category_denylist: Optional[List[str]] = None
|
||||||
result_callback: Optional[
|
result_callback: Optional[
|
||||||
Callable[[classifications_module.ClassificationResult], None]] = None
|
Callable[[classifications_module.ClassificationResult],
|
||||||
|
None]] = None
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def to_pb2(self) -> _ImageClassifierOptionsProto:
|
def to_pb2(self) -> _ImageClassifierOptionsProto:
|
||||||
|
@ -96,7 +97,8 @@ class ImageClassifierOptions:
|
||||||
max_results=self.max_results,
|
max_results=self.max_results,
|
||||||
score_threshold=self.score_threshold,
|
score_threshold=self.score_threshold,
|
||||||
category_allowlist=self.category_allowlist,
|
category_allowlist=self.category_allowlist,
|
||||||
category_denylist=self.category_denylist)
|
category_denylist=self.category_denylist
|
||||||
|
)
|
||||||
|
|
||||||
return _ImageClassifierOptionsProto(
|
return _ImageClassifierOptionsProto(
|
||||||
base_options=base_options_proto,
|
base_options=base_options_proto,
|
||||||
|
@ -198,30 +200,3 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
classifications_module.Classifications.create_from_pb2(classification)
|
classifications_module.Classifications.create_from_pb2(classification)
|
||||||
for classification in classification_result_proto.classifications
|
for classification in classification_result_proto.classifications
|
||||||
])
|
])
|
||||||
|
|
||||||
def classify_async(self, image: image_module.Image, timestamp_ms: int) -> None:
|
|
||||||
"""Sends live image data (an Image with a unique timestamp) to perform image
|
|
||||||
classification.
|
|
||||||
|
|
||||||
This method will return immediately after the input image is accepted. The
|
|
||||||
results will be available via the `result_callback` provided in the
|
|
||||||
`ImageClassifierOptions`. The `detect_async` method is designed to process
|
|
||||||
live stream data such as camera input. To lower the overall latency, image
|
|
||||||
classifier may drop the input images if needed. In other words, it's not
|
|
||||||
guaranteed to have output per input image. The `result_callback` provides:
|
|
||||||
- A classification result object that contains a list of classifications.
|
|
||||||
- The input image that the image classifier runs on.
|
|
||||||
- The input timestamp in milliseconds.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image: MediaPipe Image.
|
|
||||||
timestamp_ms: The timestamp of the input image in milliseconds.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the current input timestamp is smaller than what the image
|
|
||||||
classifier has already processed.
|
|
||||||
"""
|
|
||||||
self._send_live_stream_data({
|
|
||||||
_IMAGE_IN_STREAM_NAME:
|
|
||||||
packet_creator.create_image(image).at(timestamp_ms)
|
|
||||||
})
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user