Added the ClassifyForVideo API

This commit is contained in:
kinaryml 2022-10-05 05:24:52 -07:00
parent ef31e3a482
commit e250c903f5
2 changed files with 58 additions and 7 deletions

View File

@ -15,6 +15,7 @@
import enum
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
@ -121,6 +122,18 @@ class ImageClassifierTest(parameterized.TestCase):
# a context.
classifier.close()
def test_classify_for_video(self):
classifier_options = _ClassifierOptions(max_results=4)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video(
self.test_image, timestamp)
self.assertEqual(classification_result, _EXPECTED_CLASSIFICATION_RESULT)
if __name__ == '__main__':
absltest.main()

View File

@ -14,7 +14,7 @@
"""MediaPipe image classifier task."""
import dataclasses
from typing import Callable, List, Mapping, Optional
from typing import Callable, Mapping, Optional
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
@ -40,6 +40,7 @@ _TaskRunner = task_runner_module.TaskRunner
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
@ -66,8 +67,8 @@ class ImageClassifierOptions:
running_mode: _RunningMode = _RunningMode.IMAGE
classifier_options: _ClassifierOptions = _ClassifierOptions()
result_callback: Optional[
Callable[[classifications_module.ClassificationResult],
None]] = None
Callable[[classifications_module.ClassificationResult, image_module.Image,
int], None]] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageClassifierGraphOptionsProto:
@ -134,7 +135,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
classifications_module.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])
options.result_callback(classification_result)
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(classification_result, image, timestamp)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
@ -143,7 +146,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
],
output_streams=[
':'.join([_CLASSIFICATION_RESULT_TAG,
_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
_CLASSIFICATION_RESULT_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
],
task_options=options)
return cls(
@ -175,6 +179,40 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
return classifications_module.ClassificationResult([
classifications_module.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
classifications_module.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])
def classify_for_video(
self, image: image_module.Image,
timestamp_ms: int
) -> classifications_module.ClassificationResult:
"""Performs image classification on the provided video frames.
Only use this method when the ImageClassifier is created with the video
running mode. It's required to provide the video frame's timestamp (in
milliseconds) along with the video frame. The input timestamps should be
monotonically increasing for adjacent calls of this method.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds.
Returns:
A classification result object that contains a list of classifications.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If image classification failed to run.
"""
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(timestamp_ms)
})
classification_result_proto = packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
return classifications_module.ClassificationResult([
classifications_module.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])