Added a check to see if the output packet is empty in the API and updated tests

This commit is contained in:
kinaryml 2022-10-11 22:34:56 -07:00
parent 0a8dbc7576
commit 8ea0018397
2 changed files with 22 additions and 20 deletions

View File

@ -44,24 +44,27 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
_IMAGE_FILE = 'burger.jpg'
_EMPTY_CLASSIFICATION_RESULT = _ClassificationResult(
classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[],
timestamp_ms=0
)
],
head_index=0,
head_name='probability')
])
_ALLOW_LIST = ['cheeseburger', 'guacamole']
_DENY_LIST = ['cheeseburger']
_SCORE_THRESHOLD = 0.5
_MAX_RESULTS = 3
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
return _ClassificationResult(
classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[],
timestamp_ms=timestamp_ms
)
],
head_index=0,
head_name='probability')
])
def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
return _ClassificationResult(
classifications=[
@ -447,16 +450,14 @@ class ImageClassifierTest(parameterized.TestCase):
ValueError, r'Input timestamp must be monotonically increasing'):
classifier.classify_async(self.test_image, 0)
# TODO: Fix the packet is empty issue.
"""
@parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT),
(1, _EMPTY_CLASSIFICATION_RESULT))
def test_classify_async_calls(self, threshold, expected_result):
@parameterized.parameters((0, _generate_burger_results),
(1, _generate_empty_results))
def test_classify_async_calls(self, threshold, expected_result_fn):
observed_timestamp_ms = -1
def check_result(result: _ClassificationResult, output_image: _Image,
timestamp_ms: int):
self.assertEqual(result, expected_result)
self.assertEqual(result, expected_result_fn(timestamp_ms))
self.assertTrue(
np.array_equal(output_image.numpy_view(),
self.test_image.numpy_view()))
@ -474,7 +475,6 @@ class ImageClassifierTest(parameterized.TestCase):
for timestamp in range(0, 300, 30):
classifier.classify_async(self.test_image, timestamp)
classifier.close()
"""
if __name__ == '__main__':

View File

@ -138,6 +138,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
"""
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return
classification_result_proto = packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
@ -148,7 +150,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(classification_result, image,
timestamp.value / _MICRO_SECONDS_PER_MILLISECOND)
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,