Added a check to see if the output packet is empty in the API and updated tests
This commit is contained in:
parent
0a8dbc7576
commit
8ea0018397
|
@ -44,22 +44,25 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||||
|
|
||||||
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
|
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
|
||||||
_IMAGE_FILE = 'burger.jpg'
|
_IMAGE_FILE = 'burger.jpg'
|
||||||
_EMPTY_CLASSIFICATION_RESULT = _ClassificationResult(
|
_ALLOW_LIST = ['cheeseburger', 'guacamole']
|
||||||
|
_DENY_LIST = ['cheeseburger']
|
||||||
|
_SCORE_THRESHOLD = 0.5
|
||||||
|
_MAX_RESULTS = 3
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
||||||
|
return _ClassificationResult(
|
||||||
classifications=[
|
classifications=[
|
||||||
_Classifications(
|
_Classifications(
|
||||||
entries=[
|
entries=[
|
||||||
_ClassificationEntry(
|
_ClassificationEntry(
|
||||||
categories=[],
|
categories=[],
|
||||||
timestamp_ms=0
|
timestamp_ms=timestamp_ms
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
head_index=0,
|
head_index=0,
|
||||||
head_name='probability')
|
head_name='probability')
|
||||||
])
|
])
|
||||||
_ALLOW_LIST = ['cheeseburger', 'guacamole']
|
|
||||||
_DENY_LIST = ['cheeseburger']
|
|
||||||
_SCORE_THRESHOLD = 0.5
|
|
||||||
_MAX_RESULTS = 3
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
|
def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
|
||||||
|
@ -447,16 +450,14 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||||
classifier.classify_async(self.test_image, 0)
|
classifier.classify_async(self.test_image, 0)
|
||||||
|
|
||||||
# TODO: Fix the packet is empty issue.
|
@parameterized.parameters((0, _generate_burger_results),
|
||||||
"""
|
(1, _generate_empty_results))
|
||||||
@parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT),
|
def test_classify_async_calls(self, threshold, expected_result_fn):
|
||||||
(1, _EMPTY_CLASSIFICATION_RESULT))
|
|
||||||
def test_classify_async_calls(self, threshold, expected_result):
|
|
||||||
observed_timestamp_ms = -1
|
observed_timestamp_ms = -1
|
||||||
|
|
||||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||||
timestamp_ms: int):
|
timestamp_ms: int):
|
||||||
self.assertEqual(result, expected_result)
|
self.assertEqual(result, expected_result_fn(timestamp_ms))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
np.array_equal(output_image.numpy_view(),
|
np.array_equal(output_image.numpy_view(),
|
||||||
self.test_image.numpy_view()))
|
self.test_image.numpy_view()))
|
||||||
|
@ -474,7 +475,6 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
classifier.classify_async(self.test_image, timestamp)
|
classifier.classify_async(self.test_image, timestamp)
|
||||||
classifier.close()
|
classifier.close()
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -138,6 +138,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
||||||
|
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||||
|
return
|
||||||
classification_result_proto = packet_getter.get_proto(
|
classification_result_proto = packet_getter.get_proto(
|
||||||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
|
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
|
||||||
|
|
||||||
|
@ -148,7 +150,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||||
options.result_callback(classification_result, image,
|
options.result_callback(classification_result, image,
|
||||||
timestamp.value / _MICRO_SECONDS_PER_MILLISECOND)
|
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
|
||||||
task_info = _TaskInfo(
|
task_info = _TaskInfo(
|
||||||
task_graph=_TASK_GRAPH_NAME,
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user