Added a check to see if the output packet is empty in the API and updated tests

This commit is contained in:
kinaryml 2022-10-11 22:34:56 -07:00
parent 0a8dbc7576
commit 8ea0018397
2 changed files with 22 additions and 20 deletions

View File

@ -44,24 +44,27 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' _MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
_IMAGE_FILE = 'burger.jpg' _IMAGE_FILE = 'burger.jpg'
_EMPTY_CLASSIFICATION_RESULT = _ClassificationResult(
classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[],
timestamp_ms=0
)
],
head_index=0,
head_name='probability')
])
_ALLOW_LIST = ['cheeseburger', 'guacamole'] _ALLOW_LIST = ['cheeseburger', 'guacamole']
_DENY_LIST = ['cheeseburger'] _DENY_LIST = ['cheeseburger']
_SCORE_THRESHOLD = 0.5 _SCORE_THRESHOLD = 0.5
_MAX_RESULTS = 3 _MAX_RESULTS = 3
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
return _ClassificationResult(
classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[],
timestamp_ms=timestamp_ms
)
],
head_index=0,
head_name='probability')
])
def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult: def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
return _ClassificationResult( return _ClassificationResult(
classifications=[ classifications=[
@ -447,16 +450,14 @@ class ImageClassifierTest(parameterized.TestCase):
ValueError, r'Input timestamp must be monotonically increasing'): ValueError, r'Input timestamp must be monotonically increasing'):
classifier.classify_async(self.test_image, 0) classifier.classify_async(self.test_image, 0)
# TODO: Fix the packet is empty issue. @parameterized.parameters((0, _generate_burger_results),
""" (1, _generate_empty_results))
@parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT), def test_classify_async_calls(self, threshold, expected_result_fn):
(1, _EMPTY_CLASSIFICATION_RESULT))
def test_classify_async_calls(self, threshold, expected_result):
observed_timestamp_ms = -1 observed_timestamp_ms = -1
def check_result(result: _ClassificationResult, output_image: _Image, def check_result(result: _ClassificationResult, output_image: _Image,
timestamp_ms: int): timestamp_ms: int):
self.assertEqual(result, expected_result) self.assertEqual(result, expected_result_fn(timestamp_ms))
self.assertTrue( self.assertTrue(
np.array_equal(output_image.numpy_view(), np.array_equal(output_image.numpy_view(),
self.test_image.numpy_view())) self.test_image.numpy_view()))
@ -474,7 +475,6 @@ class ImageClassifierTest(parameterized.TestCase):
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
classifier.classify_async(self.test_image, timestamp) classifier.classify_async(self.test_image, timestamp)
classifier.close() classifier.close()
"""
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -138,6 +138,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
""" """
def packets_callback(output_packets: Mapping[str, packet_module.Packet]): def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return
classification_result_proto = packet_getter.get_proto( classification_result_proto = packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
@ -148,7 +150,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(classification_result, image, options.result_callback(classification_result, image,
timestamp.value / _MICRO_SECONDS_PER_MILLISECOND) timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,