Added the ClassifyForVideo API
This commit is contained in:
parent
ef31e3a482
commit
e250c903f5
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
@ -121,6 +122,18 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
# a context.
|
# a context.
|
||||||
classifier.close()
|
classifier.close()
|
||||||
|
|
||||||
|
def test_classify_for_video(self):
|
||||||
|
classifier_options = _ClassifierOptions(max_results=4)
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO,
|
||||||
|
classifier_options=classifier_options)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
for timestamp in range(0, 300, 30):
|
||||||
|
classification_result = classifier.classify_for_video(
|
||||||
|
self.test_image, timestamp)
|
||||||
|
self.assertEqual(classification_result, _EXPECTED_CLASSIFICATION_RESULT)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
"""MediaPipe image classifier task."""
|
"""MediaPipe image classifier task."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Callable, List, Mapping, Optional
|
from typing import Callable, Mapping, Optional
|
||||||
|
|
||||||
from mediapipe.python import packet_creator
|
from mediapipe.python import packet_creator
|
||||||
from mediapipe.python import packet_getter
|
from mediapipe.python import packet_getter
|
||||||
|
@ -40,6 +40,7 @@ _TaskRunner = task_runner_module.TaskRunner
|
||||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
||||||
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
|
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
|
||||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||||
|
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||||
_IMAGE_TAG = 'IMAGE'
|
_IMAGE_TAG = 'IMAGE'
|
||||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
|
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
|
||||||
|
|
||||||
|
@ -66,8 +67,8 @@ class ImageClassifierOptions:
|
||||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||||
classifier_options: _ClassifierOptions = _ClassifierOptions()
|
classifier_options: _ClassifierOptions = _ClassifierOptions()
|
||||||
result_callback: Optional[
|
result_callback: Optional[
|
||||||
Callable[[classifications_module.ClassificationResult],
|
Callable[[classifications_module.ClassificationResult, image_module.Image,
|
||||||
None]] = None
|
int], None]] = None
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def to_pb2(self) -> _ImageClassifierGraphOptionsProto:
|
def to_pb2(self) -> _ImageClassifierGraphOptionsProto:
|
||||||
|
@ -134,7 +135,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
classifications_module.Classifications.create_from_pb2(classification)
|
classifications_module.Classifications.create_from_pb2(classification)
|
||||||
for classification in classification_result_proto.classifications
|
for classification in classification_result_proto.classifications
|
||||||
])
|
])
|
||||||
options.result_callback(classification_result)
|
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||||
|
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||||
|
options.result_callback(classification_result, image, timestamp)
|
||||||
|
|
||||||
task_info = _TaskInfo(
|
task_info = _TaskInfo(
|
||||||
task_graph=_TASK_GRAPH_NAME,
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
|
@ -143,7 +146,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
],
|
],
|
||||||
output_streams=[
|
output_streams=[
|
||||||
':'.join([_CLASSIFICATION_RESULT_TAG,
|
':'.join([_CLASSIFICATION_RESULT_TAG,
|
||||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
|
_CLASSIFICATION_RESULT_OUT_STREAM_NAME]),
|
||||||
|
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||||
],
|
],
|
||||||
task_options=options)
|
task_options=options)
|
||||||
return cls(
|
return cls(
|
||||||
|
@ -178,3 +182,37 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
classifications_module.Classifications.create_from_pb2(classification)
|
classifications_module.Classifications.create_from_pb2(classification)
|
||||||
for classification in classification_result_proto.classifications
|
for classification in classification_result_proto.classifications
|
||||||
])
|
])
|
||||||
|
|
||||||
|
def classify_for_video(
|
||||||
|
self, image: image_module.Image,
|
||||||
|
timestamp_ms: int
|
||||||
|
) -> classifications_module.ClassificationResult:
|
||||||
|
"""Performs image classification on the provided video frames.
|
||||||
|
|
||||||
|
Only use this method when the ImageClassifier is created with the video
|
||||||
|
running mode. It's required to provide the video frame's timestamp (in
|
||||||
|
milliseconds) along with the video frame. The input timestamps should be
|
||||||
|
monotonically increasing for adjacent calls of this method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: MediaPipe Image.
|
||||||
|
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A classification result object that contains a list of classifications.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
RuntimeError: If image classification failed to run.
|
||||||
|
"""
|
||||||
|
output_packets = self._process_video_data({
|
||||||
|
_IMAGE_IN_STREAM_NAME:
|
||||||
|
packet_creator.create_image(image).at(timestamp_ms)
|
||||||
|
})
|
||||||
|
classification_result_proto = packet_getter.get_proto(
|
||||||
|
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
|
||||||
|
|
||||||
|
return classifications_module.ClassificationResult([
|
||||||
|
classifications_module.Classifications.create_from_pb2(classification)
|
||||||
|
for classification in classification_result_proto.classifications
|
||||||
|
])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user