From 022838a7f378d199876a9ab1c1c6b9ace03c8b29 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 9 Mar 2023 01:36:39 -0800 Subject: [PATCH 01/10] Added Face Detector implementation and tests --- mediapipe/python/BUILD | 1 + .../tasks/python/components/containers/BUILD | 10 + .../components/containers/detections.py | 40 +- .../python/components/containers/keypoint.py | 78 ++++ mediapipe/tasks/python/test/vision/BUILD | 22 + .../python/test/vision/face_detector_test.py | 407 ++++++++++++++++++ mediapipe/tasks/python/vision/BUILD | 20 + .../tasks/python/vision/face_detector.py | 308 +++++++++++++ 8 files changed, 882 insertions(+), 4 deletions(-) create mode 100644 mediapipe/tasks/python/components/containers/keypoint.py create mode 100644 mediapipe/tasks/python/test/vision/face_detector_test.py create mode 100644 mediapipe/tasks/python/vision/face_detector.py diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index f56e5b3d4..141b59d71 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -94,6 +94,7 @@ cc_library( "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", + "//mediapipe/tasks/cc/vision/face_detector:face_detector_graph", ] + select({ # TODO: Build text_classifier_graph and text_embedder_graph on Windows. "//mediapipe:windows": [], diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 7108617ff..b84ab744d 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -73,12 +73,22 @@ py_library( ], ) +py_library( + name = "keypoint", + srcs = ["keypoint.py"], + deps = [ + "//mediapipe/framework/formats:location_data_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) + py_library( name = "detections", srcs = ["detections.py"], deps = [ ":bounding_box", ":category", + ":keypoint", "//mediapipe/framework/formats:detection_py_pb2", "//mediapipe/framework/formats:location_data_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/components/containers/detections.py b/mediapipe/tasks/python/components/containers/detections.py index b4d550633..94fe16096 100644 --- a/mediapipe/tasks/python/components/containers/detections.py +++ b/mediapipe/tasks/python/components/containers/detections.py @@ -20,6 +20,7 @@ from mediapipe.framework.formats import detection_pb2 from mediapipe.framework.formats import location_data_pb2 from mediapipe.tasks.python.components.containers import bounding_box as bounding_box_module from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.components.containers import keypoint as keypoint_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls _DetectionListProto = detection_pb2.DetectionList @@ -34,10 +35,12 @@ class Detection: Attributes: bounding_box: A BoundingBox object. categories: A list of Category objects. + keypoints: A list of NormalizedKeypoint objects. """ - bounding_box: bounding_box_module.BoundingBox - categories: List[category_module.Category] + bounding_box: bounding_box_module.BoundingBox = None + categories: List[category_module.Category] = None + keypoints: List[keypoint_module.NormalizedKeypoint] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _DetectionProto: @@ -46,6 +49,8 @@ class Detection: label_ids = [] scores = [] display_names = [] + relative_keypoints = [] + for category in self.categories: scores.append(category.score) if category.index: @@ -54,6 +59,20 @@ class Detection: labels.append(category.category_name) if category.display_name: display_names.append(category.display_name) + + if self.keypoints: + for keypoint in self.keypoints: + relative_keypoint_proto = _LocationDataProto.RelativeKeypoint() + if keypoint.x: + relative_keypoint_proto.x = keypoint.x + if keypoint.y: + relative_keypoint_proto.y = keypoint.y + if keypoint.label: + relative_keypoint_proto.keypoint_label = keypoint.label + if keypoint.score: + relative_keypoint_proto.score = keypoint.score + relative_keypoints.append(relative_keypoint_proto) + return _DetectionProto( label=labels, label_id=label_ids, @@ -61,13 +80,16 @@ class Detection: display_name=display_names, location_data=_LocationDataProto( format=_LocationDataProto.Format.BOUNDING_BOX, - bounding_box=self.bounding_box.to_pb2())) + bounding_box=self.bounding_box.to_pb2(), + relative_keypoints=relative_keypoints)) @classmethod @doc_controls.do_not_generate_docs def create_from_pb2(cls, pb2_obj: _DetectionProto) -> 'Detection': """Creates a `Detection` object from the given protobuf object.""" categories = [] + keypoints = [] + for idx, score in enumerate(pb2_obj.score): categories.append( category_module.Category( @@ -79,10 +101,20 @@ class Detection: display_name=pb2_obj.display_name[idx] if idx < len(pb2_obj.display_name) else None)) + if pb2_obj.location_data.relative_keypoints: + for idx in range(len(pb2_obj.location_data.relative_keypoints)): + keypoints.append( + keypoint_module.NormalizedKeypoint( + x=pb2_obj.location_data.relative_keypoints[idx].x, + y=pb2_obj.location_data.relative_keypoints[idx].y, + label=pb2_obj.location_data.relative_keypoints[idx].keypoint_label, + score=pb2_obj.location_data.relative_keypoints[idx].score)) + return Detection( bounding_box=bounding_box_module.BoundingBox.create_from_pb2( pb2_obj.location_data.bounding_box), - categories=categories) + categories=categories, + keypoints=keypoints) def __eq__(self, other: Any) -> bool: """Checks if this object is equal to the given object. diff --git a/mediapipe/tasks/python/components/containers/keypoint.py b/mediapipe/tasks/python/components/containers/keypoint.py new file mode 100644 index 000000000..ef70c00b9 --- /dev/null +++ b/mediapipe/tasks/python/components/containers/keypoint.py @@ -0,0 +1,78 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Keypoint data class.""" + +import dataclasses +from typing import Any, Optional + +from mediapipe.framework.formats import location_data_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_RelativeKeypointProto = location_data_pb2.LocationData.RelativeKeypoint + + +@dataclasses.dataclass +class NormalizedKeypoint: + """A normalized keypoint. + + Normalized keypoint represents a point in 2D space with x, y coordinates. + x and y are normalized to [0.0, 1.0] by the image width and height + respectively. + + Attributes: + x: The x coordinates of the normalized keypoint. + y: The y coordinates of the normalized keypoint. + label: The optional label of the keypoint. + score: The score of the keypoint. + """ + + x: Optional[float] = None + y: Optional[float] = None + label: Optional[str] = None + score: Optional[str] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _RelativeKeypointProto: + """Generates a RelativeKeypoint protobuf object.""" + return _RelativeKeypointProto( + x=self.x, + y=self.y, + keypoint_label=self.label, + score=self.score + ) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, + pb2_obj: _RelativeKeypointProto) -> 'NormalizedKeypoint': + """Creates a `NormalizedKeypoint` object from the given protobuf object.""" + return NormalizedKeypoint( + x=pb2_obj.x, + y=pb2_obj.y, + label=pb2_obj.keypoint_label, + score=pb2_obj.score) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, NormalizedKeypoint): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 48ecc30b3..813f76bdb 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -114,3 +114,25 @@ py_test( "@com_google_protobuf//:protobuf_python", ], ) + +py_test( + name = "face_detector_test", + srcs = ["face_detector_test.py"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + "//mediapipe/tasks/testdata/vision:test_protos", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/components/containers:bounding_box", + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:detections", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/vision:face_detector", + "//mediapipe/tasks/python/vision/core:image_processing_options", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + "@com_google_protobuf//:protobuf_python", + ], +) diff --git a/mediapipe/tasks/python/test/vision/face_detector_test.py b/mediapipe/tasks/python/test/vision/face_detector_test.py new file mode 100644 index 000000000..90a52d110 --- /dev/null +++ b/mediapipe/tasks/python/test/vision/face_detector_test.py @@ -0,0 +1,407 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for face detector.""" + +import enum +import os +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +from google.protobuf import text_format + +from mediapipe.framework.formats import detection_pb2 +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.components.containers import bounding_box as bounding_box_module +from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.components.containers import detections as detections_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils +from mediapipe.tasks.python.vision import face_detector +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module + + +FaceDetectorResult = detections_module.DetectionResult +_BaseOptions = base_options_module.BaseOptions +_Category = category_module.Category +_BoundingBox = bounding_box_module.BoundingBox +_Detection = detections_module.Detection +_Image = image_module.Image +_FaceDetector = face_detector.FaceDetector +_FaceDetectorOptions = face_detector.FaceDetectorOptions +_RUNNING_MODE = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions + +_SHORT_RANGE_BLAZE_FACE_MODEL = 'face_detection_short_range.tflite' +_PORTRAIT_IMAGE = 'portrait.jpg' +_PORTRAIT_EXPECTED_DETECTION = 'portrait_expected_detection.pbtxt' +_PORTRAIT_ROTATED_IMAGE = 'portrait_rotated.jpg' +_PORTRAIT_ROTATED_EXPECTED_DETECTION = 'portrait_rotated_expected_detection.pbtxt' +_CAT_IMAGE = 'cat.jpg' +_KEYPOINT_ERROR_THRESHOLD = 1e-2 +_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' + + +def _get_expected_face_detector_result(file_name: str) -> FaceDetectorResult: + face_detection_result_file_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, file_name)) + with open(face_detection_result_file_path, "rb") as f: + face_detection_proto = detection_pb2.Detection() + text_format.Parse(f.read(), face_detection_proto) + face_detection = detections_module.Detection.create_from_pb2(face_detection_proto) + return FaceDetectorResult(detections=[face_detection]) + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class FaceDetectorTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _PORTRAIT_IMAGE))) + self.model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _SHORT_RANGE_BLAZE_FACE_MODEL)) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _FaceDetector.create_from_model_path(self.model_path) as detector: + self.assertIsInstance(detector, _FaceDetector) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _FaceDetectorOptions(base_options=base_options) + with _FaceDetector.create_from_options(options) as detector: + self.assertIsInstance(detector, _FaceDetector) + + def test_create_from_options_fails_with_invalid_model_path(self): + with self.assertRaisesRegex( + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite') + options = _FaceDetectorOptions(base_options=base_options) + _FaceDetector.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _FaceDetectorOptions(base_options=base_options) + detector = _FaceDetector.create_from_options(options) + self.assertIsInstance(detector, _FaceDetector) + + def _expect_keypoints_correct(self, actual_keypoints, expected_keypoints): + self.assertLen(actual_keypoints, len(expected_keypoints)) + for i in range(len(actual_keypoints)): + self.assertAlmostEqual( + actual_keypoints[i].x, expected_keypoints[i].x, + delta=_KEYPOINT_ERROR_THRESHOLD) + self.assertAlmostEqual( + actual_keypoints[i].y, expected_keypoints[i].y, + delta=_KEYPOINT_ERROR_THRESHOLD) + + def _expect_face_detector_results_correct(self, actual_results, expected_results): + self.assertLen(actual_results.detections, len(expected_results.detections)) + for i in range(len(actual_results.detections)): + actual_bbox = actual_results.detections[i].bounding_box + expected_bbox = expected_results.detections[i].bounding_box + self.assertEqual(actual_bbox, expected_bbox) + self.assertNotEmpty(actual_results.detections[i].keypoints) + self._expect_keypoints_correct(actual_results.detections[i].keypoints, + expected_results.detections[i].keypoints) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _PORTRAIT_EXPECTED_DETECTION), + (ModelFileType.FILE_CONTENT, _PORTRAIT_EXPECTED_DETECTION)) + def test_detect(self, model_file_type, expected_detection_result_file): + # Creates detector. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _FaceDetectorOptions(base_options=base_options) + detector = _FaceDetector.create_from_options(options) + + # Performs face detection on the input. + detection_result = detector.detect(self.test_image) + # Comparing results. + expected_detection_result = _get_expected_face_detector_result( + expected_detection_result_file) + self._expect_face_detector_results_correct(detection_result, + expected_detection_result) + # Closes the detector explicitly when the detector is not used in + # a context. + detector.close() + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _PORTRAIT_EXPECTED_DETECTION), + (ModelFileType.FILE_CONTENT, _PORTRAIT_EXPECTED_DETECTION)) + def test_detect_in_context(self, model_file_type, expected_detection_result_file): + # Creates detector. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _FaceDetectorOptions(base_options=base_options) + + with _FaceDetector.create_from_options(options) as detector: + # Performs face detection on the input. + detection_result = detector.detect(self.test_image) + # Comparing results. + expected_detection_result = _get_expected_face_detector_result( + expected_detection_result_file) + self._expect_face_detector_results_correct(detection_result, + expected_detection_result) + + def test_detect_succeeds_with_rotated_image(self): + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _FaceDetectorOptions(base_options=base_options) + with _FaceDetector.create_from_options(options) as detector: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _PORTRAIT_ROTATED_IMAGE))) + # Rotated input image. + image_processing_options = _ImageProcessingOptions(rotation_degrees=-90) + # Performs face detection on the input. + detection_result = detector.detect(test_image, image_processing_options) + # Comparing results. + expected_detection_result = _get_expected_face_detector_result( + _PORTRAIT_ROTATED_EXPECTED_DETECTION) + self._expect_face_detector_results_correct(detection_result, + expected_detection_result) + + def test_empty_detection_outputs(self): + # Load a test image with no faces. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))) + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path)) + with _FaceDetector.create_from_options(options) as detector: + # Performs object detection on the input. + detection_result = detector.detect(test_image) + self.assertEmpty(detection_result.detections) + + def test_missing_result_callback(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM) + with self.assertRaisesRegex(ValueError, + r'result callback must be provided'): + with _FaceDetector.create_from_options(options) as unused_detector: + pass + + @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO)) + def test_illegal_result_callback(self, running_mode): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=running_mode, + result_callback=mock.MagicMock()) + with self.assertRaisesRegex(ValueError, + r'result callback should not be provided'): + with _FaceDetector.create_from_options(options) as unused_detector: + pass + + def test_calling_detect_for_video_in_image_mode(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE) + with _FaceDetector.create_from_options(options) as detector: + with self.assertRaisesRegex(ValueError, + r'not initialized with the video mode'): + detector.detect_for_video(self.test_image, 0) + + def test_calling_detect_async_in_image_mode(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE) + with _FaceDetector.create_from_options(options) as detector: + with self.assertRaisesRegex(ValueError, + r'not initialized with the live stream mode'): + detector.detect_async(self.test_image, 0) + + def test_calling_detect_in_video_mode(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _FaceDetector.create_from_options(options) as detector: + with self.assertRaisesRegex(ValueError, + r'not initialized with the image mode'): + detector.detect(self.test_image) + + def test_calling_detect_async_in_video_mode(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _FaceDetector.create_from_options(options) as detector: + with self.assertRaisesRegex(ValueError, + r'not initialized with the live stream mode'): + detector.detect_async(self.test_image, 0) + + def test_detect_for_video_with_out_of_order_timestamp(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _FaceDetector.create_from_options(options) as detector: + unused_result = detector.detect_for_video(self.test_image, 1) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing'): + detector.detect_for_video(self.test_image, 0) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _PORTRAIT_IMAGE, 0, + _get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION)), + (ModelFileType.FILE_CONTENT, _PORTRAIT_IMAGE, 0, + _get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION)), + (ModelFileType.FILE_NAME, _PORTRAIT_ROTATED_IMAGE, -90, + _get_expected_face_detector_result(_PORTRAIT_ROTATED_EXPECTED_DETECTION)), + (ModelFileType.FILE_CONTENT, _PORTRAIT_ROTATED_IMAGE, -90, + _get_expected_face_detector_result(_PORTRAIT_ROTATED_EXPECTED_DETECTION)), + (ModelFileType.FILE_NAME, _CAT_IMAGE, 0, FaceDetectorResult([])), + (ModelFileType.FILE_CONTENT, _CAT_IMAGE, 0, FaceDetectorResult([]))) + def test_detect_for_video(self, model_file_type, test_image_file_name, + rotation_degrees, expected_detection_result): + # Creates detector. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _FaceDetectorOptions(base_options=base_options, + running_mode=_RUNNING_MODE.VIDEO) + + with _FaceDetector.create_from_options(options) as detector: + for timestamp in range(0, 300, 30): + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, test_image_file_name))) + # Set the image processing options. + image_processing_options = _ImageProcessingOptions( + rotation_degrees=rotation_degrees) + # Performs face detection on the input. + detection_result = detector.detect_for_video(test_image, timestamp, + image_processing_options) + # Comparing results. + self._expect_face_detector_results_correct(detection_result, + expected_detection_result) + + def test_calling_detect_in_live_stream_mode(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock()) + with _FaceDetector.create_from_options(options) as detector: + with self.assertRaisesRegex(ValueError, + r'not initialized with the image mode'): + detector.detect(self.test_image) + + def test_calling_detect_for_video_in_live_stream_mode(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock()) + with _FaceDetector.create_from_options(options) as detector: + with self.assertRaisesRegex(ValueError, + r'not initialized with the video mode'): + detector.detect_for_video(self.test_image, 0) + + def test_detect_async_calls_with_illegal_timestamp(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock()) + with _FaceDetector.create_from_options(options) as detector: + detector.detect_async(self.test_image, 100) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing'): + detector.detect_async(self.test_image, 0) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _PORTRAIT_IMAGE, 0, + _get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION)), + (ModelFileType.FILE_CONTENT, _PORTRAIT_IMAGE, 0, + _get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION)), + (ModelFileType.FILE_NAME, _PORTRAIT_ROTATED_IMAGE, -90, + _get_expected_face_detector_result(_PORTRAIT_ROTATED_EXPECTED_DETECTION)), + (ModelFileType.FILE_CONTENT, _PORTRAIT_ROTATED_IMAGE, -90, + _get_expected_face_detector_result(_PORTRAIT_ROTATED_EXPECTED_DETECTION)), + (ModelFileType.FILE_NAME, _CAT_IMAGE, 0, FaceDetectorResult([])), + (ModelFileType.FILE_CONTENT, _CAT_IMAGE, 0, FaceDetectorResult([]))) + def test_detect_async_calls(self, model_file_type, test_image_file_name, + rotation_degrees, expected_detection_result): + # Creates detector. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + observed_timestamp_ms = -1 + + def check_result(result: FaceDetectorResult, output_image: _Image, + timestamp_ms: int): + self._expect_face_detector_results_correct(result, + expected_detection_result) + self.assertLess(observed_timestamp_ms, timestamp_ms) + self.observed_timestamp_ms = timestamp_ms + + options = _FaceDetectorOptions(base_options=base_options, + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=check_result) + + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, test_image_file_name))) + + with _FaceDetector.create_from_options(options) as detector: + for timestamp in range(0, 300, 30): + # Set the image processing options. + image_processing_options = _ImageProcessingOptions( + rotation_degrees=rotation_degrees) + detector.detect_async(test_image, timestamp, image_processing_options) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index eda8e290d..891286641 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -152,3 +152,23 @@ py_library( "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) + +py_library( + name = "face_detector", + srcs = [ + "face_detector.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_py_pb2", + "//mediapipe/tasks/python/components/containers:detections", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/vision/core:base_vision_task_api", + "//mediapipe/tasks/python/vision/core:image_processing_options", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) diff --git a/mediapipe/tasks/python/vision/face_detector.py b/mediapipe/tasks/python/vision/face_detector.py new file mode 100644 index 000000000..91baecff4 --- /dev/null +++ b/mediapipe/tasks/python/vision/face_detector.py @@ -0,0 +1,308 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe face detector task.""" + +import dataclasses +from typing import Callable, Mapping, Optional + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.python._framework_bindings import packet as packet_module +from mediapipe.tasks.cc.vision.face_detector.proto import face_detector_graph_options_pb2 +from mediapipe.tasks.python.components.containers import detections as detections_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.vision.core import base_vision_task_api +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + +FaceDetectorResult = detections_module.DetectionResult +_BaseOptions = base_options_module.BaseOptions +_FaceDetectorGraphOptionsProto = face_detector_graph_options_pb2.FaceDetectorGraphOptions +_RunningMode = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions +_TaskInfo = task_info_module.TaskInfo + +_DETECTIONS_OUT_STREAM_NAME = 'detections' +_DETECTIONS_TAG = 'DETECTIONS' +_NORM_RECT_STREAM_NAME = 'norm_rect_in' +_NORM_RECT_TAG = 'NORM_RECT' +_IMAGE_IN_STREAM_NAME = 'image_in' +_IMAGE_OUT_STREAM_NAME = 'image_out' +_IMAGE_TAG = 'IMAGE' +_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_detector.FaceDetectorGraph' +_MICRO_SECONDS_PER_MILLISECOND = 1000 + + +@dataclasses.dataclass +class FaceDetectorOptions: + """Options for the face detector task. + + Attributes: + base_options: Base options for the face detector task. + running_mode: The running mode of the task. Default to the image mode. + Face detector task has three running modes: + 1) The image mode for detecting faces on single image inputs. + 2) The video mode for detecting faces on the decoded frames of a video. + 3) The live stream mode for detecting faces on a live stream of input + data, such as from camera. + min_detection_confidence: The minimum confidence score for the face + detection to be considered successful. + min_suppression_threshold: The minimum non-maximum-suppression threshold + for face detection to be considered overlapped. + num_faces: Maximum number of faces to detect in the image. + result_callback: The user-defined result callback for processing live stream + data. The result callback should only be specified when the running mode + is set to the live stream mode. + """ + base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.IMAGE + min_detection_confidence: Optional[float] = None + min_suppression_threshold: Optional[float] = None + num_faces: Optional[int] = None + result_callback: Optional[ + Callable[[detections_module.DetectionResult, image_module.Image, int], + None]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _FaceDetectorGraphOptionsProto: + """Generates an FaceDetectorOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True + return _FaceDetectorGraphOptionsProto( + base_options=base_options_proto, + min_detection_confidence=self.min_detection_confidence, + min_suppression_threshold=self.min_suppression_threshold, + num_faces=self.num_faces + ) + + +class FaceDetector(base_vision_task_api.BaseVisionTaskApi): + """Class that performs face detection on images.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'FaceDetector': + """Creates an `FaceDetector` object from a TensorFlow Lite model and the default `FaceDetectorOptions`. + + Note that the created `FaceDetector` instance is in image mode, for + detecting faces on single image inputs. + + Args: + model_path: Path to the model. + + Returns: + `FaceDetector` object that's created from the model file and the default + `FaceDetectorOptions`. + + Raises: + ValueError: If failed to create `FaceDetector` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = FaceDetectorOptions( + base_options=base_options, running_mode=_RunningMode.IMAGE) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, + options: FaceDetectorOptions) -> 'FaceDetector': + """Creates the `FaceDetector` object from face detector options. + + Args: + options: Options for the face detector task. + + Returns: + `FaceDetector` object that's created from `options`. + + Raises: + ValueError: If failed to create `FaceDetector` object from + `FaceDetectorOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + + def packets_callback(output_packets: Mapping[str, packet_module.Packet]): + if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): + return + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) + if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty(): + empty_packet = output_packets[_DETECTIONS_OUT_STREAM_NAME] + options.result_callback( + FaceDetectorResult([]), image, + empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) + return + detection_proto_list = packet_getter.get_proto_list( + output_packets[_DETECTIONS_OUT_STREAM_NAME]) + detection_result = detections_module.DetectionResult([ + detections_module.Detection.create_from_pb2(result) + for result in detection_proto_list + ]) + + timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp + options.result_callback(detection_result, image, + timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) + + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[ + ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), + ], + output_streams=[ + ':'.join([_DETECTIONS_TAG, _DETECTIONS_OUT_STREAM_NAME]), + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) + ], + task_options=options) + return cls( + task_info.generate_graph_config( + enable_flow_limiting=options.running_mode == + _RunningMode.LIVE_STREAM), options.running_mode, + packets_callback if options.result_callback else None) + + def detect( + self, + image: image_module.Image, + image_processing_options: Optional[_ImageProcessingOptions] = None + ) -> FaceDetectorResult: + """Performs face detection on the provided MediaPipe Image. + + Only use this method when the FaceDetector is created with the image + running mode. + + Args: + image: MediaPipe Image. + image_processing_options: Options for image processing. + + Returns: + A face detection result object that contains a list of face detections, + each detection has a bounding box that is expressed in the unrotated input + frame of reference coordinates system, i.e. in `[0,image_width) x [0, + image_height)`, which are the dimensions of the underlying image data. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If face detection failed to run. + """ + normalized_rect = self.convert_to_normalized_rect(image_processing_options, + roi_allowed=False) + output_packets = self._process_image_data({ + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image), + _NORM_RECT_STREAM_NAME: + packet_creator.create_proto(normalized_rect.to_pb2()) + }) + if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty(): + return FaceDetectorResult([]) + detection_proto_list = packet_getter.get_proto_list( + output_packets[_DETECTIONS_OUT_STREAM_NAME]) + return detections_module.DetectionResult([ + detections_module.Detection.create_from_pb2(result) + for result in detection_proto_list + ]) + + def detect_for_video( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None + ) -> detections_module.DetectionResult: + """Performs face detection on the provided video frames. + + Only use this method when the FaceDetector is created with the video + running mode. It's required to provide the video frame's timestamp (in + milliseconds) along with the video frame. The input timestamps should be + monotonically increasing for adjacent calls of this method. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input video frame in milliseconds. + image_processing_options: Options for image processing. + + Returns: + A face detection result object that contains a list of face detections, + each detection has a bounding box that is expressed in the unrotated input + frame of reference coordinates system, i.e. in `[0,image_width) x [0, + image_height)`, which are the dimensions of the underlying image data. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If face detection failed to run. + """ + normalized_rect = self.convert_to_normalized_rect(image_processing_options, + roi_allowed=False) + output_packets = self._process_video_data({ + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + _NORM_RECT_STREAM_NAME: + packet_creator.create_proto(normalized_rect.to_pb2()).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + }) + if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty(): + return FaceDetectorResult([]) + detection_proto_list = packet_getter.get_proto_list( + output_packets[_DETECTIONS_OUT_STREAM_NAME]) + return detections_module.DetectionResult([ + detections_module.Detection.create_from_pb2(result) + for result in detection_proto_list + ]) + + def detect_async( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None + ) -> None: + """Sends live image data (an Image with a unique timestamp) to perform face detection. + + Only use this method when the FaceDetector is created with the live stream + running mode. The input timestamps should be monotonically increasing for + adjacent calls of this method. This method will return immediately after the + input image is accepted. The results will be available via the + `result_callback` provided in the `FaceDetectorOptions`. The + `detect_async` method is designed to process live stream data such as camera + input. To lower the overall latency, face detector may drop the input + images if needed. In other words, it's not guaranteed to have output per + input image. + + The `result_callback` provides: + - A face detection result object that contains a list of face detections, + each detection has a bounding box that is expressed in the unrotated + input frame of reference coordinates system, + i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions + of the underlying image data. + - The input image that the face detector runs on. + - The input timestamp in milliseconds. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input image in milliseconds. + image_processing_options: Options for image processing. + + Raises: + ValueError: If the current input timestamp is smaller than what the face + detector has already processed. + """ + normalized_rect = self.convert_to_normalized_rect(image_processing_options, + roi_allowed=False) + self._send_live_stream_data({ + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + _NORM_RECT_STREAM_NAME: + packet_creator.create_proto(normalized_rect.to_pb2()).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + }) From 24114ec2fec7a216903cb3b8fb08b657569f8648 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 9 Mar 2023 01:41:42 -0800 Subject: [PATCH 02/10] Updated comment in test --- mediapipe/tasks/python/test/vision/face_detector_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/test/vision/face_detector_test.py b/mediapipe/tasks/python/test/vision/face_detector_test.py index 90a52d110..f78c9c94e 100644 --- a/mediapipe/tasks/python/test/vision/face_detector_test.py +++ b/mediapipe/tasks/python/test/vision/face_detector_test.py @@ -209,7 +209,7 @@ class FaceDetectorTest(parameterized.TestCase): options = _FaceDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path)) with _FaceDetector.create_from_options(options) as detector: - # Performs object detection on the input. + # Performs face detection on the input. detection_result = detector.detect(test_image) self.assertEmpty(detection_result.detections) From f48909cab63243a477207e65f0ad08c079613baa Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 9 Mar 2023 02:13:34 -0800 Subject: [PATCH 03/10] Fixed score's data type --- mediapipe/tasks/python/components/containers/keypoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/components/containers/keypoint.py b/mediapipe/tasks/python/components/containers/keypoint.py index ef70c00b9..ef91d0950 100644 --- a/mediapipe/tasks/python/components/containers/keypoint.py +++ b/mediapipe/tasks/python/components/containers/keypoint.py @@ -40,7 +40,7 @@ class NormalizedKeypoint: x: Optional[float] = None y: Optional[float] = None label: Optional[str] = None - score: Optional[str] = None + score: Optional[float] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _RelativeKeypointProto: From ce3cd94f457970502adb855fc723d2d13ae47980 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 15 Mar 2023 10:54:21 -0700 Subject: [PATCH 04/10] Internal change PiperOrigin-RevId: 516871638 --- LICENSE | 17 ++ .../language_detector/custom_ops/utils/BUILD | 42 ++++ .../custom_ops/utils/ngram_hash_ops_utils.cc | 96 ++++++++ .../custom_ops/utils/ngram_hash_ops_utils.h | 56 +++++ .../utils/ngram_hash_ops_utils_test.cc | 135 ++++++++++ .../custom_ops/utils/utf/BUILD | 27 ++ .../custom_ops/utils/utf/rune.c | 233 ++++++++++++++++++ .../custom_ops/utils/utf/runetype.c | 54 ++++ .../custom_ops/utils/utf/runetypebody.h | 212 ++++++++++++++++ .../custom_ops/utils/utf/utf.h | 98 ++++++++ 10 files changed, 970 insertions(+) create mode 100644 mediapipe/tasks/cc/text/language_detector/custom_ops/utils/BUILD create mode 100644 mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.cc create mode 100644 mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h create mode 100644 mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils_test.cc create mode 100644 mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/BUILD create mode 100644 mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/rune.c create mode 100644 mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetype.c create mode 100644 mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h create mode 100644 mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h diff --git a/LICENSE b/LICENSE index 261eeb9e9..0e03e3911 100644 --- a/LICENSE +++ b/LICENSE @@ -199,3 +199,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +=========================================================================== +For files under tasks/cc/text/language_detector/custom_ops/utils/utf/ +=========================================================================== +/* + * The authors of this software are Rob Pike and Ken Thompson. + * Copyright (c) 2002 by Lucent Technologies. + * Permission to use, copy, modify, and distribute this software for any + * purpose without fee is hereby granted, provided that this entire notice + * is included in all copies of any software which is or includes a copy + * or modification of this software and in all copies of the supporting + * documentation for such software. + * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED + * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY + * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY + * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. + */ diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/BUILD b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/BUILD new file mode 100644 index 000000000..9f2fe298a --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/BUILD @@ -0,0 +1,42 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "ngram_hash_ops_utils", + srcs = [ + "ngram_hash_ops_utils.cc", + ], + hdrs = [ + "ngram_hash_ops_utils.h", + ], + deps = [ + "//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf", + ], +) + +cc_test( + name = "ngram_hash_ops_utils_test", + size = "small", + srcs = [ + "ngram_hash_ops_utils_test.cc", + ], + deps = [ + ":ngram_hash_ops_utils", + "//mediapipe/framework/port:gtest_main", + ], +) diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.cc new file mode 100644 index 000000000..f1ad71fc1 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.cc @@ -0,0 +1,96 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h" + +#include +#include +#include + +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h" + +namespace mediapipe::tasks::text::language_detector::custom_ops { + +TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens, + bool exclude_nonalphaspace_tokens) { + const std::string kPrefix = "^"; + const std::string kSuffix = "$"; + const std::string kReplacementToken = " "; + + TokenizedOutput output; + + size_t token_start = 0; + output.str.reserve(len + 2); + output.tokens.reserve(len + 2); + + output.str.append(kPrefix); + output.tokens.push_back(std::make_pair(token_start, kPrefix.size())); + token_start += kPrefix.size(); + + Rune token; + for (int i = 0; i < len && output.tokens.size() + 1 < max_tokens;) { + // Use the standard UTF-8 library to find the next token. + size_t bytes_read = utf_charntorune(&token, input_str + i, len - i); + + // Stop processing, if we can't read any more tokens, or we have reached + // maximum allowed tokens, allocating one token for the suffix. + if (bytes_read == 0) { + break; + } + + // If `exclude_nonalphaspace_tokens` is set to true, and the token is not + // alphanumeric, replace it with a replacement token. + if (exclude_nonalphaspace_tokens && !utf_isalpharune(token)) { + output.str.append(kReplacementToken); + output.tokens.push_back( + std::make_pair(token_start, kReplacementToken.size())); + token_start += kReplacementToken.size(); + i += bytes_read; + continue; + } + + // Append the token in the output string, and note its position and the + // number of bytes that token consumed. + output.str.append(input_str + i, bytes_read); + output.tokens.push_back(std::make_pair(token_start, bytes_read)); + token_start += bytes_read; + i += bytes_read; + } + output.str.append(kSuffix); + output.tokens.push_back(std::make_pair(token_start, kSuffix.size())); + token_start += kSuffix.size(); + + return output; +} + +void LowercaseUnicodeStr(const char* input_str, int len, + std::string* output_str) { + for (int i = 0; i < len;) { + Rune token; + + // Tokenize the given string, and get the appropriate lowercase token. + size_t bytes_read = utf_charntorune(&token, input_str + i, len - i); + token = utf_isalpharune(token) ? utf_tolowerrune(token) : token; + + // Write back the token to the output string. + char token_buf[UTFmax]; + size_t bytes_to_write = utf_runetochar(token_buf, &token); + output_str->append(token_buf, bytes_to_write); + + i += bytes_read; + } +} + +} // namespace mediapipe::tasks::text::language_detector::custom_ops diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h new file mode 100644 index 000000000..9a80554c8 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h @@ -0,0 +1,56 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_ + +#include +#include +#include + +namespace mediapipe::tasks::text::language_detector::custom_ops { + +struct TokenizedOutput { + // The processed string (with necessary prefix, suffix, skipped tokens, etc.). + std::string str; + + // This vector contains pairs, where each pair has two members. The first + // denoting the starting index of the token in the `str` string, and the + // second denoting the length of that token in bytes. + std::vector> tokens; +}; + +// Tokenizes the given input string on Unicode token boundaries, with a maximum +// of `max_tokens` tokens. +// +// If `exclude_nonalphaspace_tokens` is enabled, the tokenization ignores +// non-alphanumeric tokens, and replaces them with a replacement token (" "). +// +// The method returns the output in the `TokenizedOutput` struct, which stores +// both, the processed input string, and the indices and sizes of each token +// within that string. +TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens, + bool exclude_nonalphaspace_tokens); + +// Converts the given unicode string (`input_str`) with the specified length +// (`len`) to a lowercase string. +// +// The method populates the lowercased string in `output_str`. +void LowercaseUnicodeStr(const char* input_str, int len, + std::string* output_str); + +} // namespace mediapipe::tasks::text::language_detector::custom_ops + +#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_ diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils_test.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils_test.cc new file mode 100644 index 000000000..d22af1c95 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils_test.cc @@ -0,0 +1,135 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h" + +#include + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" + +namespace mediapipe::tasks::text::language_detector::custom_ops { + +namespace { + +using ::testing::Values; + +std::string ReconstructStringFromTokens(TokenizedOutput output) { + std::string reconstructed_str; + for (int i = 0; i < output.tokens.size(); i++) { + reconstructed_str.append( + output.str.c_str() + output.tokens[i].first, + output.str.c_str() + output.tokens[i].first + output.tokens[i].second); + } + return reconstructed_str; +} + +struct TokenizeTestParams { + std::string input_str; + size_t max_tokens; + bool exclude_nonalphaspace_tokens; + std::string expected_output_str; +}; + +class TokenizeParameterizedTest + : public ::testing::Test, + public testing::WithParamInterface {}; + +TEST_P(TokenizeParameterizedTest, Tokenize) { + // Checks that the Tokenize method returns the expected value. + const TokenizeTestParams params = TokenizeParameterizedTest::GetParam(); + const TokenizedOutput output = Tokenize( + /*input_str=*/params.input_str.c_str(), + /*len=*/params.input_str.size(), + /*max_tokens=*/params.max_tokens, + /*exclude_nonalphaspace_tokens=*/params.exclude_nonalphaspace_tokens); + + // The output string should have the necessary prefixes, and the "!" token + // should have been replaced with a " ". + EXPECT_EQ(output.str, params.expected_output_str); + EXPECT_EQ(ReconstructStringFromTokens(output), params.expected_output_str); +} + +INSTANTIATE_TEST_SUITE_P( + TokenizeParameterizedTests, TokenizeParameterizedTest, + Values( + // Test including non-alphanumeric characters. + TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100, + /*exclude_alphanonspace=*/false, + /*expected_output_str=*/"^hi!$"}), + // Test not including non-alphanumeric characters. + TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100, + /*exclude_alphanonspace=*/true, + /*expected_output_str=*/"^hi $"}), + // Test with a maximum of 3 tokens. + TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/3, + /*exclude_alphanonspace=*/true, + /*expected_output_str=*/"^h$"}), + // Test with non-latin characters. + TokenizeTestParams({/*input_str=*/"ありがと", /*max_tokens=*/100, + /*exclude_alphanonspace=*/true, + /*expected_output_str=*/"^ありがと$"}))); + +TEST(LowercaseUnicodeTest, TestLowercaseUnicode) { + { + // Check that the method is a no-op when the string is lowercase. + std::string input_str = "hello"; + std::string output_str; + LowercaseUnicodeStr( + /*input_str=*/input_str.c_str(), + /*len=*/input_str.size(), + /*output_str=*/&output_str); + + EXPECT_EQ(output_str, "hello"); + } + { + // Check that the method has uppercase characters. + std::string input_str = "hElLo"; + std::string output_str; + LowercaseUnicodeStr( + /*input_str=*/input_str.c_str(), + /*len=*/input_str.size(), + /*output_str=*/&output_str); + + EXPECT_EQ(output_str, "hello"); + } + { + // Check that the method works with non-latin scripts. + // Cyrillic has the concept of cases, so it should change the input. + std::string input_str = "БЙп"; + std::string output_str; + LowercaseUnicodeStr( + /*input_str=*/input_str.c_str(), + /*len=*/input_str.size(), + /*output_str=*/&output_str); + + EXPECT_EQ(output_str, "бйп"); + } + { + // Check that the method works with non-latin scripts. + // Japanese doesn't have the concept of cases, so it should not change. + std::string input_str = "ありがと"; + std::string output_str; + LowercaseUnicodeStr( + /*input_str=*/input_str.c_str(), + /*len=*/input_str.size(), + /*output_str=*/&output_str); + + EXPECT_EQ(output_str, "ありがと"); + } +} + +} // namespace +} // namespace mediapipe::tasks::text::language_detector::custom_ops diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/BUILD b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/BUILD new file mode 100644 index 000000000..a71845305 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/BUILD @@ -0,0 +1,27 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "utf", + srcs = [ + "rune.c", + "runetype.c", + "runetypebody.h", + ], + hdrs = ["utf.h"], +) diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/rune.c b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/rune.c new file mode 100644 index 000000000..b74450f44 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/rune.c @@ -0,0 +1,233 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Forked from a library written by Rob Pike and Ken Thompson. Original +// copyright message below. +/* + * The authors of this software are Rob Pike and Ken Thompson. + * Copyright (c) 2002 by Lucent Technologies. + * Permission to use, copy, modify, and distribute this software for any + * purpose without fee is hereby granted, provided that this entire notice + * is included in all copies of any software which is or includes a copy + * or modification of this software and in all copies of the supporting + * documentation for such software. + * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED + * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY + * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY + * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. + */ +#include +#include +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h" + +enum +{ + Bit1 = 7, + Bitx = 6, + Bit2 = 5, + Bit3 = 4, + Bit4 = 3, + Bit5 = 2, + + T1 = ((1<<(Bit1+1))-1) ^ 0xFF, /* 0000 0000 */ + Tx = ((1<<(Bitx+1))-1) ^ 0xFF, /* 1000 0000 */ + T2 = ((1<<(Bit2+1))-1) ^ 0xFF, /* 1100 0000 */ + T3 = ((1<<(Bit3+1))-1) ^ 0xFF, /* 1110 0000 */ + T4 = ((1<<(Bit4+1))-1) ^ 0xFF, /* 1111 0000 */ + T5 = ((1<<(Bit5+1))-1) ^ 0xFF, /* 1111 1000 */ + + Rune1 = (1<<(Bit1+0*Bitx))-1, /* 0000 0000 0111 1111 */ + Rune2 = (1<<(Bit2+1*Bitx))-1, /* 0000 0111 1111 1111 */ + Rune3 = (1<<(Bit3+2*Bitx))-1, /* 1111 1111 1111 1111 */ + Rune4 = (1<<(Bit4+3*Bitx))-1, + /* 0001 1111 1111 1111 1111 1111 */ + + Maskx = (1< T1 + */ + c = *(uchar*)str; + if(c < Tx) { + *rune = c; + return 1; + } + + // If we can't read more than one character we must stop + if(length <= 1) { + goto badlen; + } + + /* + * two character sequence (11-bit value) + * 0080-07FF => T2 Tx + */ + c1 = *(uchar*)(str+1) ^ Tx; + if(c1 & Testx) + goto bad; + if(c < T3) { + if(c < T2) + goto bad; + l = ((c << Bitx) | c1) & Rune2; + if(l <= Rune1) + goto bad; + *rune = l; + return 2; + } + + // If we can't read more than two characters we must stop + if(length <= 2) { + goto badlen; + } + + /* + * three character sequence (16-bit value) + * 0800-FFFF => T3 Tx Tx + */ + c2 = *(uchar*)(str+2) ^ Tx; + if(c2 & Testx) + goto bad; + if(c < T4) { + l = ((((c << Bitx) | c1) << Bitx) | c2) & Rune3; + if(l <= Rune2) + goto bad; + *rune = l; + return 3; + } + + if (length <= 3) + goto badlen; + + /* + * four character sequence (21-bit value) + * 10000-1FFFFF => T4 Tx Tx Tx + */ + c3 = *(uchar*)(str+3) ^ Tx; + if (c3 & Testx) + goto bad; + if (c < T5) { + l = ((((((c << Bitx) | c1) << Bitx) | c2) << Bitx) | c3) & Rune4; + if (l <= Rune3) + goto bad; + if (l > Runemax) + goto bad; + *rune = l; + return 4; + } + + // Support for 5-byte or longer UTF-8 would go here, but + // since we don't have that, we'll just fall through to bad. + + /* + * bad decoding + */ +bad: + *rune = Bad; + return 1; +badlen: + *rune = Bad; + return 0; + +} + +int +utf_runetochar(char *str, const Rune *rune) +{ + /* Runes are signed, so convert to unsigned for range check. */ + unsigned long c; + + /* + * one character sequence + * 00000-0007F => 00-7F + */ + c = *rune; + if(c <= Rune1) { + str[0] = c; + return 1; + } + + /* + * two character sequence + * 0080-07FF => T2 Tx + */ + if(c <= Rune2) { + str[0] = T2 | (c >> 1*Bitx); + str[1] = Tx | (c & Maskx); + return 2; + } + + /* + * If the Rune is out of range, convert it to the error rune. + * Do this test here because the error rune encodes to three bytes. + * Doing it earlier would duplicate work, since an out of range + * Rune wouldn't have fit in one or two bytes. + */ + if (c > Runemax) + c = Runeerror; + + /* + * three character sequence + * 0800-FFFF => T3 Tx Tx + */ + if (c <= Rune3) { + str[0] = T3 | (c >> 2*Bitx); + str[1] = Tx | ((c >> 1*Bitx) & Maskx); + str[2] = Tx | (c & Maskx); + return 3; + } + + /* + * four character sequence (21-bit value) + * 10000-1FFFFF => T4 Tx Tx Tx + */ + str[0] = T4 | (c >> 3*Bitx); + str[1] = Tx | ((c >> 2*Bitx) & Maskx); + str[2] = Tx | ((c >> 1*Bitx) & Maskx); + str[3] = Tx | (c & Maskx); + return 4; +} diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetype.c b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetype.c new file mode 100644 index 000000000..1dd8abdbd --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetype.c @@ -0,0 +1,54 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Forked from a library written by Rob Pike and Ken Thompson. Original +// copyright message below. +/* + * The authors of this software are Rob Pike and Ken Thompson. + * Copyright (c) 2002 by Lucent Technologies. + * Permission to use, copy, modify, and distribute this software for any + * purpose without fee is hereby granted, provided that this entire notice + * is included in all copies of any software which is or includes a copy + * or modification of this software and in all copies of the supporting + * documentation for such software. + * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED + * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY + * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY + * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. + */ +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h" + +static +Rune* +rbsearch(Rune c, Rune *t, int n, int ne) +{ + Rune *p; + int m; + + while(n > 1) { + m = n >> 1; + p = t + m*ne; + if(c >= p[0]) { + t = p; + n = n-m; + } else + n = m; + } + if(n && c >= t[0]) + return t; + return 0; +} + +#define RUNETYPEBODY +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h" diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h new file mode 100644 index 000000000..66d1dfc19 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h @@ -0,0 +1,212 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef RUNETYPEBODY + +static Rune __isalphar[] = { + 0x0041, 0x005a, 0x0061, 0x007a, 0x00c0, 0x00d6, 0x00d8, 0x00f6, + 0x00f8, 0x02c1, 0x02c6, 0x02d1, 0x02e0, 0x02e4, 0x0370, 0x0374, + 0x0376, 0x0377, 0x037a, 0x037d, 0x0388, 0x038a, 0x038e, 0x03a1, + 0x03a3, 0x03f5, 0x03f7, 0x0481, 0x048a, 0x0527, 0x0531, 0x0556, + 0x0561, 0x0587, 0x05d0, 0x05ea, 0x05f0, 0x05f2, 0x0620, 0x064a, + 0x066e, 0x066f, 0x0671, 0x06d3, 0x06e5, 0x06e6, 0x06ee, 0x06ef, + 0x06fa, 0x06fc, 0x0712, 0x072f, 0x074d, 0x07a5, 0x07ca, 0x07ea, + 0x07f4, 0x07f5, 0x0800, 0x0815, 0x0840, 0x0858, 0x08a2, 0x08ac, + 0x0904, 0x0939, 0x0958, 0x0961, 0x0971, 0x0977, 0x0979, 0x097f, + 0x0985, 0x098c, 0x098f, 0x0990, 0x0993, 0x09a8, 0x09aa, 0x09b0, + 0x09b6, 0x09b9, 0x09dc, 0x09dd, 0x09df, 0x09e1, 0x09f0, 0x09f1, + 0x0a05, 0x0a0a, 0x0a0f, 0x0a10, 0x0a13, 0x0a28, 0x0a2a, 0x0a30, + 0x0a32, 0x0a33, 0x0a35, 0x0a36, 0x0a38, 0x0a39, 0x0a59, 0x0a5c, + 0x0a72, 0x0a74, 0x0a85, 0x0a8d, 0x0a8f, 0x0a91, 0x0a93, 0x0aa8, + 0x0aaa, 0x0ab0, 0x0ab2, 0x0ab3, 0x0ab5, 0x0ab9, 0x0ae0, 0x0ae1, + 0x0b05, 0x0b0c, 0x0b0f, 0x0b10, 0x0b13, 0x0b28, 0x0b2a, 0x0b30, + 0x0b32, 0x0b33, 0x0b35, 0x0b39, 0x0b5c, 0x0b5d, 0x0b5f, 0x0b61, + 0x0b85, 0x0b8a, 0x0b8e, 0x0b90, 0x0b92, 0x0b95, 0x0b99, 0x0b9a, + 0x0b9e, 0x0b9f, 0x0ba3, 0x0ba4, 0x0ba8, 0x0baa, 0x0bae, 0x0bb9, + 0x0c05, 0x0c0c, 0x0c0e, 0x0c10, 0x0c12, 0x0c28, 0x0c2a, 0x0c33, + 0x0c35, 0x0c39, 0x0c58, 0x0c59, 0x0c60, 0x0c61, 0x0c85, 0x0c8c, + 0x0c8e, 0x0c90, 0x0c92, 0x0ca8, 0x0caa, 0x0cb3, 0x0cb5, 0x0cb9, + 0x0ce0, 0x0ce1, 0x0cf1, 0x0cf2, 0x0d05, 0x0d0c, 0x0d0e, 0x0d10, + 0x0d12, 0x0d3a, 0x0d60, 0x0d61, 0x0d7a, 0x0d7f, 0x0d85, 0x0d96, + 0x0d9a, 0x0db1, 0x0db3, 0x0dbb, 0x0dc0, 0x0dc6, 0x0e01, 0x0e30, + 0x0e32, 0x0e33, 0x0e40, 0x0e46, 0x0e81, 0x0e82, 0x0e87, 0x0e88, + 0x0e94, 0x0e97, 0x0e99, 0x0e9f, 0x0ea1, 0x0ea3, 0x0eaa, 0x0eab, + 0x0ead, 0x0eb0, 0x0eb2, 0x0eb3, 0x0ec0, 0x0ec4, 0x0edc, 0x0edf, + 0x0f40, 0x0f47, 0x0f49, 0x0f6c, 0x0f88, 0x0f8c, 0x1000, 0x102a, + 0x1050, 0x1055, 0x105a, 0x105d, 0x1065, 0x1066, 0x106e, 0x1070, + 0x1075, 0x1081, 0x10a0, 0x10c5, 0x10d0, 0x10fa, 0x10fc, 0x1248, + 0x124a, 0x124d, 0x1250, 0x1256, 0x125a, 0x125d, 0x1260, 0x1288, + 0x128a, 0x128d, 0x1290, 0x12b0, 0x12b2, 0x12b5, 0x12b8, 0x12be, + 0x12c2, 0x12c5, 0x12c8, 0x12d6, 0x12d8, 0x1310, 0x1312, 0x1315, + 0x1318, 0x135a, 0x1380, 0x138f, 0x13a0, 0x13f4, 0x1401, 0x166c, + 0x166f, 0x167f, 0x1681, 0x169a, 0x16a0, 0x16ea, 0x1700, 0x170c, + 0x170e, 0x1711, 0x1720, 0x1731, 0x1740, 0x1751, 0x1760, 0x176c, + 0x176e, 0x1770, 0x1780, 0x17b3, 0x1820, 0x1877, 0x1880, 0x18a8, + 0x18b0, 0x18f5, 0x1900, 0x191c, 0x1950, 0x196d, 0x1970, 0x1974, + 0x1980, 0x19ab, 0x19c1, 0x19c7, 0x1a00, 0x1a16, 0x1a20, 0x1a54, + 0x1b05, 0x1b33, 0x1b45, 0x1b4b, 0x1b83, 0x1ba0, 0x1bae, 0x1baf, + 0x1bba, 0x1be5, 0x1c00, 0x1c23, 0x1c4d, 0x1c4f, 0x1c5a, 0x1c7d, + 0x1ce9, 0x1cec, 0x1cee, 0x1cf1, 0x1cf5, 0x1cf6, 0x1d00, 0x1dbf, + 0x1e00, 0x1f15, 0x1f18, 0x1f1d, 0x1f20, 0x1f45, 0x1f48, 0x1f4d, + 0x1f50, 0x1f57, 0x1f5f, 0x1f7d, 0x1f80, 0x1fb4, 0x1fb6, 0x1fbc, + 0x1fc2, 0x1fc4, 0x1fc6, 0x1fcc, 0x1fd0, 0x1fd3, 0x1fd6, 0x1fdb, + 0x1fe0, 0x1fec, 0x1ff2, 0x1ff4, 0x1ff6, 0x1ffc, 0x2090, 0x209c, + 0x210a, 0x2113, 0x2119, 0x211d, 0x212a, 0x212d, 0x212f, 0x2139, + 0x213c, 0x213f, 0x2145, 0x2149, 0x2183, 0x2184, 0x2c00, 0x2c2e, + 0x2c30, 0x2c5e, 0x2c60, 0x2ce4, 0x2ceb, 0x2cee, 0x2cf2, 0x2cf3, + 0x2d00, 0x2d25, 0x2d30, 0x2d67, 0x2d80, 0x2d96, 0x2da0, 0x2da6, + 0x2da8, 0x2dae, 0x2db0, 0x2db6, 0x2db8, 0x2dbe, 0x2dc0, 0x2dc6, + 0x2dc8, 0x2dce, 0x2dd0, 0x2dd6, 0x2dd8, 0x2dde, 0x3005, 0x3006, + 0x3031, 0x3035, 0x303b, 0x303c, 0x3041, 0x3096, 0x309d, 0x309f, + 0x30a1, 0x30fa, 0x30fc, 0x30ff, 0x3105, 0x312d, 0x3131, 0x318e, + 0x31a0, 0x31ba, 0x31f0, 0x31ff, 0x3400, 0x4db5, 0x4e00, 0x9fcc, + 0xa000, 0xa48c, 0xa4d0, 0xa4fd, 0xa500, 0xa60c, 0xa610, 0xa61f, + 0xa62a, 0xa62b, 0xa640, 0xa66e, 0xa67f, 0xa697, 0xa6a0, 0xa6e5, + 0xa717, 0xa71f, 0xa722, 0xa788, 0xa78b, 0xa78e, 0xa790, 0xa793, + 0xa7a0, 0xa7aa, 0xa7f8, 0xa801, 0xa803, 0xa805, 0xa807, 0xa80a, + 0xa80c, 0xa822, 0xa840, 0xa873, 0xa882, 0xa8b3, 0xa8f2, 0xa8f7, + 0xa90a, 0xa925, 0xa930, 0xa946, 0xa960, 0xa97c, 0xa984, 0xa9b2, + 0xaa00, 0xaa28, 0xaa40, 0xaa42, 0xaa44, 0xaa4b, 0xaa60, 0xaa76, + 0xaa80, 0xaaaf, 0xaab5, 0xaab6, 0xaab9, 0xaabd, 0xaadb, 0xaadd, + 0xaae0, 0xaaea, 0xaaf2, 0xaaf4, 0xab01, 0xab06, 0xab09, 0xab0e, + 0xab11, 0xab16, 0xab20, 0xab26, 0xab28, 0xab2e, 0xabc0, 0xabe2, + 0xac00, 0xd7a3, 0xd7b0, 0xd7c6, 0xd7cb, 0xd7fb, 0xf900, 0xfa6d, + 0xfa70, 0xfad9, 0xfb00, 0xfb06, 0xfb13, 0xfb17, 0xfb1f, 0xfb28, + 0xfb2a, 0xfb36, 0xfb38, 0xfb3c, 0xfb40, 0xfb41, 0xfb43, 0xfb44, + 0xfb46, 0xfbb1, 0xfbd3, 0xfd3d, 0xfd50, 0xfd8f, 0xfd92, 0xfdc7, + 0xfdf0, 0xfdfb, 0xfe70, 0xfe74, 0xfe76, 0xfefc, 0xff21, 0xff3a, + 0xff41, 0xff5a, 0xff66, 0xffbe, 0xffc2, 0xffc7, 0xffca, 0xffcf, + 0xffd2, 0xffd7, 0xffda, 0xffdc, 0x10000, 0x1000b, 0x1000d, 0x10026, + 0x10028, 0x1003a, 0x1003c, 0x1003d, 0x1003f, 0x1004d, 0x10050, 0x1005d, + 0x10080, 0x100fa, 0x10280, 0x1029c, 0x102a0, 0x102d0, 0x10300, 0x1031e, + 0x10330, 0x10340, 0x10342, 0x10349, 0x10380, 0x1039d, 0x103a0, 0x103c3, + 0x103c8, 0x103cf, 0x10400, 0x1049d, 0x10800, 0x10805, 0x1080a, 0x10835, + 0x10837, 0x10838, 0x1083f, 0x10855, 0x10900, 0x10915, 0x10920, 0x10939, + 0x10980, 0x109b7, 0x109be, 0x109bf, 0x10a10, 0x10a13, 0x10a15, 0x10a17, + 0x10a19, 0x10a33, 0x10a60, 0x10a7c, 0x10b00, 0x10b35, 0x10b40, 0x10b55, + 0x10b60, 0x10b72, 0x10c00, 0x10c48, 0x11003, 0x11037, 0x11083, 0x110af, + 0x110d0, 0x110e8, 0x11103, 0x11126, 0x11183, 0x111b2, 0x111c1, 0x111c4, + 0x11680, 0x116aa, 0x12000, 0x1236e, 0x13000, 0x1342e, 0x16800, 0x16a38, + 0x16f00, 0x16f44, 0x16f93, 0x16f9f, 0x1b000, 0x1b001, 0x1d400, 0x1d454, + 0x1d456, 0x1d49c, 0x1d49e, 0x1d49f, 0x1d4a5, 0x1d4a6, 0x1d4a9, 0x1d4ac, + 0x1d4ae, 0x1d4b9, 0x1d4bd, 0x1d4c3, 0x1d4c5, 0x1d505, 0x1d507, 0x1d50a, + 0x1d50d, 0x1d514, 0x1d516, 0x1d51c, 0x1d51e, 0x1d539, 0x1d53b, 0x1d53e, + 0x1d540, 0x1d544, 0x1d54a, 0x1d550, 0x1d552, 0x1d6a5, 0x1d6a8, 0x1d6c0, + 0x1d6c2, 0x1d6da, 0x1d6dc, 0x1d6fa, 0x1d6fc, 0x1d714, 0x1d716, 0x1d734, + 0x1d736, 0x1d74e, 0x1d750, 0x1d76e, 0x1d770, 0x1d788, 0x1d78a, 0x1d7a8, + 0x1d7aa, 0x1d7c2, 0x1d7c4, 0x1d7cb, 0x1ee00, 0x1ee03, 0x1ee05, 0x1ee1f, + 0x1ee21, 0x1ee22, 0x1ee29, 0x1ee32, 0x1ee34, 0x1ee37, 0x1ee4d, 0x1ee4f, + 0x1ee51, 0x1ee52, 0x1ee61, 0x1ee62, 0x1ee67, 0x1ee6a, 0x1ee6c, 0x1ee72, + 0x1ee74, 0x1ee77, 0x1ee79, 0x1ee7c, 0x1ee80, 0x1ee89, 0x1ee8b, 0x1ee9b, + 0x1eea1, 0x1eea3, 0x1eea5, 0x1eea9, 0x1eeab, 0x1eebb, 0x20000, 0x2a6d6, + 0x2a700, 0x2b734, 0x2b740, 0x2b81d, 0x2f800, 0x2fa1d, +}; + +static Rune __isalphas[] = { + 0x00aa, 0x00b5, 0x00ba, 0x02ec, 0x02ee, 0x0386, 0x038c, 0x0559, + 0x06d5, 0x06ff, 0x0710, 0x07b1, 0x07fa, 0x081a, 0x0824, 0x0828, + 0x08a0, 0x093d, 0x0950, 0x09b2, 0x09bd, 0x09ce, 0x0a5e, 0x0abd, + 0x0ad0, 0x0b3d, 0x0b71, 0x0b83, 0x0b9c, 0x0bd0, 0x0c3d, 0x0cbd, + 0x0cde, 0x0d3d, 0x0d4e, 0x0dbd, 0x0e84, 0x0e8a, 0x0e8d, 0x0ea5, + 0x0ea7, 0x0ebd, 0x0ec6, 0x0f00, 0x103f, 0x1061, 0x108e, 0x10c7, + 0x10cd, 0x1258, 0x12c0, 0x17d7, 0x17dc, 0x18aa, 0x1aa7, 0x1f59, + 0x1f5b, 0x1f5d, 0x1fbe, 0x2071, 0x207f, 0x2102, 0x2107, 0x2115, + 0x2124, 0x2126, 0x2128, 0x214e, 0x2d27, 0x2d2d, 0x2d6f, 0x2e2f, + 0xa8fb, 0xa9cf, 0xaa7a, 0xaab1, 0xaac0, 0xaac2, 0xfb1d, 0xfb3e, + 0x10808, 0x1083c, 0x10a00, 0x16f50, 0x1d4a2, 0x1d4bb, 0x1d546, 0x1ee24, + 0x1ee27, 0x1ee39, 0x1ee3b, 0x1ee42, 0x1ee47, 0x1ee49, 0x1ee4b, 0x1ee54, + 0x1ee57, 0x1ee59, 0x1ee5b, 0x1ee5d, 0x1ee5f, 0x1ee64, 0x1ee7e, +}; + +int utf_isalpharune(Rune c) { + Rune *p; + + p = rbsearch(c, __isalphar, nelem(__isalphar) / 2, 2); + if (p && c >= p[0] && c <= p[1]) return 1; + p = rbsearch(c, __isalphas, nelem(__isalphas), 1); + if (p && c == p[0]) return 1; + return 0; +} + +static Rune __tolowerr[] = { + 0x0041, 0x005a, 1048608, 0x00c0, 0x00d6, 1048608, 0x00d8, 0x00de, 1048608, + 0x0189, 0x018a, 1048781, 0x01b1, 0x01b2, 1048793, 0x0388, 0x038a, 1048613, + 0x038e, 0x038f, 1048639, 0x0391, 0x03a1, 1048608, 0x03a3, 0x03ab, 1048608, + 0x03fd, 0x03ff, 1048446, 0x0400, 0x040f, 1048656, 0x0410, 0x042f, 1048608, + 0x0531, 0x0556, 1048624, 0x10a0, 0x10c5, 1055840, 0x1f08, 0x1f0f, 1048568, + 0x1f18, 0x1f1d, 1048568, 0x1f28, 0x1f2f, 1048568, 0x1f38, 0x1f3f, 1048568, + 0x1f48, 0x1f4d, 1048568, 0x1f68, 0x1f6f, 1048568, 0x1f88, 0x1f8f, 1048568, + 0x1f98, 0x1f9f, 1048568, 0x1fa8, 0x1faf, 1048568, 0x1fb8, 0x1fb9, 1048568, + 0x1fba, 0x1fbb, 1048502, 0x1fc8, 0x1fcb, 1048490, 0x1fd8, 0x1fd9, 1048568, + 0x1fda, 0x1fdb, 1048476, 0x1fe8, 0x1fe9, 1048568, 0x1fea, 0x1feb, 1048464, + 0x1ff8, 0x1ff9, 1048448, 0x1ffa, 0x1ffb, 1048450, 0x2160, 0x216f, 1048592, + 0x24b6, 0x24cf, 1048602, 0x2c00, 0x2c2e, 1048624, 0x2c7e, 0x2c7f, 1037761, + 0xff21, 0xff3a, 1048608, 0x10400, 0x10427, 1048616, +}; + +static Rune __tolowerp[] = { + 0x0100, 0x012e, 1048577, 0x0132, 0x0136, 1048577, 0x0139, 0x0147, 1048577, + 0x014a, 0x0176, 1048577, 0x017b, 0x017d, 1048577, 0x01a2, 0x01a4, 1048577, + 0x01b3, 0x01b5, 1048577, 0x01cd, 0x01db, 1048577, 0x01de, 0x01ee, 1048577, + 0x01f8, 0x021e, 1048577, 0x0222, 0x0232, 1048577, 0x0248, 0x024e, 1048577, + 0x0370, 0x0372, 1048577, 0x03d8, 0x03ee, 1048577, 0x0460, 0x0480, 1048577, + 0x048a, 0x04be, 1048577, 0x04c3, 0x04cd, 1048577, 0x04d0, 0x0526, 1048577, + 0x1e00, 0x1e94, 1048577, 0x1ea0, 0x1efe, 1048577, 0x1f59, 0x1f5f, 1048568, + 0x2c67, 0x2c6b, 1048577, 0x2c80, 0x2ce2, 1048577, 0x2ceb, 0x2ced, 1048577, + 0xa640, 0xa66c, 1048577, 0xa680, 0xa696, 1048577, 0xa722, 0xa72e, 1048577, + 0xa732, 0xa76e, 1048577, 0xa779, 0xa77b, 1048577, 0xa780, 0xa786, 1048577, + 0xa790, 0xa792, 1048577, 0xa7a0, 0xa7a8, 1048577, +}; + +static Rune __tolowers[] = { + 0x0130, 1048377, 0x0178, 1048455, 0x0179, 1048577, 0x0181, 1048786, + 0x0182, 1048577, 0x0184, 1048577, 0x0186, 1048782, 0x0187, 1048577, + 0x018b, 1048577, 0x018e, 1048655, 0x018f, 1048778, 0x0190, 1048779, + 0x0191, 1048577, 0x0193, 1048781, 0x0194, 1048783, 0x0196, 1048787, + 0x0197, 1048785, 0x0198, 1048577, 0x019c, 1048787, 0x019d, 1048789, + 0x019f, 1048790, 0x01a0, 1048577, 0x01a6, 1048794, 0x01a7, 1048577, + 0x01a9, 1048794, 0x01ac, 1048577, 0x01ae, 1048794, 0x01af, 1048577, + 0x01b7, 1048795, 0x01b8, 1048577, 0x01bc, 1048577, 0x01c4, 1048578, + 0x01c5, 1048577, 0x01c7, 1048578, 0x01c8, 1048577, 0x01ca, 1048578, + 0x01cb, 1048577, 0x01f1, 1048578, 0x01f2, 1048577, 0x01f4, 1048577, + 0x01f6, 1048479, 0x01f7, 1048520, 0x0220, 1048446, 0x023a, 1059371, + 0x023b, 1048577, 0x023d, 1048413, 0x023e, 1059368, 0x0241, 1048577, + 0x0243, 1048381, 0x0244, 1048645, 0x0245, 1048647, 0x0246, 1048577, + 0x0376, 1048577, 0x0386, 1048614, 0x038c, 1048640, 0x03cf, 1048584, + 0x03f4, 1048516, 0x03f7, 1048577, 0x03f9, 1048569, 0x03fa, 1048577, + 0x04c0, 1048591, 0x04c1, 1048577, 0x10c7, 1055840, 0x10cd, 1055840, + 0x1e9e, 1040961, 0x1fbc, 1048567, 0x1fcc, 1048567, 0x1fec, 1048569, + 0x1ffc, 1048567, 0x2126, 1041059, 0x212a, 1040193, 0x212b, 1040314, + 0x2132, 1048604, 0x2183, 1048577, 0x2c60, 1048577, 0x2c62, 1037833, + 0x2c63, 1044762, 0x2c64, 1037849, 0x2c6d, 1037796, 0x2c6e, 1037827, + 0x2c6f, 1037793, 0x2c70, 1037794, 0x2c72, 1048577, 0x2c75, 1048577, + 0x2cf2, 1048577, 0xa77d, 1013244, 0xa77e, 1048577, 0xa78b, 1048577, + 0xa78d, 1006296, 0xa7aa, 1006268, +}; + +Rune utf_tolowerrune(Rune c) { + Rune *p; + + p = rbsearch(c, __tolowerr, nelem(__tolowerr) / 3, 3); + if (p && c >= p[0] && c <= p[1]) return c + p[2] - 1048576; + p = rbsearch(c, __tolowerp, nelem(__tolowerp) / 3, 3); + if (p && c >= p[0] && c <= p[1] && !((c - p[0]) & 1)) + return c + p[2] - 1048576; + p = rbsearch(c, __tolowers, nelem(__tolowers) / 2, 2); + if (p && c == p[0]) return c + p[1] - 1048576; + return c; +} + +#endif diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h new file mode 100644 index 000000000..f3b14772e --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h @@ -0,0 +1,98 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Fork of several UTF utils originally written by Rob Pike and Ken Thompson. +#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_ 1 + +#include + +// Code-point values in Unicode 4.0 are 21 bits wide. +typedef signed int Rune; + +#define uchar _utfuchar + +typedef unsigned char uchar; + +#define nelem(x) (sizeof(x) / sizeof((x)[0])) + +enum { + UTFmax = 4, // maximum bytes per rune + Runeerror = 0xFFFD, // decoding error in UTF + Runemax = 0x10FFFF, // maximum rune value +}; + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * rune routines + */ + +/* + * These routines were written by Rob Pike and Ken Thompson + * and first appeared in Plan 9. + * SEE ALSO + * utf (7) + * tcs (1) + */ + +// utf_runetochar copies (encodes) one rune, pointed to by r, to at most +// UTFmax bytes starting at s and returns the number of bytes generated. + +int utf_runetochar(char* s, const Rune* r); + +// utf_charntorune copies (decodes) at most UTFmax bytes starting at `str` to +// one rune, pointed to by `rune`, accesss at most `length` bytes of `str`, and +// returns the number of bytes consumed. +// If the UTF sequence is incomplete within n bytes, +// utf_charntorune will set *r to Runeerror and return 0. If it is complete +// but not in UTF format, it will set *r to Runeerror and return 1. +// +// Added 2004-09-24 by Wei-Hwa Huang + +int utf_charntorune(Rune* rune, const char* str, int length); + +// Unicode defines some characters as letters and +// specifies three cases: upper, lower, and title. Mappings among the +// cases are also defined, although they are not exhaustive: some +// upper case letters have no lower case mapping, and so on. Unicode +// also defines several character properties, a subset of which are +// checked by these routines. These routines are based on Unicode +// version 3.0.0. +// +// NOTE: The routines are implemented in C, so isalpharrune returns 0 for false +// and 1 for true. +// +// utf_tolowerrune is the Unicode case mapping. It returns the character +// unchanged if it has no defined mapping. + +Rune utf_tolowerrune(Rune r); + +// utf_isalpharune tests for Unicode letters; this includes ideographs in +// addition to alphabetic characters. + +int utf_isalpharune(Rune r); + +// (The comments in this file were copied from the manpage files rune.3, +// isalpharune.3, and runestrcat.3. Some formatting changes were also made +// to conform to Google style. /JRM 11/11/05) + +#ifdef __cplusplus +} +#endif + +#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_ From 18d88c531ac54b6369dc45f558b23cda20895efc Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 15 Mar 2023 11:29:16 -0700 Subject: [PATCH 05/10] Internal MediaPipe Tasks change. PiperOrigin-RevId: 516881879 --- .../text/language_detector/custom_ops/BUILD | 33 ++ .../custom_ops/ngram_hash.cc | 264 +++++++++++++++ .../language_detector/custom_ops/ngram_hash.h | 27 ++ .../custom_ops/ngram_hash_test.cc | 313 ++++++++++++++++++ 4 files changed, 637 insertions(+) create mode 100644 mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc create mode 100644 mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h create mode 100644 mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD b/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD index 5e7c5afa5..090f528ef 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD @@ -42,3 +42,36 @@ cc_test( "@org_tensorflow//tensorflow/lite/kernels:test_util", ], ) + +cc_library( + name = "ngram_hash", + srcs = ["ngram_hash.cc"], + hdrs = ["ngram_hash.h"], + copts = tflite_copts(), + deps = [ + "//mediapipe/tasks/cc/text/language_detector/custom_ops/utils:ngram_hash_ops_utils", + "//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + ], + alwayslink = 1, +) + +cc_test( + name = "ngram_hash_test", + srcs = ["ngram_hash_test.cc"], + deps = [ + ":ngram_hash", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur", + "@com_google_absl//absl/types:optional", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels:test_util", + ], +) diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc new file mode 100644 index 000000000..738fa1128 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc @@ -0,0 +1,264 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" + +#include +#include +#include + +#include "flatbuffers/flexbuffers.h" +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h" +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite::ops::custom { + +namespace ngram_op { + +namespace { + +using ::flexbuffers::GetRoot; +using ::flexbuffers::Map; +using ::flexbuffers::TypedVector; +using ::mediapipe::tasks::text::language_detector::custom_ops:: + LowercaseUnicodeStr; +using ::mediapipe::tasks::text::language_detector::custom_ops::Tokenize; +using ::mediapipe::tasks::text::language_detector::custom_ops::TokenizedOutput; +using ::mediapipe::tasks::text::language_detector::custom_ops::hash:: + MurmurHash64WithSeed; +using ::tflite::GetString; +using ::tflite::StringRef; + +constexpr int kInputMessage = 0; +constexpr int kOutputLabel = 0; +constexpr int kDefaultMaxSplits = 128; + +// This op takes in a string, finds the character ngrams for it and then +// maps each of these ngrams to an index using the specified vocabulary sizes. + +// Input(s): +// - input: Input string. +// - seeds: Seed for the random number generator. +// - ngram_lengths: Lengths of each of the ngrams. For example [1, 2, 3] would +// be interpreted as generating unigrams, bigrams, and trigrams. +// - vocab_sizes: Size of the vocabulary for each of the ngram features +// respectively. The op would generate vocab ids to be less than or equal to +// the vocab size. The index 0 implies an invalid ngram. +// - max_splits: Maximum number of tokens in the output. If this is unset, the +// limit is `kDefaultMaxSplits`. +// - lower_case_input: If this is set to true, the input string would be +// lower-cased before any processing. + +// Output(s): +// - output: A tensor of size [number of ngrams, number of tokens + 2], +// where 2 tokens are reserved for the padding. If `max_splits` is set, this +// length is <= max_splits, otherwise it is <= `kDefaultMaxSplits`. + +// Helper class used for pre-processing the input. +class NGramHashParams { + public: + NGramHashParams(const uint64_t seed, const std::vector& ngram_lengths, + const std::vector& vocab_sizes, int max_splits, + bool lower_case_input) + : seed_(seed), + ngram_lengths_(ngram_lengths), + vocab_sizes_(vocab_sizes), + max_splits_(max_splits), + lower_case_input_(lower_case_input) {} + + TfLiteStatus PreprocessInput(const TfLiteTensor* input_t, + TfLiteContext* context) { + if (input_t->bytes == 0) { + context->ReportError(context, "Empty input not supported."); + return kTfLiteError; + } + + // Do sanity checks on the input. + if (ngram_lengths_.empty()) { + context->ReportError(context, "`ngram_lengths` must be non-empty."); + return kTfLiteError; + } + + if (vocab_sizes_.empty()) { + context->ReportError(context, "`vocab_sizes` must be non-empty."); + return kTfLiteError; + } + + if (ngram_lengths_.size() != vocab_sizes_.size()) { + context->ReportError( + context, + "Sizes of `ngram_lengths` and `vocab_sizes` must be the same."); + return kTfLiteError; + } + + if (max_splits_ <= 0) { + context->ReportError(context, "`max_splits` must be > 0."); + return kTfLiteError; + } + + // Obtain and tokenize the input. + StringRef inputref = GetString(input_t, /*string_index=*/0); + if (lower_case_input_) { + std::string lower_cased_str; + LowercaseUnicodeStr(inputref.str, inputref.len, &lower_cased_str); + + tokenized_output_ = + Tokenize(lower_cased_str.c_str(), inputref.len, max_splits_, + /*exclude_nonalphaspace_tokens=*/true); + } else { + tokenized_output_ = Tokenize(inputref.str, inputref.len, max_splits_, + /*exclude_nonalphaspace_tokens=*/true); + } + return kTfLiteOk; + } + uint64_t GetSeed() const { return seed_; } + + int GetNumTokens() const { return tokenized_output_.tokens.size(); } + + int GetNumNGrams() const { return ngram_lengths_.size(); } + + std::vector GetNGramLengths() const { return ngram_lengths_; } + + std::vector GetVocabSizes() const { return vocab_sizes_; } + + const TokenizedOutput& GetTokenizedOutput() const { + return tokenized_output_; + } + + TokenizedOutput tokenized_output_; + + private: + const uint64_t seed_; + std::vector ngram_lengths_; + std::vector vocab_sizes_; + const int max_splits_; + const bool lower_case_input_; +}; + +// Convert the TypedVector into a regular std::vector. +std::vector GetIntVector(TypedVector typed_vec) { + std::vector vec(typed_vec.size()); + for (int j = 0; j < typed_vec.size(); j++) { + vec[j] = typed_vec[j].AsInt32(); + } + return vec; +} + +void GetNGramHashIndices(NGramHashParams* params, int32_t* data) { + const int max_unicode_length = params->GetNumTokens(); + const auto ngram_lengths = params->GetNGramLengths(); + const auto vocab_sizes = params->GetVocabSizes(); + const auto& tokenized_output = params->GetTokenizedOutput(); + const auto seed = params->GetSeed(); + + // Compute for each ngram. + for (int ngram = 0; ngram < ngram_lengths.size(); ngram++) { + const int vocab_size = vocab_sizes[ngram]; + const int ngram_length = ngram_lengths[ngram]; + + // Compute for each token within the input. + for (int start = 0; start < tokenized_output.tokens.size(); start++) { + // Compute the number of bytes for the ngram starting at the given + // token. + int num_bytes = 0; + for (int i = start; + i < tokenized_output.tokens.size() && i < (start + ngram_length); + i++) { + num_bytes += tokenized_output.tokens[i].second; + } + + // Compute the hash for the ngram starting at the token. + const auto str_hash = MurmurHash64WithSeed( + tokenized_output.str.c_str() + tokenized_output.tokens[start].first, + num_bytes, seed); + + // Map the hash to an index in the vocab. + data[ngram * max_unicode_length + start] = (str_hash % vocab_size) + 1; + } + } +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + const uint8_t* buffer_t = reinterpret_cast(buffer); + const Map& m = GetRoot(buffer_t, length).AsMap(); + + const uint64_t seed = m["seed"].AsUInt64(); + const std::vector ngram_lengths = + GetIntVector(m["ngram_lengths"].AsTypedVector()); + const std::vector vocab_sizes = + GetIntVector(m["vocab_sizes"].AsTypedVector()); + const int max_splits = + m["max_splits"].IsNull() ? kDefaultMaxSplits : m["max_splits"].AsInt32(); + const bool lowercase_input = + m["lowercase_input"].IsNull() ? true : m["lowercase_input"].AsBool(); + + return new NGramHashParams(seed, ngram_lengths, vocab_sizes, max_splits, + lowercase_input); +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, kOutputLabel); + TF_LITE_ENSURE(context, output != nullptr); + SetTensorToDynamic(output); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + NGramHashParams* params = reinterpret_cast(node->user_data); + TF_LITE_ENSURE_OK( + context, + params->PreprocessInput(GetInput(context, node, kInputMessage), context)); + + TfLiteTensor* output = GetOutput(context, node, kOutputLabel); + TF_LITE_ENSURE(context, output != nullptr); + if (IsDynamicTensor(output)) { + TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); + output_size->data[0] = 1; + output_size->data[1] = params->GetNumNGrams(); + output_size->data[2] = params->GetNumTokens(); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + } else { + context->ReportError(context, "Output must by dynamic."); + return kTfLiteError; + } + + if (output->type == kTfLiteInt32) { + GetNGramHashIndices(params, output->data.i32); + } else { + context->ReportError(context, "Output type must be Int32."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace ngram_op + +TfLiteRegistration* Register_NGRAM_HASH() { + static TfLiteRegistration r = {ngram_op::Init, ngram_op::Free, + ngram_op::Resize, ngram_op::Eval}; + return &r; +} + +} // namespace tflite::ops::custom diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h new file mode 100644 index 000000000..a061357bd --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h @@ -0,0 +1,27 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_ + +#include "tensorflow/lite/kernels/register.h" + +namespace tflite::ops::custom { + +TfLiteRegistration* Register_NGRAM_HASH(); + +} // namespace tflite::ops::custom + +#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_ diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc new file mode 100644 index 000000000..28d2dea6e --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc @@ -0,0 +1,313 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "flatbuffers/flexbuffers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite::ops::custom { +namespace { + +using ::flexbuffers::Builder; +using ::mediapipe::tasks::text::language_detector::custom_ops::hash:: + MurmurHash64WithSeed; +using ::testing::ElementsAreArray; +using ::testing::Message; + +// Helper class for testing the op. +class NGramHashModel : public SingleOpModel { + public: + explicit NGramHashModel(const uint64_t seed, + const std::vector& ngram_lengths, + const std::vector& vocab_sizes, + const absl::optional max_splits = std::nullopt) { + // Setup the model inputs. + Builder fbb; + size_t start = fbb.StartMap(); + fbb.UInt("seed", seed); + { + size_t start = fbb.StartVector("ngram_lengths"); + for (const int& ngram_len : ngram_lengths) { + fbb.Int(ngram_len); + } + fbb.EndVector(start, /*typed=*/true, /*fixed=*/false); + } + { + size_t start = fbb.StartVector("vocab_sizes"); + for (const int& vocab_size : vocab_sizes) { + fbb.Int(vocab_size); + } + fbb.EndVector(start, /*typed=*/true, /*fixed=*/false); + } + if (max_splits) { + fbb.Int("max_splits", *max_splits); + } + fbb.EndMap(start); + fbb.Finish(); + output_ = AddOutput({TensorType_INT32, {}}); + SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH); + BuildInterpreter({GetShape(input_)}); + } + + void SetupInputTensor(const std::string& input) { + PopulateStringTensor(input_, {input}); + CHECK(interpreter_->AllocateTensors() == kTfLiteOk) + << "Cannot allocate tensors"; + } + + void Invoke(const std::string& input) { + SetupInputTensor(input); + CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk); + } + + TfLiteStatus InvokeUnchecked(const std::string& input) { + SetupInputTensor(input); + return SingleOpModel::Invoke(); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_ = AddInput(TensorType_STRING); + int output_; +}; + +TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) { + // Checks that the op returns the expected value when the input is sane. + // Also checks that when `max_splits` is not specified, the entire string is + // tokenized. + const uint64_t kSeed = 123; + const std::vector vocab_sizes({100, 200}); + std::vector ngram_lengths({1, 2}); + const std::vector testcase_inputs({ + "hi", + "wow", + "!", + "HI", + }); + + // A hash function that maps the given string to an index in the embedding + // table denoted by `vocab_idx`. + auto hash = [vocab_sizes](std::string str, const int vocab_idx) { + const auto hash_value = + MurmurHash64WithSeed(str.c_str(), str.size(), kSeed); + return static_cast((hash_value % vocab_sizes[vocab_idx]) + 1); + }; + const std::vector> expected_testcase_outputs( + {{ + // Unigram & Bigram output for "hi". + hash("^", 0), + hash("h", 0), + hash("i", 0), + hash("$", 0), + hash("^h", 1), + hash("hi", 1), + hash("i$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "wow". + hash("^", 0), + hash("w", 0), + hash("o", 0), + hash("w", 0), + hash("$", 0), + hash("^w", 1), + hash("wo", 1), + hash("ow", 1), + hash("w$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "!" (which will get replaced by " "). + hash("^", 0), + hash(" ", 0), + hash("$", 0), + hash("^ ", 1), + hash(" $", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "HI" (which will get lower-cased). + hash("^", 0), + hash("h", 0), + hash("i", 0), + hash("$", 0), + hash("^h", 1), + hash("hi", 1), + hash("i$", 1), + hash("$", 1), + }}); + + NGramHashModel m(kSeed, ngram_lengths, vocab_sizes); + for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) { + const string& testcase_input = testcase_inputs[test_idx]; + m.Invoke(testcase_input); + SCOPED_TRACE(Message() << "Where the testcases' input is: " + << testcase_input); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(expected_testcase_outputs[test_idx])); + EXPECT_THAT(m.GetOutputShape(), + ElementsAreArray( + {/*batch_size=*/1, static_cast(ngram_lengths.size()), + static_cast(testcase_input.size()) + /*padding*/ 2})); + } +} + +TEST(NGramHashTest, ReturnsExpectedValueWhenMaxSplitsIsSpecified) { + // Checks that the op returns the expected value when the input is correct + // when `max_splits` is specified. + const uint64_t kSeed = 123; + const std::vector vocab_sizes({100, 200}); + std::vector ngram_lengths({1, 2}); + + const std::string testcase_input = "wow"; + const std::vector max_splits({2, 3, 4, 5, 6}); + + // A hash function that maps the given string to an index in the embedding + // table denoted by `vocab_idx`. + auto hash = [vocab_sizes](std::string str, const int vocab_idx) { + const auto hash_value = + MurmurHash64WithSeed(str.c_str(), str.size(), kSeed); + return static_cast((hash_value % vocab_sizes[vocab_idx]) + 1); + }; + + const std::vector> expected_testcase_outputs( + {{ + // Unigram & Bigram output for "wow", when `max_splits` == 2. + // We cannot include any of the actual tokens, since `max_splits` + // only allows enough space for the delimiters. + hash("^", 0), + hash("$", 0), + hash("^$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "wow", when `max_splits` == 3. + // We can start to include some tokens from the input string. + hash("^", 0), + hash("w", 0), + hash("$", 0), + hash("^w", 1), + hash("w$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "wow", when `max_splits` == 4. + hash("^", 0), + hash("w", 0), + hash("o", 0), + hash("$", 0), + hash("^w", 1), + hash("wo", 1), + hash("o$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "wow", when `max_splits` == 5. + // We can include the full input string. + hash("^", 0), + hash("w", 0), + hash("o", 0), + hash("w", 0), + hash("$", 0), + hash("^w", 1), + hash("wo", 1), + hash("ow", 1), + hash("w$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "wow", when `max_splits` == 6. + // `max_splits` is more than the full input string. + hash("^", 0), + hash("w", 0), + hash("o", 0), + hash("w", 0), + hash("$", 0), + hash("^w", 1), + hash("wo", 1), + hash("ow", 1), + hash("w$", 1), + hash("$", 1), + }}); + + for (int test_idx = 0; test_idx < max_splits.size(); test_idx++) { + const int testcase_max_splits = max_splits[test_idx]; + NGramHashModel m(kSeed, ngram_lengths, vocab_sizes, testcase_max_splits); + m.Invoke(testcase_input); + SCOPED_TRACE(Message() << "Where `max_splits` is: " << testcase_max_splits); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(expected_testcase_outputs[test_idx])); + EXPECT_THAT( + m.GetOutputShape(), + ElementsAreArray( + {/*batch_size=*/1, static_cast(ngram_lengths.size()), + std::min( + // Longest possible tokenization when using the entire + // input. + static_cast(testcase_input.size()) + /*padding*/ 2, + // Longest possible string when the `max_splits` value + // is < testcase_input.size() + 2 for padding. + testcase_max_splits)})); + } +} + +TEST(NGramHashTest, InvalidMaxSplitsValue) { + // Check that the op errors out when given an invalid max splits value. + const std::vector invalid_max_splits({0, -1, -5, -100}); + for (const int max_splits : invalid_max_splits) { + NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200}, + /*vocab_sizes=*/{1, 2}, /*max_splits=*/max_splits); + EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError); + } +} + +TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) { + // Check that the op errors out when ngram lengths and vocab sizes mistmatch. + { + NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200, 300}, + /*vocab_sizes=*/{1, 2}); + EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError); + } + { + NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200}, + /*vocab_sizes=*/{1, 2, 3}); + EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError); + } +} + +} // namespace +} // namespace tflite::ops::custom From a32382513468d6d0ea2888c0fce0e754df57329a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 15 Mar 2023 11:31:28 -0700 Subject: [PATCH 06/10] Internal change PiperOrigin-RevId: 516882513 --- .../tensors_to_segmentation_calculator.cc | 209 ++++++++++++------ .../image_segmenter/image_segmenter_graph.cc | 2 +- .../image_segmenter/image_segmenter_test.cc | 91 +++++++- mediapipe/tasks/testdata/vision/BUILD | 6 + third_party/external_files.bzl | 26 ++- 5 files changed, 250 insertions(+), 84 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index 091e4d6c9..b6c1fe6b0 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -79,6 +80,133 @@ void Sigmoid(absl::Span values, [](float value) { return 1. / (1 + std::exp(-value)); }); } +std::vector ProcessForCategoryMaskCpu(const Shape& input_shape, + const Shape& output_shape, + const SegmenterOptions& options, + const float* tensors_buffer) { + cv::Mat resized_tensors_mat; + cv::Mat tensors_mat_view( + input_shape.height, input_shape.width, CV_32FC(input_shape.channels), + reinterpret_cast(const_cast(tensors_buffer))); + if (output_shape.height == input_shape.height && + output_shape.width == input_shape.width) { + resized_tensors_mat = tensors_mat_view; + } else { + // Resize input tensors to output size. + // TOOD(b/273633027) Use an efficient way to find values for category mask + // instead of resizing the whole tensor . + cv::resize(tensors_mat_view, resized_tensors_mat, + {output_shape.width, output_shape.height}, 0, 0, + cv::INTER_LINEAR); + } + + // Category mask Image. + ImageFrameSharedPtr image_frame_ptr = std::make_shared( + ImageFormat::GRAY8, output_shape.width, output_shape.height, 1); + Image category_mask(image_frame_ptr); + + // Fill in the maximum category in the category mask image. + cv::Mat category_mask_mat_view = + mediapipe::formats::MatView(image_frame_ptr.get()); + const int input_channels = input_shape.channels; + category_mask_mat_view.forEach( + [&resized_tensors_mat, &input_channels, &options](uint8_t& pixel, + const int position[]) { + float* tensors_buffer = + resized_tensors_mat.ptr(position[0], position[1]); + absl::Span confidence_scores(tensors_buffer, input_channels); + // Only process the activation function if it is SIGMOID. If NONE, + // we do nothing for activation, If SOFTMAX, it is required + // to have input_channels > 1, and for input_channels > 1, we don't need + // activation to find the maximum value. + if (options.activation() == SegmenterOptions::SIGMOID) { + Sigmoid(confidence_scores, confidence_scores); + } + if (input_channels == 1) { + // if the input tensor is a single mask, it is assumed to be a binary + // foreground segmentation mask. For such a mask, we make foreground + // category 1, and background category 0. + pixel = static_cast(*tensors_buffer > 0.5f); + } else { + const int maximum_category_idx = + std::max_element(confidence_scores.begin(), + confidence_scores.end()) - + confidence_scores.begin(); + pixel = maximum_category_idx; + } + }); + return {category_mask}; +} + +std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, + const Shape& output_shape, + const SegmenterOptions& options, + const float* tensors_buffer) { + std::function values, + absl::Span activated_values)> + activation_fn; + switch (options.activation()) { + case SegmenterOptions::SIGMOID: + activation_fn = &Sigmoid; + break; + case SegmenterOptions::SOFTMAX: + activation_fn = &StableSoftmax; + break; + case SegmenterOptions::NONE: + // Just copying for NONE activation. + activation_fn = [](absl::Span values, + absl::Span activated_values) { + std::copy(values.begin(), values.end(), activated_values.begin()); + }; + break; + } + + // TODO Use libyuv for resizing instead. + std::vector confidence_masks; + std::vector confidence_mask_mats; + confidence_masks.reserve(input_shape.channels); + confidence_mask_mats.reserve(input_shape.channels); + for (int i = 0; i < input_shape.channels; ++i) { + confidence_masks.push_back(Image(std::make_shared( + ImageFormat::VEC32F1, input_shape.width, input_shape.height, 1))); + confidence_mask_mats.push_back(mediapipe::formats::MatView( + confidence_masks.back().GetImageFrameSharedPtr().get())); + } + + // Applies activation function. + const int tensor_size = input_shape.height * input_shape.width; + std::vector activated_values(input_shape.channels); + absl::Span activated_values_span(activated_values); + for (int i = 0; i < tensor_size; ++i) { + activation_fn(absl::MakeConstSpan(&tensors_buffer[i * input_shape.channels], + input_shape.channels), + activated_values_span); + for (int j = 0; j < input_shape.channels; ++j) { + confidence_mask_mats[j].at( + i / input_shape.width, i % input_shape.width) = activated_values[j]; + } + } + if (output_shape.height == input_shape.height && + output_shape.width == input_shape.width) { + return confidence_masks; + } + std::vector resized_confidence_masks; + resized_confidence_masks.reserve(confidence_mask_mats.size()); + // Resizes segmented masks to required output size. + for (int i = 0; i < confidence_mask_mats.size(); i++) { + // Pre-allocates ImageFrame memory to avoid copying from cv::Mat + // afterward. + ImageFrameSharedPtr image_frame_ptr = std::make_shared( + ImageFormat::VEC32F1, output_shape.width, output_shape.height, 1); + cv::Mat resized_mask_mat_view = + mediapipe::formats::MatView(image_frame_ptr.get()); + cv::resize(confidence_mask_mats[i], resized_mask_mat_view, + resized_mask_mat_view.size(), 0, 0, cv::INTER_LINEAR); + resized_confidence_masks.push_back(Image(image_frame_ptr)); + } + return resized_confidence_masks; +} + } // namespace // Converts Tensors from a vector of Tensor to Segmentation. @@ -222,81 +350,16 @@ absl::Status TensorsToSegmentationCalculator::Process( std::vector TensorsToSegmentationCalculator::GetSegmentationResultCpu( const Shape& input_shape, const Shape& output_shape, const float* tensors_buffer) { - std::function values, - absl::Span activated_values)> - activation_fn; - switch (options_.segmenter_options().activation()) { - case SegmenterOptions::SIGMOID: - activation_fn = &Sigmoid; - break; - case SegmenterOptions::SOFTMAX: - activation_fn = &StableSoftmax; - break; - case SegmenterOptions::NONE: - // Just copying for NONE activation. - activation_fn = [](absl::Span values, - absl::Span activated_values) { - std::copy(values.begin(), values.end(), activated_values.begin()); - }; - break; - } - - const bool is_category_mask = options_.segmenter_options().output_type() == - SegmenterOptions::CATEGORY_MASK; - const int cv_mat_type = is_category_mask ? CV_8UC1 : CV_32FC1; - const int output_masks_num = output_shape.channels; - - // TODO Use libyuv for resizing instead. - std::vector segmented_mask_mats; - segmented_mask_mats.reserve(output_masks_num); - for (int i = 0; i < output_masks_num; ++i) { - segmented_mask_mats.push_back( - cv::Mat(input_shape.height, input_shape.width, cv_mat_type)); - } - - // Applies activation function. - const int tensor_size = input_shape.height * input_shape.width; - if (is_category_mask) { - for (int i = 0; i < tensor_size; ++i) { - absl::Span confidence_scores( - &tensors_buffer[i * input_shape.channels], input_shape.channels); - const int maximum_category_idx = - std::max_element(confidence_scores.begin(), confidence_scores.end()) - - confidence_scores.begin(); - segmented_mask_mats[0].at( - i / input_shape.width, i % input_shape.width) = maximum_category_idx; - } + if (options_.segmenter_options().output_type() == + SegmenterOptions::CATEGORY_MASK) { + return ProcessForCategoryMaskCpu(input_shape, output_shape, + options_.segmenter_options(), + tensors_buffer); } else { - std::vector activated_values(input_shape.channels); - absl::Span activated_values_span(activated_values); - for (int i = 0; i < tensor_size; ++i) { - activation_fn( - absl::MakeConstSpan(&tensors_buffer[i * input_shape.channels], - input_shape.channels), - activated_values_span); - for (int j = 0; j < input_shape.channels; ++j) { - segmented_mask_mats[j].at( - i / input_shape.width, i % input_shape.width) = activated_values[j]; - } - } + return ProcessForConfidenceMaskCpu(input_shape, output_shape, + options_.segmenter_options(), + tensors_buffer); } - - std::vector segmented_masks; - segmented_masks.reserve(output_masks_num); - // Resizes segmented masks to required output size. - for (int i = 0; i < segmented_mask_mats.size(); i++) { - // Pre-allocates ImageFrame memory to avoid copying from cv::Mat afterward. - ImageFrameSharedPtr image_frame_ptr = std::make_shared( - is_category_mask ? ImageFormat::GRAY8 : ImageFormat::VEC32F1, - output_shape.width, output_shape.height, 1); - cv::Mat resized_mask_mat_view = - mediapipe::formats::MatView(image_frame_ptr.get()); - cv::resize(segmented_mask_mats[i], resized_mask_mat_view, - resized_mask_mat_view.size(), 0, 0, - cv_mat_type == CV_8UC1 ? cv::INTER_NEAREST : cv::INTER_LINEAR); - segmented_masks.push_back(Image(image_frame_ptr)); - } - return segmented_masks; } MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::TensorsToSegmentationCalculator); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index c4a4065c6..6a7e08626 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -401,7 +401,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { } else { ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, GetOutputTensor(model_resources)); - const int segmentation_streams_num = *output_tensor->shape()->rbegin(); + int segmentation_streams_num = *output_tensor->shape()->rbegin(); for (int i = 0; i < segmentation_streams_num; ++i) { segmented_masks.push_back(Source( tensor_to_images[Output::Multiple(kSegmentationTag)][i])); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index ab5d184db..d063ca87a 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -62,6 +62,11 @@ constexpr char kSelfie128x128WithMetadata[] = "selfie_segm_128_128_3.tflite"; constexpr char kSelfie144x256WithMetadata[] = "selfie_segm_144_256_3.tflite"; +constexpr char kSelfieSegmentation[] = "selfie_segmentation.tflite"; + +constexpr char kSelfieSegmentationLandscape[] = + "selfie_segmentation_landscape.tflite"; + constexpr char kHairSegmentationWithMetadata[] = "hair_segmentation.tflite"; constexpr float kGoldenMaskSimilarity = 0.98; @@ -90,13 +95,8 @@ cv::Mat PostProcessResultMask(const cv::Mat& mask) { } Image GetSRGBImage(const std::string& image_path) { - // TODO: fix test so RGB really is used and not BGR/BGRA. - // mediapipe/app/aimatter/segmentation/segmenter_test_common.cc - // golden masks are generated with BGR image. To align with the unittest of - // aimatter segmenter, here reads image as BGR as well (opencv reads image as - // BGR). Once the correctness of mediapipe tasks segmenter is verified, change - // the golden masks to be generated by RGB image. cv::Mat image_mat = cv::imread(image_path); + cv::cvtColor(image_mat, image_mat, cv::COLOR_BGR2RGB); mediapipe::ImageFrame image_frame( mediapipe::ImageFormat::SRGB, image_mat.cols, image_mat.rows, image_mat.step, image_mat.data, [image_mat](uint8_t[]) {}); @@ -435,6 +435,85 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } +TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { + Image image = + GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg")); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kSelfieSegmentation); + options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; + options->activation = ImageSegmenterOptions::Activation::NONE; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); + EXPECT_EQ(confidence_masks.size(), 1); + MP_ASSERT_OK(segmenter->Close()); + + cv::Mat expected_mask = cv::imread( + JoinPath("./", kTestDataDirectory, + "portrait_selfie_segmentation_expected_confidence_mask.jpg"), + cv::IMREAD_GRAYSCALE); + cv::Mat expected_mask_float; + expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); + + cv::Mat selfie_mask = mediapipe::formats::MatView( + confidence_masks[0].GetImageFrameSharedPtr().get()); + EXPECT_THAT(selfie_mask, + SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); +} + +TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) { + Image image = + GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg")); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kSelfieSegmentation); + options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + options->activation = ImageSegmenterOptions::Activation::NONE; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image)); + EXPECT_EQ(category_mask.size(), 1); + MP_ASSERT_OK(segmenter->Close()); + + cv::Mat selfie_mask = mediapipe::formats::MatView( + category_mask[0].GetImageFrameSharedPtr().get()); + cv::Mat expected_mask = cv::imread( + JoinPath("./", kTestDataDirectory, + "portrait_selfie_segmentation_expected_category_mask.jpg"), + cv::IMREAD_GRAYSCALE); + EXPECT_THAT(selfie_mask, + SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255)); +} + +TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) { + Image image = + GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg")); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape); + options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + options->activation = ImageSegmenterOptions::Activation::NONE; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image)); + EXPECT_EQ(category_mask.size(), 1); + MP_ASSERT_OK(segmenter->Close()); + + cv::Mat selfie_mask = mediapipe::formats::MatView( + category_mask[0].GetImageFrameSharedPtr().get()); + cv::Mat expected_mask = cv::imread( + JoinPath( + "./", kTestDataDirectory, + "portrait_selfie_segmentation_landscape_expected_category_mask.jpg"), + cv::IMREAD_GRAYSCALE); + EXPECT_THAT(selfie_mask, + SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255)); +} + TEST_F(ImageModeTest, SucceedsHairSegmentation) { Image image = GetSRGBAImage(JoinPath("./", kTestDataDirectory, "portrait.jpg")); diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 087d0ea75..ac76bfa23 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -70,6 +70,9 @@ mediapipe_files(srcs = [ "portrait.jpg", "portrait_hair_expected_mask.jpg", "portrait_rotated.jpg", + "portrait_selfie_segmentation_expected_category_mask.jpg", + "portrait_selfie_segmentation_expected_confidence_mask.jpg", + "portrait_selfie_segmentation_landscape_expected_category_mask.jpg", "pose.jpg", "pose_detection.tflite", "right_hands.jpg", @@ -129,6 +132,9 @@ filegroup( "portrait.jpg", "portrait_hair_expected_mask.jpg", "portrait_rotated.jpg", + "portrait_selfie_segmentation_expected_category_mask.jpg", + "portrait_selfie_segmentation_expected_confidence_mask.jpg", + "portrait_selfie_segmentation_landscape_expected_category_mask.jpg", "pose.jpg", "right_hands.jpg", "right_hands_rotated.jpg", diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index b290fbcbe..52636f427 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -886,6 +886,24 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_rotated.jpg?generation=1677194680138164"], ) + http_file( + name = "com_google_mediapipe_portrait_selfie_segmentation_expected_category_mask_jpg", + sha256 = "d8f20fa746e14067f668dd293f21bbc50ec81196d186386a6ded1278c3ec8f46", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_expected_category_mask.jpg?generation=1678606935088873"], + ) + + http_file( + name = "com_google_mediapipe_portrait_selfie_segmentation_expected_confidence_mask_jpg", + sha256 = "25b723e90608edaf6ed92f382da703dc904a59c87525b6d271e60d9eed7a90e9", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_expected_confidence_mask.jpg?generation=1678606937358235"], + ) + + http_file( + name = "com_google_mediapipe_portrait_selfie_segmentation_landscape_expected_category_mask_jpg", + sha256 = "f5c3fa3d93f8e7289b69b8a89c2519276dfa5014dcc50ed6e86e8cd4d4ae7f27", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_landscape_expected_category_mask.jpg?generation=1678606939469429"], + ) + http_file( name = "com_google_mediapipe_pose_detection_tflite", sha256 = "9ba9dd3d42efaaba86b4ff0122b06f29c4122e756b329d89dca1e297fd8f866c", @@ -1014,8 +1032,8 @@ def external_files(): http_file( name = "com_google_mediapipe_selfie_segm_128_128_3_expected_mask_jpg", - sha256 = "a295f3ab394a5e0caff2db5041337da58341ec331f1413ef91f56e0d650b4a1e", - urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_128_128_3_expected_mask.jpg?generation=1661875916766416"], + sha256 = "1a2a068287d8bcd4184492485b3dbb95a09b763f4653fd729d14a836147eb383", + urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_128_128_3_expected_mask.jpg?generation=1678606942616777"], ) http_file( @@ -1026,8 +1044,8 @@ def external_files(): http_file( name = "com_google_mediapipe_selfie_segm_144_256_3_expected_mask_jpg", - sha256 = "cfc699db9670585c04414d0d1a07b289a027ba99d6903d2219f897d34e2c9952", - urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_144_256_3_expected_mask.jpg?generation=1661875922646736"], + sha256 = "2de433b6e8adabec2aaf80135232db900903ead4f2811c0c9378a6792b2a68b5", + urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_144_256_3_expected_mask.jpg?generation=1678606945085676"], ) http_file( From 59962bed27cde01045756778850400e770ed4486 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 15 Mar 2023 12:08:28 -0700 Subject: [PATCH 07/10] ImageSegmenterGraph set activation type from metadata, and remove the activation config in C++ ImageSegmenterOptions. PiperOrigin-RevId: 516893115 --- .../tasks/cc/vision/image_segmenter/BUILD | 1 + .../vision/image_segmenter/image_segmenter.cc | 14 ------ .../vision/image_segmenter/image_segmenter.h | 9 ---- .../image_segmenter/image_segmenter_graph.cc | 48 +++++++++++++++++-- .../image_segmenter/image_segmenter_test.cc | 9 ---- .../vision/imagesegmenter/ImageSegmenter.java | 3 -- 6 files changed, 46 insertions(+), 38 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 1123204ce..69833a5f6 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -80,6 +80,7 @@ cc_library( "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "//mediapipe/tasks/metadata:image_segmenter_metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 9769b47d5..c12fe7f7e 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -101,20 +101,6 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) { SegmenterOptions::CONFIDENCE_MASK); break; } - switch (options->activation) { - case ImageSegmenterOptions::Activation::NONE: - options_proto->mutable_segmenter_options()->set_activation( - SegmenterOptions::NONE); - break; - case ImageSegmenterOptions::Activation::SIGMOID: - options_proto->mutable_segmenter_options()->set_activation( - SegmenterOptions::SIGMOID); - break; - case ImageSegmenterOptions::Activation::SOFTMAX: - options_proto->mutable_segmenter_options()->set_activation( - SegmenterOptions::SOFTMAX); - break; - } return options_proto; } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index c757296e4..076a5016c 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -64,15 +64,6 @@ struct ImageSegmenterOptions { OutputType output_type = OutputType::CATEGORY_MASK; - // The activation function used on the raw segmentation model output. - enum Activation { - NONE = 0, // No activation function is used. - SIGMOID = 1, // Assumes 1-channel input tensor. - SOFTMAX = 2, // Assumes multi-channel input tensor. - }; - - Activation activation = Activation::NONE; - // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 6a7e08626..fe6265b73 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" +#include "mediapipe/tasks/metadata/image_segmenter_metadata_schema_generated.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map_util.h" @@ -74,6 +75,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; +constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA"; // Struct holding the different output streams produced by the image segmenter // subgraph. @@ -130,7 +132,49 @@ absl::Status ConfigureTensorsToSegmentationCalculator( const ImageSegmenterGraphOptions& segmenter_option, const core::ModelResources& model_resources, TensorsToSegmentationCalculatorOptions* options) { - *options->mutable_segmenter_options() = segmenter_option.segmenter_options(); + // Set default activation function NONE + options->mutable_segmenter_options()->set_output_type( + segmenter_option.segmenter_options().output_type()); + options->mutable_segmenter_options()->set_activation(SegmenterOptions::NONE); + // Find the custom metadata of ImageSegmenterOptions type in model metadata. + const auto* metadata_extractor = model_resources.GetMetadataExtractor(); + bool found_activation_in_metadata = false; + if (metadata_extractor->GetCustomMetadataList() != nullptr && + metadata_extractor->GetCustomMetadataList()->size() > 0) { + for (const auto& custom_metadata : + *metadata_extractor->GetCustomMetadataList()) { + if (custom_metadata->name()->str() == kSegmentationMetadataName) { + found_activation_in_metadata = true; + auto activation_fb = + GetImageSegmenterOptions(custom_metadata->data()->data()) + ->activation(); + switch (activation_fb) { + case Activation_NONE: + options->mutable_segmenter_options()->set_activation( + SegmenterOptions::NONE); + break; + case Activation_SIGMOID: + options->mutable_segmenter_options()->set_activation( + SegmenterOptions::SIGMOID); + break; + case Activation_SOFTMAX: + options->mutable_segmenter_options()->set_activation( + SegmenterOptions::SOFTMAX); + break; + default: + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Invalid activation type found in CustomMetadata of " + "ImageSegmenterOptions type."); + } + } + } + } + if (!found_activation_in_metadata) { + LOG(WARNING) + << "No activation type is found in model metadata. Use NONE for " + "ImageSegmenterGraph."; + } const tflite::Model& model = *model_resources.GetTfLiteModel(); if (model.subgraphs()->size() != 1) { return CreateStatusWithPayload( @@ -146,8 +190,6 @@ absl::Status ConfigureTensorsToSegmentationCalculator( MediaPipeTasksStatus::kInvalidArgumentError); } - const ModelMetadataExtractor* metadata_extractor = - model_resources.GetMetadataExtractor(); ASSIGN_OR_RETURN( *options->mutable_label_items(), GetLabelItemsIfAny(*metadata_extractor, diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index d063ca87a..1d75a3fb7 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -304,7 +304,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::SOFTMAX; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -333,7 +332,6 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::SOFTMAX; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -364,7 +362,6 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::SOFTMAX; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -388,7 +385,6 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::SOFTMAX; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -416,7 +412,6 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::NONE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); @@ -442,7 +437,6 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfieSegmentation); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::NONE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -470,7 +464,6 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfieSegmentation); options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; - options->activation = ImageSegmenterOptions::Activation::NONE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -495,7 +488,6 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape); options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; - options->activation = ImageSegmenterOptions::Activation::NONE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -521,7 +513,6 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::SOFTMAX; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java index 299423003..931740c8e 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -641,9 +641,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi { SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK); } - // TODO: remove this once activation is handled in metadata and grpah level. - segmenterOptionsBuilder.setActivation( - SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX); taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( From 43082482f85321607987b083a356fa8962a6f8c8 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 15 Mar 2023 14:21:50 -0700 Subject: [PATCH 08/10] Remove framework:Cocoa again PiperOrigin-RevId: 516928735 --- third_party/BUILD | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/third_party/BUILD b/third_party/BUILD index ea037dce1..7522bab1b 100644 --- a/third_party/BUILD +++ b/third_party/BUILD @@ -169,11 +169,7 @@ cmake_external( "-lm", "-lpthread", "-lrt", - ] + select({ - "//mediapipe:ios": ["-framework Cocoa"], - "//mediapipe:macos": ["-framework Cocoa"], - "//conditions:default": [], - }), + ], shared_libraries = select({ "@bazel_tools//src/conditions:darwin": ["libopencv_%s.%s.dylib" % (module, OPENCV_SO_VERSION) for module in OPENCV_MODULES], # Only the shared objects listed here will be linked in the directory From 61bcddc6711a2b9d1a2a7676bb5fca1360360cd4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 15 Mar 2023 16:02:48 -0700 Subject: [PATCH 09/10] Add Interactive Segmenter MediaPipe Task PiperOrigin-RevId: 516954589 --- .../cc/vision/interactive_segmenter/BUILD | 76 +++++ .../interactive_segmenter.cc | 163 +++++++++++ .../interactive_segmenter.h | 136 +++++++++ .../interactive_segmenter_graph.cc | 198 +++++++++++++ .../interactive_segmenter_test.cc | 261 ++++++++++++++++++ 5 files changed, 834 insertions(+) create mode 100644 mediapipe/tasks/cc/vision/interactive_segmenter/BUILD create mode 100644 mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc create mode 100644 mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h create mode 100644 mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc create mode 100644 mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD b/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD new file mode 100644 index 000000000..ea72d3d99 --- /dev/null +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD @@ -0,0 +1,76 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# Docs for Mediapipe Tasks Interactive Segmenter +# TODO: add doc link. +cc_library( + name = "interactive_segmenter", + srcs = ["interactive_segmenter.cc"], + hdrs = ["interactive_segmenter.h"], + deps = [ + ":interactive_segmenter_graph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components/containers:keypoint", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "interactive_segmenter_graph", + srcs = ["interactive_segmenter_graph.cc"], + deps = [ + "@com_google_absl//absl/strings", + "//mediapipe/calculators/image:set_alpha_calculator", + "//mediapipe/calculators/util:annotation_overlay_calculator", + "//mediapipe/calculators/util:flat_color_image_calculator", + "//mediapipe/calculators/util:flat_color_image_calculator_cc_proto", + "//mediapipe/calculators/util:from_image_calculator", + "//mediapipe/calculators/util:to_image_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:label_map_cc_proto", + "//mediapipe/util:render_data_cc_proto", + ] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ + "//mediapipe/gpu:gpu_buffer_to_image_frame_calculator", + "//mediapipe/gpu:image_frame_to_gpu_buffer_calculator", + ], + }), + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc new file mode 100644 index 000000000..4298d4a19 --- /dev/null +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc @@ -0,0 +1,163 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" +#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace interactive_segmenter { +namespace { + +constexpr char kSegmentationStreamName[] = "segmented_mask_out"; +constexpr char kImageInStreamName[] = "image_in"; +constexpr char kImageOutStreamName[] = "image_out"; +constexpr char kRoiStreamName[] = "roi_in"; +constexpr char kNormRectStreamName[] = "norm_rect_in"; + +constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; +constexpr char kImageTag[] = "IMAGE"; +constexpr char kRoiTag[] = "ROI"; +constexpr char kNormRectTag[] = "NORM_RECT"; + +constexpr char kSubgraphTypeName[] = + "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"; + +using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::Image; +using ::mediapipe::NormalizedRect; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; +using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: + image_segmenter::proto::ImageSegmenterGraphOptions; + +// Creates a MediaPipe graph config that only contains a single subgraph node of +// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph". +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options) { + api2::builder::Graph graph; + auto& task_subgraph = graph.AddNode(kSubgraphTypeName); + task_subgraph.GetOptions().Swap( + options.get()); + graph.In(kImageTag).SetName(kImageInStreamName); + graph.In(kRoiTag).SetName(kRoiStreamName); + graph.In(kNormRectTag).SetName(kNormRectStreamName); + task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> + graph.Out(kGroupedSegmentationTag); + task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> + graph.Out(kImageTag); + graph.In(kImageTag) >> task_subgraph.In(kImageTag); + graph.In(kRoiTag) >> task_subgraph.In(kRoiTag); + graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); + return graph.GetConfig(); +} + +// Converts the user-facing InteractiveSegmenterOptions struct to the internal +// ImageSegmenterOptions proto. +std::unique_ptr +ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + switch (options->output_type) { + case InteractiveSegmenterOptions::OutputType::CATEGORY_MASK: + options_proto->mutable_segmenter_options()->set_output_type( + SegmenterOptions::CATEGORY_MASK); + break; + case InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK: + options_proto->mutable_segmenter_options()->set_output_type( + SegmenterOptions::CONFIDENCE_MASK); + break; + } + return options_proto; +} + +// Converts the user-facing RegionOfInterest struct to the RenderData proto that +// is used in subgraph. +absl::StatusOr ConvertRoiToRenderData(const RegionOfInterest& roi) { + RenderData result; + switch (roi.format) { + case RegionOfInterest::UNSPECIFIED: + return absl::InvalidArgumentError( + "RegionOfInterest format not specified"); + case RegionOfInterest::KEYPOINT: + RET_CHECK(roi.keypoint.has_value()); + auto* annotation = result.add_render_annotations(); + annotation->mutable_color()->set_r(255); + auto* point = annotation->mutable_point(); + point->set_normalized(true); + point->set_x(roi.keypoint->x); + point->set_y(roi.keypoint->y); + return result; + } + return absl::UnimplementedError("Unrecognized format"); +} + +} // namespace + +absl::StatusOr> +InteractiveSegmenter::Create( + std::unique_ptr options) { + auto options_proto = ConvertImageSegmenterOptionsToProto(options.get()); + return core::VisionTaskApiFactory::Create( + CreateGraphConfig(std::move(options_proto)), + std::move(options->base_options.op_resolver), core::RunningMode::IMAGE, + /*packets_callback=*/nullptr); +} + +absl::StatusOr> InteractiveSegmenter::Segment( + mediapipe::Image image, const RegionOfInterest& roi, + std::optional image_processing_options) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("GPU input images are currently not supported."), + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); + ASSIGN_OR_RETURN(RenderData roi_as_render_data, ConvertRoiToRenderData(roi)); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessImageData( + {{kImageInStreamName, mediapipe::MakePacket(std::move(image))}, + {kRoiStreamName, + mediapipe::MakePacket(std::move(roi_as_render_data))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect))}})); + return output_packets[kSegmentationStreamName].Get>(); +} + +} // namespace interactive_segmenter +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h new file mode 100644 index 000000000..420b22462 --- /dev/null +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h @@ -0,0 +1,136 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/tasks/cc/components/containers/keypoint.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace interactive_segmenter { + +// The options for configuring a mediapipe interactive segmenter task. +struct InteractiveSegmenterOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + tasks::core::BaseOptions base_options; + + // The output type of segmentation results. + enum OutputType { + // Gives a single output mask where each pixel represents the class which + // the pixel in the original image was predicted to belong to. + CATEGORY_MASK = 0, + // Gives a list of output masks where, for each mask, each pixel represents + // the prediction confidence, usually in the [0, 1] range. + CONFIDENCE_MASK = 1, + }; + + OutputType output_type = OutputType::CATEGORY_MASK; +}; + +// The Region-Of-Interest (ROI) to interact with. +struct RegionOfInterest { + enum Format { + UNSPECIFIED = 0, // Format not specified. + KEYPOINT = 1, // Using keypoint to represent ROI. + }; + + // Specifies the format used to specify the region-of-interest. Note that + // using `UNSPECIFIED` is invalid and will lead to an `InvalidArgument` status + // being returned. + Format format = Format::UNSPECIFIED; + + // Represents the ROI in keypoint format, this should be non-nullopt if + // `format` is `KEYPOINT`. + std::optional keypoint; +}; + +// Performs interactive segmentation on images. +// +// Users can represent user interaction through `RegionOfInterest`, which gives +// a hint to InteractiveSegmenter to perform segmentation focusing on the given +// region of interest. +// +// The API expects a TFLite model with mandatory TFLite Model Metadata. +// +// Input tensor: +// (kTfLiteUInt8/kTfLiteFloat32) +// - image input of size `[batch x height x width x channels]`. +// - batch inference is not supported (`batch` is required to be 1). +// - RGB inputs is supported (`channels` is required to be 3). +// - if type is kTfLiteFloat32, NormalizationOptions are required to be +// attached to the metadata for input normalization. +// Output tensors: +// (kTfLiteUInt8/kTfLiteFloat32) +// - list of segmented masks. +// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. +// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size +// `channels`. +// - batch is always 1 +class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates an InteractiveSegmenter from the provided options. A non-default + // OpResolver can be specified in the BaseOptions of + // InteractiveSegmenterOptions, to support custom Ops of the segmentation + // model. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Performs image segmentation on the provided single image. + // + // The image can be of any size with format RGB. + // + // The `roi` parameter is used to represent user's region of interest for + // segmentation. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // + // If the output_type is CATEGORY_MASK, the returned vector of images is + // per-category segmented image mask. + // If the output_type is CONFIDENCE_MASK, the returned vector of images + // contains only one confidence image mask. + absl::StatusOr> Segment( + mediapipe::Image image, const RegionOfInterest& roi, + std::optional image_processing_options = + std::nullopt); + + // Shuts down the InteractiveSegmenter when all works are done. + absl::Status Close() { return runner_->Close(); } +}; + +} // namespace interactive_segmenter +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_ diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc new file mode 100644 index 000000000..4c0cd2a88 --- /dev/null +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc @@ -0,0 +1,198 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/label_map.pb.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace interactive_segmenter { + +namespace { + +using image_segmenter::proto::ImageSegmenterGraphOptions; +using ::mediapipe::Image; +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; + +constexpr char kSegmentationTag[] = "SEGMENTATION"; +constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageCpuTag[] = "IMAGE_CPU"; +constexpr char kImageGpuTag[] = "IMAGE_GPU"; +constexpr char kAlphaTag[] = "ALPHA"; +constexpr char kAlphaGpuTag[] = "ALPHA_GPU"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kRoiTag[] = "ROI"; +constexpr char kVideoTag[] = "VIDEO"; + +// Updates the graph to return `roi` stream which has same dimension as +// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is +// in GpuBuffer format, otherwise using ImageFrame. +Source<> RoiToAlpha(Source image, Source roi, bool use_gpu, + Graph& graph) { + // TODO: Replace with efficient implementation. + const absl::string_view image_tag_with_suffix = + use_gpu ? kImageGpuTag : kImageCpuTag; + + // Generates a blank canvas with same size as input image. + auto& flat_color = graph.AddNode("FlatColorImageCalculator"); + auto& flat_color_options = + flat_color.GetOptions(); + // SetAlphaCalculator only takes 1st channel. + flat_color_options.mutable_color()->set_r(0); + image >> flat_color.In(kImageTag)[0]; + auto blank_canvas = flat_color.Out(kImageTag)[0]; + + auto& from_mp_image = graph.AddNode("FromImageCalculator"); + blank_canvas >> from_mp_image.In(kImageTag); + auto blank_canvas_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix); + + auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator"); + blank_canvas_in_cpu_or_gpu >> + roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag); + roi >> roi_to_alpha.In(0); + auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag); + + return alpha; +} + +} // namespace + +// An "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph" +// performs semantic segmentation given user's region-of-interest. Two kinds of +// outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. Users can +// retrieve segmented mask of only particular category/channel from +// SEGMENTATION, and users can also get all segmented masks from +// GROUPED_SEGMENTATION. +// - Accepts CPU input images and outputs segmented masks on CPU. +// +// Inputs: +// IMAGE - Image +// Image to perform segmentation on. +// ROI - RenderData proto +// Region of interest based on user interaction. Currently only support +// Point format, and Color has to be (255, 255, 255). +// NORM_RECT - NormalizedRect @Optional +// Describes image rotation and region of image to perform detection +// on. +// @Optional: rect covering the whole image is used if not specified. +// +// Outputs: +// SEGMENTATION - mediapipe::Image @Multiple +// Segmented masks for individual category. Segmented mask of single +// category can be accessed by index based output stream. +// GROUPED_SEGMENTATION - std::vector +// The output segmented masks grouped in a vector. +// IMAGE - mediapipe::Image +// The image that image segmenter runs on. +// +// Example: +// node { +// calculator: +// "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph" +// input_stream: "IMAGE:image" +// input_stream: "ROI:region_of_interest" +// output_stream: "SEGMENTATION:segmented_masks" +// options { +// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } +// segmenter_options { +// output_type: CONFIDENCE_MASK +// } +// } +// } +// } +class InteractiveSegmenterGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + mediapipe::SubgraphContext* sc) override { + Graph graph; + const auto& task_options = sc->Options(); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + + Source image = graph[Input(kImageTag)]; + Source roi = graph[Input(kRoiTag)]; + Source norm_rect = + graph[Input(kNormRectTag)]; + const absl::string_view image_tag_with_suffix = + use_gpu ? kImageGpuTag : kImageCpuTag; + const absl::string_view alpha_tag_with_suffix = + use_gpu ? kAlphaGpuTag : kAlphaTag; + + auto& from_mp_image = graph.AddNode("FromImageCalculator"); + image >> from_mp_image.In(kImageTag); + auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix); + + auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph); + + auto& set_alpha = graph.AddNode("SetAlphaCalculator"); + image_in_cpu_or_gpu >> set_alpha.In(use_gpu ? kImageGpuTag : kImageTag); + alpha_in_cpu_or_gpu >> set_alpha.In(alpha_tag_with_suffix); + auto image_in_cpu_or_gpu_with_set_alpha = + set_alpha.Out(use_gpu ? kImageGpuTag : kImageTag); + + auto& to_mp_image = graph.AddNode("ToImageCalculator"); + image_in_cpu_or_gpu_with_set_alpha >> to_mp_image.In(image_tag_with_suffix); + auto image_with_set_alpha = to_mp_image.Out(kImageTag); + + auto& image_segmenter = graph.AddNode( + "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"); + image_segmenter.GetOptions() = task_options; + image_with_set_alpha >> image_segmenter.In(kImageTag); + norm_rect >> image_segmenter.In(kNormRectTag); + + image_segmenter.Out(kSegmentationTag) >> + graph[Output(kSegmentationTag)]; + image_segmenter.Out(kGroupedSegmentationTag) >> + graph[Output>(kGroupedSegmentationTag)]; + image_segmenter.Out(kImageTag) >> graph[Output(kImageTag)]; + + return graph.GetConfig(); + } +}; + +// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly. +// clang-format off +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::interactive_segmenter::InteractiveSegmenterGraph); +// clang-format on + +} // namespace interactive_segmenter +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc new file mode 100644 index 000000000..dbe021dce --- /dev/null +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc @@ -0,0 +1,261 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/kernels/builtin_op_kernels.h" +#include "tensorflow/lite/mutable_op_resolver.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace interactive_segmenter { +namespace { + +using ::mediapipe::Image; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::RectF; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; +using ::testing::HasSubstr; +using ::testing::Optional; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kPtmModel[] = "ptm_512_hdt_ptm_woid.tflite"; +constexpr char kCatsAndDogsJpg[] = "cats_and_dogs.jpg"; + +constexpr float kGoldenMaskSimilarity = 0.98; + +// Magnification factor used when creating the golden category masks to make +// them more human-friendly. Each pixel in the golden masks has its value +// multiplied by this factor, i.e. a value of 10 means class index 1, a value of +// 20 means class index 2, etc. +constexpr int kGoldenMaskMagnificationFactor = 10; + +// Intentionally converting output into CV_8UC1 and then again into CV_32FC1 +// as expected outputs are stored in CV_8UC1, so this conversion allows to do +// fair comparison. +cv::Mat PostProcessResultMask(const cv::Mat& mask) { + cv::Mat mask_float; + mask.convertTo(mask_float, CV_8UC1, 255); + mask_float.convertTo(mask_float, CV_32FC1, 1 / 255.f); + return mask_float; +} + +double CalculateSum(const cv::Mat& m) { + double sum = 0.0; + cv::Scalar s = cv::sum(m); + for (int i = 0; i < m.channels(); ++i) { + sum += s.val[i]; + } + return sum; +} + +double CalculateSoftIOU(const cv::Mat& m1, const cv::Mat& m2) { + cv::Mat intersection; + cv::multiply(m1, m2, intersection); + double intersection_value = CalculateSum(intersection); + double union_value = + CalculateSum(m1.mul(m1)) + CalculateSum(m2.mul(m2)) - intersection_value; + return union_value > 0.0 ? intersection_value / union_value : 0.0; +} + +MATCHER_P2(SimilarToFloatMask, expected_mask, similarity_threshold, "") { + cv::Mat actual_mask = PostProcessResultMask(arg); + return arg.rows == expected_mask.rows && arg.cols == expected_mask.cols && + CalculateSoftIOU(arg, expected_mask) > similarity_threshold; +} + +MATCHER_P3(SimilarToUint8Mask, expected_mask, similarity_threshold, + magnification_factor, "") { + if (arg.rows != expected_mask.rows || arg.cols != expected_mask.cols) { + return false; + } + int consistent_pixels = 0; + const int num_pixels = expected_mask.rows * expected_mask.cols; + for (int i = 0; i < num_pixels; ++i) { + consistent_pixels += + (arg.data[i] * magnification_factor == expected_mask.data[i]); + } + return static_cast(consistent_pixels) / num_pixels >= + similarity_threshold; +} + +class CreateFromOptionsTest : public tflite_shims::testing::Test {}; + +class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver { + public: + DeepLabOpResolverMissingOps() { + AddBuiltin(::tflite::BuiltinOperator_ADD, + ::tflite::ops::builtin::Register_ADD()); + } + + DeepLabOpResolverMissingOps(const DeepLabOpResolverMissingOps& r) = delete; +}; + +TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kPtmModel); + options->base_options.op_resolver = + absl::make_unique(); + auto segmenter_or = InteractiveSegmenter::Create(std::move(options)); + // TODO: Make MediaPipe InferenceCalculator report the detailed + // interpreter errors (e.g., "Encountered unresolved custom op"). + EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal); + EXPECT_THAT( + segmenter_or.status().message(), + testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk")); +} + +TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { + absl::StatusOr> segmenter_or = + InteractiveSegmenter::Create( + std::make_unique()); + + EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + segmenter_or.status().message(), + HasSubstr("ExternalFile must specify at least one of 'file_content', " + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); + EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +class ImageModeTest : public tflite_shims::testing::Test {}; + +TEST_F(ImageModeTest, SucceedsWithCategoryMask) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); + RegionOfInterest interaction_roi; + interaction_roi.format = RegionOfInterest::KEYPOINT; + interaction_roi.keypoint = + components::containers::NormalizedKeypoint{0.25, 0.9}; + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kPtmModel); + options->output_type = InteractiveSegmenterOptions::OutputType::CATEGORY_MASK; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + InteractiveSegmenter::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto category_masks, + segmenter->Segment(image, interaction_roi)); + EXPECT_EQ(category_masks.size(), 1); +} + +TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); + RegionOfInterest interaction_roi; + interaction_roi.format = RegionOfInterest::KEYPOINT; + interaction_roi.keypoint = + components::containers::NormalizedKeypoint{0.25, 0.9}; + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kPtmModel); + options->output_type = + InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + InteractiveSegmenter::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, + segmenter->Segment(image, interaction_roi)); + EXPECT_EQ(confidence_masks.size(), 2); +} + +// TODO: fix this unit test after image segmenter handled post +// processing correctly with rotated image. +TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); + RegionOfInterest interaction_roi; + interaction_roi.format = RegionOfInterest::KEYPOINT; + interaction_roi.keypoint = + components::containers::NormalizedKeypoint{0.25, 0.9}; + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kPtmModel); + options->output_type = + InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + InteractiveSegmenter::Create(std::move(options))); + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = -90; + MP_ASSERT_OK_AND_ASSIGN( + auto confidence_masks, + segmenter->Segment(image, interaction_roi, image_processing_options)); + EXPECT_EQ(confidence_masks.size(), 2); +} + +TEST_F(ImageModeTest, FailsWithRegionOfInterest) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); + RegionOfInterest interaction_roi; + interaction_roi.format = RegionOfInterest::KEYPOINT; + interaction_roi.keypoint = + components::containers::NormalizedKeypoint{0.25, 0.9}; + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kPtmModel); + options->output_type = + InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + InteractiveSegmenter::Create(std::move(options))); + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; + + auto results = + segmenter->Segment(image, interaction_roi, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("This task doesn't support region-of-interest")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); +} + +} // namespace +} // namespace interactive_segmenter +} // namespace vision +} // namespace tasks +} // namespace mediapipe From 8f1ce5fef6b37d5ebf34ffd5dd9dd8c6365b2b75 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 15 Mar 2023 17:00:50 -0700 Subject: [PATCH 10/10] Add quality test for InteractiveSegmenter PiperOrigin-RevId: 516968294 --- .../interactive_segmenter_test.cc | 83 ++++++++++++++----- mediapipe/tasks/testdata/vision/BUILD | 4 + third_party/external_files.bzl | 20 +++-- 3 files changed, 81 insertions(+), 26 deletions(-) diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc index dbe021dce..dbc3bbe4c 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h" -#include #include +#include #include "absl/flags/flag.h" #include "mediapipe/framework/deps/file_path.h" @@ -28,6 +28,7 @@ limitations under the License. #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/containers/keypoint.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" @@ -47,6 +48,7 @@ namespace { using ::mediapipe::Image; using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::NormalizedKeypoint; using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; @@ -55,14 +57,16 @@ using ::testing::Optional; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kPtmModel[] = "ptm_512_hdt_ptm_woid.tflite"; constexpr char kCatsAndDogsJpg[] = "cats_and_dogs.jpg"; +// Golden mask for the dogs in cats_and_dogs.jpg. +constexpr char kCatsAndDogsMaskDog1[] = "cats_and_dogs_mask_dog1.png"; +constexpr char kCatsAndDogsMaskDog2[] = "cats_and_dogs_mask_dog2.png"; -constexpr float kGoldenMaskSimilarity = 0.98; +constexpr float kGoldenMaskSimilarity = 0.97; // Magnification factor used when creating the golden category masks to make -// them more human-friendly. Each pixel in the golden masks has its value -// multiplied by this factor, i.e. a value of 10 means class index 1, a value of -// 20 means class index 2, etc. -constexpr int kGoldenMaskMagnificationFactor = 10; +// them more human-friendly. Since interactive segmenter has only 2 categories, +// the golden mask uses 0 or 255 for each pixel. +constexpr int kGoldenMaskMagnificationFactor = 255; // Intentionally converting output into CV_8UC1 and then again into CV_32FC1 // as expected outputs are stored in CV_8UC1, so this conversion allows to do @@ -155,16 +159,25 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { MediaPipeTasksStatus::kRunnerInitializationError)))); } -class ImageModeTest : public tflite_shims::testing::Test {}; +struct InteractiveSegmenterTestParams { + std::string test_name; + RegionOfInterest::Format format; + NormalizedKeypoint roi; + std::string golden_mask_file; + float similarity_threshold; +}; -TEST_F(ImageModeTest, SucceedsWithCategoryMask) { +using SucceedSegmentationWithRoi = + ::testing::TestWithParam; + +TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) { + const InteractiveSegmenterTestParams& params = GetParam(); MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); RegionOfInterest interaction_roi; - interaction_roi.format = RegionOfInterest::KEYPOINT; - interaction_roi.keypoint = - components::containers::NormalizedKeypoint{0.25, 0.9}; + interaction_roi.format = params.format; + interaction_roi.keypoint = params.roi; auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); @@ -175,16 +188,26 @@ TEST_F(ImageModeTest, SucceedsWithCategoryMask) { MP_ASSERT_OK_AND_ASSIGN(auto category_masks, segmenter->Segment(image, interaction_roi)); EXPECT_EQ(category_masks.size(), 1); + + cv::Mat actual_mask = mediapipe::formats::MatView( + category_masks[0].GetImageFrameSharedPtr().get()); + + cv::Mat expected_mask = + cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file), + cv::IMREAD_GRAYSCALE); + EXPECT_THAT(actual_mask, + SimilarToUint8Mask(expected_mask, params.similarity_threshold, + kGoldenMaskMagnificationFactor)); } -TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { +TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { + const auto& params = GetParam(); MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); RegionOfInterest interaction_roi; - interaction_roi.format = RegionOfInterest::KEYPOINT; - interaction_roi.keypoint = - components::containers::NormalizedKeypoint{0.25, 0.9}; + interaction_roi.format = params.format; + interaction_roi.keypoint = params.roi; auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); @@ -196,8 +219,32 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image, interaction_roi)); EXPECT_EQ(confidence_masks.size(), 2); + + cv::Mat expected_mask = + cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file), + cv::IMREAD_GRAYSCALE); + cv::Mat expected_mask_float; + expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); + + cv::Mat actual_mask = mediapipe::formats::MatView( + confidence_masks[1].GetImageFrameSharedPtr().get()); + EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float, + params.similarity_threshold)); } +INSTANTIATE_TEST_SUITE_P( + SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi, + ::testing::ValuesIn( + {{"PointToDog1", RegionOfInterest::KEYPOINT, + NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f}, + {"PointToDog2", RegionOfInterest::KEYPOINT, + NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2, + kGoldenMaskSimilarity}}), + [](const ::testing::TestParamInfo& + info) { return info.param.test_name; }); + +class ImageModeTest : public tflite_shims::testing::Test {}; + // TODO: fix this unit test after image segmenter handled post // processing correctly with rotated image. TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { @@ -206,8 +253,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); RegionOfInterest interaction_roi; interaction_roi.format = RegionOfInterest::KEYPOINT; - interaction_roi.keypoint = - components::containers::NormalizedKeypoint{0.25, 0.9}; + interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66}; auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); @@ -230,8 +276,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); RegionOfInterest interaction_roi; interaction_roi.format = RegionOfInterest::KEYPOINT; - interaction_roi.keypoint = - components::containers::NormalizedKeypoint{0.25, 0.9}; + interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66}; auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index ac76bfa23..097acad43 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -31,6 +31,8 @@ mediapipe_files(srcs = [ "cat_rotated.jpg", "cat_rotated_mask.jpg", "cats_and_dogs.jpg", + "cats_and_dogs_mask_dog1.png", + "cats_and_dogs_mask_dog2.png", "cats_and_dogs_no_resizing.jpg", "cats_and_dogs_rotated.jpg", "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite", @@ -116,6 +118,8 @@ filegroup( "cat_rotated.jpg", "cat_rotated_mask.jpg", "cats_and_dogs.jpg", + "cats_and_dogs_mask_dog1.png", + "cats_and_dogs_mask_dog2.png", "cats_and_dogs_no_resizing.jpg", "cats_and_dogs_rotated.jpg", "fist.jpg", diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 52636f427..3a08c61c5 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -67,13 +67,7 @@ def external_files(): http_file( name = "com_google_mediapipe_BUILD", sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3", - urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=16618756636939761678323576393653"], - ) - - http_file( - name = "com_google_mediapipe_BUILD_orig", - sha256 = "d86b98b82e00dd87cd46bd1429bf5eaa007b500c1a24d9316b73309f2e6c8df8", - urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1678737479599640"], + urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=166187566369397616783235763936531678737479599640"], ) http_file( @@ -136,6 +130,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs.jpg?generation=1661875684064150"], ) + http_file( + name = "com_google_mediapipe_cats_and_dogs_mask_dog1_png", + sha256 = "2ab37d56ba1e46e70b3ddbfe35dac51b18b597b76904c68d7d34c7c74c677d4c", + urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_mask_dog1.png?generation=1678840350058498"], + ) + + http_file( + name = "com_google_mediapipe_cats_and_dogs_mask_dog2_png", + sha256 = "2010850e2dd7f520fe53b9086d70913b6fb53b178cae15a373e5ee7ffb46824a", + urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_mask_dog2.png?generation=1678840352961684"], + ) + http_file( name = "com_google_mediapipe_cats_and_dogs_no_resizing_jpg", sha256 = "9d55933ed66bcdc63cd6509ee2518d7eed75d12db609238387ee4cc50b173e58",