Added the ClassifyForVideo API

This commit is contained in:
kinaryml 2022-10-05 05:24:52 -07:00
parent ef31e3a482
commit e250c903f5
2 changed files with 58 additions and 7 deletions

View File

@ -15,6 +15,7 @@
import enum import enum
import numpy as np
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
@ -121,6 +122,18 @@ class ImageClassifierTest(parameterized.TestCase):
# a context. # a context.
classifier.close() classifier.close()
def test_classify_for_video(self):
classifier_options = _ClassifierOptions(max_results=4)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video(
self.test_image, timestamp)
self.assertEqual(classification_result, _EXPECTED_CLASSIFICATION_RESULT)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -14,7 +14,7 @@
"""MediaPipe image classifier task.""" """MediaPipe image classifier task."""
import dataclasses import dataclasses
from typing import Callable, List, Mapping, Optional from typing import Callable, Mapping, Optional
from mediapipe.python import packet_creator from mediapipe.python import packet_creator
from mediapipe.python import packet_getter from mediapipe.python import packet_getter
@ -40,6 +40,7 @@ _TaskRunner = task_runner_module.TaskRunner
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' _CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
@ -66,8 +67,8 @@ class ImageClassifierOptions:
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
classifier_options: _ClassifierOptions = _ClassifierOptions() classifier_options: _ClassifierOptions = _ClassifierOptions()
result_callback: Optional[ result_callback: Optional[
Callable[[classifications_module.ClassificationResult], Callable[[classifications_module.ClassificationResult, image_module.Image,
None]] = None int], None]] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageClassifierGraphOptionsProto: def to_pb2(self) -> _ImageClassifierGraphOptionsProto:
@ -134,7 +135,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
classifications_module.Classifications.create_from_pb2(classification) classifications_module.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications for classification in classification_result_proto.classifications
]) ])
options.result_callback(classification_result) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(classification_result, image, timestamp)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
@ -143,7 +146,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
], ],
output_streams=[ output_streams=[
':'.join([_CLASSIFICATION_RESULT_TAG, ':'.join([_CLASSIFICATION_RESULT_TAG,
_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) _CLASSIFICATION_RESULT_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
], ],
task_options=options) task_options=options)
return cls( return cls(
@ -175,6 +179,40 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
return classifications_module.ClassificationResult([ return classifications_module.ClassificationResult([
classifications_module.Classifications.create_from_pb2(classification) classifications_module.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications for classification in classification_result_proto.classifications
])
def classify_for_video(
self, image: image_module.Image,
timestamp_ms: int
) -> classifications_module.ClassificationResult:
"""Performs image classification on the provided video frames.
Only use this method when the ImageClassifier is created with the video
running mode. It's required to provide the video frame's timestamp (in
milliseconds) along with the video frame. The input timestamps should be
monotonically increasing for adjacent calls of this method.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds.
Returns:
A classification result object that contains a list of classifications.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If image classification failed to run.
"""
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(timestamp_ms)
})
classification_result_proto = packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
return classifications_module.ClassificationResult([
classifications_module.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
]) ])