Added the ClassifyForVideo API
This commit is contained in:
parent
ef31e3a482
commit
e250c903f5
|
@ -15,6 +15,7 @@
|
|||
|
||||
import enum
|
||||
|
||||
import numpy as np
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
|
@ -121,6 +122,18 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
# a context.
|
||||
classifier.close()
|
||||
|
||||
def test_classify_for_video(self):
|
||||
classifier_options = _ClassifierOptions(max_results=4)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
classifier_options=classifier_options)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
self.test_image, timestamp)
|
||||
self.assertEqual(classification_result, _EXPECTED_CLASSIFICATION_RESULT)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
"""MediaPipe image classifier task."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Callable, List, Mapping, Optional
|
||||
from typing import Callable, Mapping, Optional
|
||||
|
||||
from mediapipe.python import packet_creator
|
||||
from mediapipe.python import packet_getter
|
||||
|
@ -40,6 +40,7 @@ _TaskRunner = task_runner_module.TaskRunner
|
|||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
||||
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
|
||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||
_IMAGE_TAG = 'IMAGE'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
|
||||
|
||||
|
@ -66,8 +67,8 @@ class ImageClassifierOptions:
|
|||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
classifier_options: _ClassifierOptions = _ClassifierOptions()
|
||||
result_callback: Optional[
|
||||
Callable[[classifications_module.ClassificationResult],
|
||||
None]] = None
|
||||
Callable[[classifications_module.ClassificationResult, image_module.Image,
|
||||
int], None]] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ImageClassifierGraphOptionsProto:
|
||||
|
@ -134,7 +135,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
classifications_module.Classifications.create_from_pb2(classification)
|
||||
for classification in classification_result_proto.classifications
|
||||
])
|
||||
options.result_callback(classification_result)
|
||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||
options.result_callback(classification_result, image, timestamp)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
|
@ -143,7 +146,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
],
|
||||
output_streams=[
|
||||
':'.join([_CLASSIFICATION_RESULT_TAG,
|
||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
|
||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME]),
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||
],
|
||||
task_options=options)
|
||||
return cls(
|
||||
|
@ -175,6 +179,40 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
|
||||
|
||||
return classifications_module.ClassificationResult([
|
||||
classifications_module.Classifications.create_from_pb2(classification)
|
||||
for classification in classification_result_proto.classifications
|
||||
classifications_module.Classifications.create_from_pb2(classification)
|
||||
for classification in classification_result_proto.classifications
|
||||
])
|
||||
|
||||
def classify_for_video(
|
||||
self, image: image_module.Image,
|
||||
timestamp_ms: int
|
||||
) -> classifications_module.ClassificationResult:
|
||||
"""Performs image classification on the provided video frames.
|
||||
|
||||
Only use this method when the ImageClassifier is created with the video
|
||||
running mode. It's required to provide the video frame's timestamp (in
|
||||
milliseconds) along with the video frame. The input timestamps should be
|
||||
monotonically increasing for adjacent calls of this method.
|
||||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||
|
||||
Returns:
|
||||
A classification result object that contains a list of classifications.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If image classification failed to run.
|
||||
"""
|
||||
output_packets = self._process_video_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(timestamp_ms)
|
||||
})
|
||||
classification_result_proto = packet_getter.get_proto(
|
||||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
|
||||
|
||||
return classifications_module.ClassificationResult([
|
||||
classifications_module.Classifications.create_from_pb2(classification)
|
||||
for classification in classification_result_proto.classifications
|
||||
])
|
||||
|
|
Loading…
Reference in New Issue
Block a user