Added a check to see if the output packet is empty in the API and updated tests
This commit is contained in:
parent
0a8dbc7576
commit
8ea0018397
|
@ -44,22 +44,25 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
|||
|
||||
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
|
||||
_IMAGE_FILE = 'burger.jpg'
|
||||
_EMPTY_CLASSIFICATION_RESULT = _ClassificationResult(
|
||||
_ALLOW_LIST = ['cheeseburger', 'guacamole']
|
||||
_DENY_LIST = ['cheeseburger']
|
||||
_SCORE_THRESHOLD = 0.5
|
||||
_MAX_RESULTS = 3
|
||||
|
||||
|
||||
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
||||
return _ClassificationResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
entries=[
|
||||
_ClassificationEntry(
|
||||
categories=[],
|
||||
timestamp_ms=0
|
||||
timestamp_ms=timestamp_ms
|
||||
)
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability')
|
||||
])
|
||||
_ALLOW_LIST = ['cheeseburger', 'guacamole']
|
||||
_DENY_LIST = ['cheeseburger']
|
||||
_SCORE_THRESHOLD = 0.5
|
||||
_MAX_RESULTS = 3
|
||||
|
||||
|
||||
def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
|
||||
|
@ -447,16 +450,14 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
classifier.classify_async(self.test_image, 0)
|
||||
|
||||
# TODO: Fix the packet is empty issue.
|
||||
"""
|
||||
@parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT),
|
||||
(1, _EMPTY_CLASSIFICATION_RESULT))
|
||||
def test_classify_async_calls(self, threshold, expected_result):
|
||||
@parameterized.parameters((0, _generate_burger_results),
|
||||
(1, _generate_empty_results))
|
||||
def test_classify_async_calls(self, threshold, expected_result_fn):
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
self.assertEqual(result, expected_result)
|
||||
self.assertEqual(result, expected_result_fn(timestamp_ms))
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(),
|
||||
self.test_image.numpy_view()))
|
||||
|
@ -474,7 +475,6 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
for timestamp in range(0, 300, 30):
|
||||
classifier.classify_async(self.test_image, timestamp)
|
||||
classifier.close()
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -138,6 +138,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
"""
|
||||
|
||||
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
||||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||
return
|
||||
classification_result_proto = packet_getter.get_proto(
|
||||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
|
||||
|
||||
|
@ -148,7 +150,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||
options.result_callback(classification_result, image,
|
||||
timestamp.value / _MICRO_SECONDS_PER_MILLISECOND)
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
|
|
Loading…
Reference in New Issue
Block a user