From 9a1a9d4c136685962afc0a6bc81e4b458d57d2e0 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 24 Oct 2022 06:08:27 -0700 Subject: [PATCH 01/34] Added files needed for the GestureRecognizer API implementation --- mediapipe/python/BUILD | 1 + .../tasks/python/components/containers/BUILD | 37 ++ .../components/containers/classification.py | 128 ++++++ .../python/components/containers/gesture.py | 138 ++++++ .../python/components/containers/landmark.py | 250 ++++++++++ .../python/components/containers/rect.py | 141 ++++++ .../tasks/python/components/processors/BUILD | 28 ++ .../processors/classifier_options.py | 92 ++++ mediapipe/tasks/python/test/vision/BUILD | 19 + .../test/vision/gesture_recognizer_test.py | 91 ++++ mediapipe/tasks/python/vision/BUILD | 27 ++ .../tasks/python/vision/gesture_recognizer.py | 434 ++++++++++++++++++ 12 files changed, 1386 insertions(+) create mode 100644 mediapipe/tasks/python/components/containers/classification.py create mode 100644 mediapipe/tasks/python/components/containers/gesture.py create mode 100644 mediapipe/tasks/python/components/containers/landmark.py create mode 100644 mediapipe/tasks/python/components/containers/rect.py create mode 100644 mediapipe/tasks/python/components/processors/BUILD create mode 100644 mediapipe/tasks/python/components/processors/classifier_options.py create mode 100644 mediapipe/tasks/python/test/vision/gesture_recognizer_test.py create mode 100644 mediapipe/tasks/python/vision/gesture_recognizer.py diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 2911e2fd6..50a1f5791 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -88,6 +88,7 @@ cc_library( name = "builtin_task_graphs", deps = [ "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", + "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", ], ) diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index fd25401f7..325dff5fc 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -27,6 +27,43 @@ py_library( ], ) +py_library( + name = "rect", + srcs = ["rect.py"], + deps = [ + "//mediapipe/framework/formats:rect_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) + +py_library( + name = "classification", + srcs = ["classification.py"], + deps = [ + "//mediapipe/framework/formats:classification_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) + +py_library( + name = "landmark", + srcs = ["landmark.py"], + deps = [ + "//mediapipe/framework/formats:landmark_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) + +py_library( + name = "gesture", + srcs = ["gesture.py"], + deps = [ + ":classification", + ":landmark", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) + py_library( name = "category", srcs = ["category.py"], diff --git a/mediapipe/tasks/python/components/containers/classification.py b/mediapipe/tasks/python/components/containers/classification.py new file mode 100644 index 000000000..157c34528 --- /dev/null +++ b/mediapipe/tasks/python/components/containers/classification.py @@ -0,0 +1,128 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classification data class.""" + +import dataclasses +from typing import Any, List + +from mediapipe.framework.formats import classification_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_ClassificationProto = classification_pb2.Classification +_ClassificationListProto = classification_pb2.ClassificationList +_ClassificationListCollectionProto = classification_pb2.ClassificationListCollection + + +@dataclasses.dataclass +class Classification: + """A classification. + + Attributes: + index: The index of the class in the corresponding label map. + score: The probability score for this class. + label_name: Label or name of the class. + display_name: Optional human-readable string for display purposes. + """ + + index: int + score: float + label_name: str + display_name: str + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationProto: + """Generates a Classification protobuf object.""" + return _ClassificationProto( + index=self.index, + score=self.score, + label_name=self.label_name, + display_name=self.display_name) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _ClassificationProto) -> 'Classification': + """Creates a `Classification` object from the given protobuf object.""" + return Classification( + index=pb2_obj.index, + score=pb2_obj.score, + label_name=pb2_obj.label_name, + display_name=pb2_obj.display_name) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, Classification): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class ClassificationList: + """Represents the classifications for a given classifier. + Attributes: + classification : A list of `Classification` objects. + tensor_index: Optional index of the tensor that produced these + classifications. + tensor_name: Optional name of the tensor that produced these + classifications tensor metadata name. + """ + + classifications: List[Classification] + tensor_index: int + tensor_name: str + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationListProto: + """Generates a ClassificationList protobuf object.""" + return _ClassificationListProto( + classification=[ + classification.to_pb2() + for classification in self.classifications + ], + tensor_index=self.tensor_index, + tensor_name=self.tensor_name) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, + pb2_obj: _ClassificationListProto + ) -> 'ClassificationList': + """Creates a `ClassificationList` object from the given protobuf object.""" + return ClassificationList( + classifications=[ + Classification.create_from_pb2(classification) + for classification in pb2_obj.classification + ], + tensor_index=pb2_obj.tensor_index, + tensor_name=pb2_obj.tensor_name) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + Args: + other: The object to be compared with. + Returns: + True if the objects are equal. + """ + if not isinstance(other, ClassificationList): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/containers/gesture.py b/mediapipe/tasks/python/components/containers/gesture.py new file mode 100644 index 000000000..f314d18bd --- /dev/null +++ b/mediapipe/tasks/python/components/containers/gesture.py @@ -0,0 +1,138 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Gesture data class.""" + +import dataclasses +from typing import Any, List + +from mediapipe.tasks.python.components.containers import classification +from mediapipe.tasks.python.components.containers import landmark +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + + +@dataclasses.dataclass +class GestureRecognitionResult: + """ The gesture recognition result from GestureRecognizer, where each vector + element represents a single hand detected in the image. + + Attributes: + gestures: Recognized hand gestures with sorted order such that the + winning label is the first item in the list. + handedness: Classification of handedness. + hand_landmarks: Detected hand landmarks in normalized image coordinates. + hand_world_landmarks: Detected hand landmarks in world coordinates. + """ + + gestures: List[classification.ClassificationList] + handedness: List[classification.ClassificationList] + hand_landmarks: List[landmark.NormalizedLandmarkList] + hand_world_landmarks: List[landmark.LandmarkList] + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _DetectionProto: + """Generates a Detection protobuf object.""" + labels = [] + label_ids = [] + scores = [] + display_names = [] + for category in self.categories: + scores.append(category.score) + if category.index: + label_ids.append(category.index) + if category.category_name: + labels.append(category.category_name) + if category.display_name: + display_names.append(category.display_name) + return _DetectionProto( + label=labels, + label_id=label_ids, + score=scores, + display_name=display_names, + location_data=_LocationDataProto( + format=_LocationDataProto.Format.BOUNDING_BOX, + bounding_box=self.bounding_box.to_pb2())) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _DetectionProto) -> 'Detection': + """Creates a `Detection` object from the given protobuf object.""" + categories = [] + for idx, score in enumerate(pb2_obj.score): + categories.append( + category_module.Category( + score=score, + index=pb2_obj.label_id[idx] + if idx < len(pb2_obj.label_id) else None, + category_name=pb2_obj.label[idx] + if idx < len(pb2_obj.label) else None, + display_name=pb2_obj.display_name[idx] + if idx < len(pb2_obj.display_name) else None)) + + return Detection( + bounding_box=bounding_box_module.BoundingBox.create_from_pb2( + pb2_obj.location_data.bounding_box), + categories=categories) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, Detection): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class DetectionResult: + """Represents the list of detected objects. + + Attributes: + detections: A list of `Detection` objects. + """ + + detections: List[Detection] + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _DetectionListProto: + """Generates a DetectionList protobuf object.""" + return _DetectionListProto( + detection=[detection.to_pb2() for detection in self.detections]) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _DetectionListProto) -> 'DetectionResult': + """Creates a `DetectionResult` object from the given protobuf object.""" + return DetectionResult(detections=[ + Detection.create_from_pb2(detection) for detection in pb2_obj.detection + ]) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, DetectionResult): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/containers/landmark.py b/mediapipe/tasks/python/components/containers/landmark.py new file mode 100644 index 000000000..a86c17f2f --- /dev/null +++ b/mediapipe/tasks/python/components/containers/landmark.py @@ -0,0 +1,250 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Landmark data class.""" + +import dataclasses +from typing import Any, Optional, List + +from mediapipe.framework.formats import landmark_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_LandmarkProto = landmark_pb2.Landmark +_LandmarkListProto = landmark_pb2.LandmarkList +_NormalizedLandmarkProto = landmark_pb2.NormalizedLandmark +_NormalizedLandmarkListProto = landmark_pb2.NormalizedLandmarkList + + +@dataclasses.dataclass +class Landmark: + """A landmark that can have 1 to 3 dimensions. + + Use x for 1D points, (x, y) for 2D points and (x, y, z) for 3D points. + + Attributes: + x: The x coordinate of the 3D point. + y: The y coordinate of the 3D point. + z: The z coordinate of the 3D point. + visibility: Landmark visibility. Should stay unset if not supported. + Float score of whether landmark is visible or occluded by other objects. + Landmark considered as invisible also if it is not present on the screen + (out of scene bounds). Depending on the model, visibility value is either + a sigmoid or an argument of sigmoid. + presence: Landmark presence. Should stay unset if not supported. + Float score of whether landmark is present on the scene (located within + scene bounds). Depending on the model, presence value is either a result + of sigmoid or an argument of sigmoid function to get landmark presence + probability. + """ + + x: Optional[float] = None + y: Optional[float] = None + z: Optional[float] = None + visibility: Optional[float] = None + presence: Optional[float] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _LandmarkProto: + """Generates a Landmark protobuf object.""" + return _LandmarkProto( + x=self.x, + y=self.y, + z=self.z, + visibility=self.visibility, + presence=self.presence) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _LandmarkProto) -> 'Landmark': + """Creates a `Landmark` object from the given protobuf object.""" + return Landmark( + x=pb2_obj.x, + y=pb2_obj.y, + z=pb2_obj.z, + visibility=pb2_obj.visibility, + presence=pb2_obj.presence) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, Landmark): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class LandmarkList: + """Represents the list of landmarks. + + Attributes: + landmarks : A list of `Landmark` objects. + """ + + landmarks: List[Landmark] + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _LandmarkListProto: + """Generates a LandmarkList protobuf object.""" + return _LandmarkListProto( + landmark=[ + landmark.to_pb2() + for landmark in self.landmarks + ] + ) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, + pb2_obj: _LandmarkListProto + ) -> 'LandmarkList': + """Creates a `LandmarkList` object from the given protobuf object.""" + return LandmarkList( + landmarks=[ + Landmark.create_from_pb2(landmark) + for landmark in pb2_obj.landmark + ] + ) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + Args: + other: The object to be compared with. + Returns: + True if the objects are equal. + """ + if not isinstance(other, LandmarkList): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class NormalizedLandmark: + """A normalized version of above Landmark proto. + + All coordinates should be within [0, 1]. + + Attributes: + x: The normalized x coordinate of the 3D point. + y: The normalized y coordinate of the 3D point. + z: The normalized z coordinate of the 3D point. + visibility: Landmark visibility. Should stay unset if not supported. + Float score of whether landmark is visible or occluded by other objects. + Landmark considered as invisible also if it is not present on the screen + (out of scene bounds). Depending on the model, visibility value is either + a sigmoid or an argument of sigmoid. + presence: Landmark presence. Should stay unset if not supported. + Float score of whether landmark is present on the scene (located within + scene bounds). Depending on the model, presence value is either a result + of sigmoid or an argument of sigmoid function to get landmark presence + probability. + """ + + x: Optional[float] = None + y: Optional[float] = None + z: Optional[float] = None + visibility: Optional[float] = None + presence: Optional[float] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _NormalizedLandmarkProto: + """Generates a NormalizedLandmark protobuf object.""" + return _NormalizedLandmarkProto( + x=self.x, + y=self.y, + z=self.z, + visibility=self.visibility, + presence=self.presence) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, + pb2_obj: _NormalizedLandmarkProto + ) -> 'NormalizedLandmark': + """Creates a `NormalizedLandmark` object from the given protobuf object.""" + return NormalizedLandmark( + x=pb2_obj.x, + y=pb2_obj.y, + z=pb2_obj.z, + visibility=pb2_obj.visibility, + presence=pb2_obj.presence) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, NormalizedLandmark): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class NormalizedLandmarkList: + """Represents the list of normalized landmarks. + + Attributes: + landmarks : A list of `Landmark` objects. + """ + + landmarks: List[NormalizedLandmark] + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _NormalizedLandmarkListProto: + """Generates a NormalizedLandmarkList protobuf object.""" + return _NormalizedLandmarkListProto( + landmark=[ + landmark.to_pb2() + for landmark in self.landmarks + ] + ) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, + pb2_obj: _NormalizedLandmarkListProto + ) -> 'NormalizedLandmarkList': + """Creates a `NormalizedLandmarkList` object from the given protobuf object.""" + return NormalizedLandmarkList( + landmarks=[ + NormalizedLandmark.create_from_pb2(landmark) + for landmark in pb2_obj.landmark + ] + ) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + Args: + other: The object to be compared with. + Returns: + True if the objects are equal. + """ + if not isinstance(other, NormalizedLandmarkList): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/containers/rect.py b/mediapipe/tasks/python/components/containers/rect.py new file mode 100644 index 000000000..aadb404db --- /dev/null +++ b/mediapipe/tasks/python/components/containers/rect.py @@ -0,0 +1,141 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rect data class.""" + +import dataclasses +from typing import Any, Optional + +from mediapipe.framework.formats import rect_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_RectProto = rect_pb2.Rect +_NormalizedRectProto = rect_pb2.NormalizedRect + + +@dataclasses.dataclass +class Rect: + """A rectangle with rotation in image coordinates. + + Attributes: + x_center : The X coordinate of the top-left corner, in pixels. + y_center : The Y coordinate of the top-left corner, in pixels. + width: The width of the rectangle, in pixels. + height: The height of the rectangle, in pixels. + rotation: Rotation angle is clockwise in radians. + rect_id: Optional unique id to help associate different rectangles to each + other. + """ + + x_center: int + y_center: int + width: int + height: int + rotation: Optional[float] = 0.0 + rect_id: Optional[int] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _RectProto: + """Generates a Rect protobuf object.""" + return _RectProto( + x_center=self.x_center, + y_center=self.y_center, + width=self.width, + height=self.height, + ) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _RectProto) -> 'Rect': + """Creates a `Rect` object from the given protobuf object.""" + return Rect( + x_center=pb2_obj.x_center, + y_center=pb2_obj.y_center, + width=pb2_obj.width, + height=pb2_obj.height) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, Rect): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class NormalizedRect: + """A rectangle with rotation in normalized coordinates. The values of box + center location and size are within [0, 1]. + + Attributes: + x_center : The X normalized coordinate of the top-left corner. + y_center : The Y normalized coordinate of the top-left corner. + width: The width of the rectangle. + height: The height of the rectangle. + rotation: Rotation angle is clockwise in radians. + rect_id: Optional unique id to help associate different rectangles to each + other. + """ + + x_center: float + y_center: float + width: float + height: float + rotation: Optional[float] = 0.0 + rect_id: Optional[int] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _NormalizedRectProto: + """Generates a NormalizedRect protobuf object.""" + return _NormalizedRectProto( + x_center=self.x_center, + y_center=self.y_center, + width=self.width, + height=self.height, + rotation=self.rotation, + rect_id=self.rect_id + ) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _NormalizedRectProto) -> 'NormalizedRect': + """Creates a `NormalizedRect` object from the given protobuf object.""" + return NormalizedRect( + x_center=pb2_obj.x_center, + y_center=pb2_obj.y_center, + width=pb2_obj.width, + height=pb2_obj.height, + rotation=pb2_obj.rotation, + rect_id=pb2_obj.rect_id + ) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, NormalizedRect): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/processors/BUILD b/mediapipe/tasks/python/components/processors/BUILD new file mode 100644 index 000000000..814e15d1f --- /dev/null +++ b/mediapipe/tasks/python/components/processors/BUILD @@ -0,0 +1,28 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_library( + name = "classifier_options", + srcs = ["classifier_options.py"], + deps = [ + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) diff --git a/mediapipe/tasks/python/components/processors/classifier_options.py b/mediapipe/tasks/python/components/processors/classifier_options.py new file mode 100644 index 000000000..b4597e57a --- /dev/null +++ b/mediapipe/tasks/python/components/processors/classifier_options.py @@ -0,0 +1,92 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classifier options data class.""" + +import dataclasses +from typing import Any, List, Optional + +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions + + +@dataclasses.dataclass +class ClassifierOptions: + """Options for classification processor. + + Attributes: + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, detection + results whose category name is not in this set will be filtered out. + Duplicate or unknown category names are ignored. Mutually exclusive with + `category_denylist`. + category_denylist: Denylist of category names. If non-empty, detection + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. + """ + + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassifierOptionsProto: + """Generates a ClassifierOptions protobuf object.""" + return _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, + pb2_obj: _ClassifierOptionsProto + ) -> 'ClassifierOptions': + """Creates a `ClassifierOptions` object from the given protobuf object.""" + return ClassifierOptions( + score_threshold=pb2_obj.score_threshold, + category_allowlist=[ + str(name) for name in pb2_obj.class_name_allowlist + ], + category_denylist=[ + str(name) for name in pb2_obj.class_name_denylist + ], + display_names_locale=pb2_obj.display_names_locale, + max_results=pb2_obj.max_results) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, ClassifierOptions): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 290b665e7..0dd83edcf 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -36,3 +36,22 @@ py_test( "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) + +py_test( + name = "gesture_recognizer_test", + srcs = ["gesture_recognizer_test.py"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/components/containers:classification", + "//mediapipe/tasks/python/components/containers:landmark", + "//mediapipe/tasks/python/components/containers:rect", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/vision:gesture_recognizer", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) diff --git a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py new file mode 100644 index 000000000..288cfd1f5 --- /dev/null +++ b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py @@ -0,0 +1,91 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for gesture recognizer.""" + +import enum + +from absl.testing import absltest +from absl.testing import parameterized + +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.components.containers import rect as rect_module +from mediapipe.tasks.python.components.containers import classification as classification_module +from mediapipe.tasks.python.components.containers import landmark as landmark_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils +from mediapipe.tasks.python.vision import gesture_recognizer +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + +_BaseOptions = base_options_module.BaseOptions +_NormalizedRect = rect_module.NormalizedRect +_ClassificationList = classification_module.ClassificationList +_LandmarkList = landmark_module.LandmarkList +_NormalizedLandmarkList = landmark_module.NormalizedLandmarkList +_Image = image_module.Image +_GestureRecognizer = gesture_recognizer.GestureRecognizer +_GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions +_GestureRecognitionResult = gesture_recognizer.GestureRecognitionResult +_RUNNING_MODE = running_mode_module.VisionTaskRunningMode + +_GESTURE_RECOGNIZER_MODEL_FILE = 'gesture_recognizer.task' +_IMAGE_FILE = 'right_hands.jpg' +_EXPECTED_DETECTION_RESULT = _GestureRecognitionResult([], [], [], []) + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class GestureRecognizerTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path(_IMAGE_FILE)) + self.gesture_recognizer_model_path = test_utils.get_test_data_path( + _GESTURE_RECOGNIZER_MODEL_FILE) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _EXPECTED_DETECTION_RESULT), + (ModelFileType.FILE_CONTENT, _EXPECTED_DETECTION_RESULT)) + def test_recognize(self, model_file_type, expected_recognition_result): + # Creates gesture recognizer. + if model_file_type is ModelFileType.FILE_NAME: + gesture_recognizer_base_options = _BaseOptions( + model_asset_path=self.gesture_recognizer_model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.gesture_recognizer_model_path, 'rb') as f: + model_content = f.read() + gesture_recognizer_base_options = _BaseOptions( + model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _GestureRecognizerOptions( + base_options=gesture_recognizer_base_options) + recognizer = _GestureRecognizer.create_from_options(options) + + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(self.test_image) + # Comparing results. + self.assertEqual(recognition_result, expected_recognition_result) + # Closes the gesture recognizer explicitly when the detector is not used in + # a context. + recognizer.close() + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index e7be51c8d..9a9ca3429 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -36,3 +36,30 @@ py_library( "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) + +py_library( + name = "gesture_recognizer", + srcs = [ + "gesture_recognizer.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_py_pb2", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_py_pb2", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_py_pb2", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_py_pb2", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_py_pb2", + "//mediapipe/tasks/python/components/containers:rect", + "//mediapipe/tasks/python/components/containers:classification", + "//mediapipe/tasks/python/components/containers:landmark", + "//mediapipe/tasks/python/components/processors:classifier_options", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/vision/core:base_vision_task_api", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py new file mode 100644 index 000000000..aca7a5277 --- /dev/null +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -0,0 +1,434 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe gesture recognizer task.""" + +import dataclasses +from typing import Callable, Mapping, Optional, List + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.python._framework_bindings import packet as packet_module +from mediapipe.python._framework_bindings import task_runner as task_runner_module +from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_classifier_graph_options_pb2 +from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_recognizer_graph_options_pb2 +from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_recognizer_graph_options_pb2 +from mediapipe.tasks.cc.vision.hand_detector.proto import hand_detector_graph_options_pb2 +from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2 +from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarks_detector_graph_options_pb2 +from mediapipe.tasks.python.components.containers import rect as rect_module +from mediapipe.tasks.python.components.containers import classification as classification_module +from mediapipe.tasks.python.components.containers import landmark as landmark_module +from mediapipe.tasks.python.components.processors import classifier_options +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.vision.core import base_vision_task_api +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + +_NormalizedRect = rect_module.NormalizedRect +_BaseOptions = base_options_module.BaseOptions +_GestureClassifierGraphOptionsProto = gesture_classifier_graph_options_pb2.GestureClassifierGraphOptions +_GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions +_HandGestureRecognizerGraphOptionsProto = hand_gesture_recognizer_graph_options_pb2.HandGestureRecognizerGraphOptions +_HandDetectorGraphOptionsProto = hand_detector_graph_options_pb2.HandDetectorGraphOptions +_HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions +_HandLandmarksDetectorGraphOptionsProto = hand_landmarks_detector_graph_options_pb2.HandLandmarksDetectorGraphOptions +_ClassifierOptions = classifier_options.ClassifierOptions +_RunningMode = running_mode_module.VisionTaskRunningMode +_TaskInfo = task_info_module.TaskInfo +_TaskRunner = task_runner_module.TaskRunner + +_IMAGE_IN_STREAM_NAME = 'image_in' +_IMAGE_OUT_STREAM_NAME = 'image_out' +_IMAGE_TAG = 'IMAGE' +_NORM_RECT_STREAM_NAME = 'norm_rect_in' +_NORM_RECT_TAG = 'NORM_RECT' +_HAND_GESTURE_STREAM_NAME = 'hand_gestures' +_HAND_GESTURE_TAG = 'HAND_GESTURES' +_HANDEDNESS_STREAM_NAME = 'handedness' +_HANDEDNESS_TAG = 'HANDEDNESS' +_HAND_LANDMARKS_STREAM_NAME = 'landmarks' +_HAND_LANDMARKS_TAG = 'LANDMARKS' +_HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks' +_HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS' +_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph' +_MICRO_SECONDS_PER_MILLISECOND = 1000 + + +def _build_full_image_norm_rect() -> _NormalizedRect: + # Builds a NormalizedRect covering the entire image. + return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1) + + +@dataclasses.dataclass +class GestureRecognitionResult: + """The gesture recognition result from GestureRecognizer, where each vector + element represents a single hand detected in the image. + + Attributes: + gestures: Recognized hand gestures with sorted order such that the + winning label is the first item in the list. + handedness: Classification of handedness. + hand_landmarks: Detected hand landmarks in normalized image coordinates. + hand_world_landmarks: Detected hand landmarks in world coordinates. + """ + + gestures: List[classification_module.ClassificationList] + handedness: List[classification_module.ClassificationList] + hand_landmarks: List[landmark_module.NormalizedLandmarkList] + hand_world_landmarks: List[landmark_module.LandmarkList] + + +@dataclasses.dataclass +class GestureRecognizerOptions: + """Options for the gesture recognizer task. + + Attributes: + base_options: Base options for the hand gesture recognizer task. + running_mode: The running mode of the task. Default to the image mode. + Gesture recognizer task has three running modes: + 1) The image mode for recognizing hand gestures on single image inputs. + 2) The video mode for recognizing hand gestures on the decoded frames of a + video. + 3) The live stream mode for recognizing hand gestures on a live stream of + input data, such as from camera. + num_hands: The maximum number of hands can be detected by the recognizer. + min_hand_detection_confidence: The minimum confidence score for the hand + detection to be considered successful. + min_hand_presence_confidence: The minimum confidence score of hand presence + score in the hand landmark detection. + min_tracking_confidence: The minimum confidence score for the hand tracking + to be considered successful. + min_gesture_confidence: The minimum confidence score for the gestures to be + considered successful. If < 0, the gesture confidence thresholds in the + model metadata are used. + TODO: Note this option is subject to change, after scoring merging + calculator is implemented. + result_callback: The user-defined result callback for processing live stream + data. The result callback should only be specified when the running mode + is set to the live stream mode. + """ + base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.IMAGE + num_hands: Optional[int] = 1 + min_hand_detection_confidence: Optional[int] = 0.5 + min_hand_presence_confidence: Optional[int] = 0.5 + min_tracking_confidence: Optional[int] = 0.5 + min_gesture_confidence: Optional[int] = -1 + result_callback: Optional[ + Callable[[GestureRecognitionResult, image_module.Image, + int], None]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _GestureRecognizerGraphOptionsProto: + """Generates an GestureRecognizerOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True + # hand_landmark_detector_base_options_proto = self.hand_landmark_detector_base_options.to_pb2() + # hand_landmark_detector_base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True + + # Configure hand detector options. + hand_detector_options_proto = _HandDetectorGraphOptionsProto( + num_hands=self.num_hands, + min_detection_confidence=self.min_hand_detection_confidence) + + # Configure hand landmarker options. + hand_landmarks_detector_options_proto = _HandLandmarksDetectorGraphOptionsProto( + min_detection_confidence=self.min_hand_presence_confidence) + hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto( + hand_detector_graph_options=hand_detector_options_proto, + hand_landmarks_detector_graph_options=hand_landmarks_detector_options_proto, + min_tracking_confidence=self.min_tracking_confidence) + + # Configure hand gesture recognizer options. + hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto() + if self.min_gesture_confidence >= 0: + classifier_options = _ClassifierOptions( + score_threshold=self.min_gesture_confidence) + hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options = \ + _GestureClassifierGraphOptionsProto( + classifier_options=classifier_options.to_pb2()) + + return _GestureRecognizerGraphOptionsProto( + base_options=base_options_proto, + hand_landmarker_graph_options=hand_landmarker_options_proto, + hand_gesture_recognizer_graph_options=hand_gesture_recognizer_options_proto + ) + + +class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): + """Class that performs gesture recognition on images.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'GestureRecognizer': + """Creates an `GestureRecognizer` object from a TensorFlow Lite model and + the default `GestureRecognizerOptions`. + + Note that the created `GestureRecognizer` instance is in image mode, for + recognizing hand gestures on single image inputs. + + Args: + model_path: Path to the model. + + Returns: + `GestureRecognizer` object that's created from the model file and the + default `GestureRecognizerOptions`. + + Raises: + ValueError: If failed to create `GestureRecognizer` object from the + provided file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = GestureRecognizerOptions( + base_options=base_options, running_mode=_RunningMode.IMAGE) + return cls.create_from_options(options) + + @classmethod + def create_from_options( + cls, + options: GestureRecognizerOptions + ) -> 'GestureRecognizer': + """Creates the `GestureRecognizer` object from gesture recognizer options. + + Args: + options: Options for the gesture recognizer task. + + Returns: + `GestureRecognizer` object that's created from `options`. + + Raises: + ValueError: If failed to create `GestureRecognizer` object from + `GestureRecognizerOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + + def packets_callback(output_packets: Mapping[str, packet_module.Packet]): + if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): + return + + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) + + if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): + empty_packet = output_packets[_HAND_GESTURE_STREAM_NAME] + options.result_callback( + GestureRecognitionResult([], [], [], []), image, + empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) + return + + gestures_proto_list = packet_getter.get_proto_list( + output_packets[_HAND_GESTURE_STREAM_NAME]) + handedness_proto_list = packet_getter.get_proto_list( + output_packets[_HANDEDNESS_STREAM_NAME]) + hand_landmarks_proto_list = packet_getter.get_proto_list( + output_packets[_HAND_LANDMARKS_STREAM_NAME]) + hand_world_landmarks_proto_list = packet_getter.get_proto_list( + output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]) + + gesture_recognition_result = GestureRecognitionResult( + [ + classification_module.ClassificationList.create_from_pb2(gestures) + for gestures in gestures_proto_list + ], [ + classification_module.ClassificationList.create_from_pb2(handedness) + for handedness in handedness_proto_list + ], [ + landmark_module.NormalizedLandmarkList.create_from_pb2(hand_landmarks) + for hand_landmarks in hand_landmarks_proto_list + ], [ + landmark_module.LandmarkList.create_from_pb2(hand_world_landmarks) + for hand_world_landmarks in hand_world_landmarks_proto_list + ] + ) + timestamp = output_packets[_HAND_GESTURE_STREAM_NAME].timestamp + options.result_callback( + gesture_recognition_result, image, + timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) + + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[ + ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), + ], + output_streams=[ + ':'.join([_HAND_GESTURE_TAG, _HAND_GESTURE_STREAM_NAME]), + ':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]), + ':'.join([_HAND_LANDMARKS_TAG, _HAND_LANDMARKS_STREAM_NAME]), + ':'.join([_HAND_WORLD_LANDMARKS_TAG, + _HAND_WORLD_LANDMARKS_STREAM_NAME]), + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) + ], + task_options=options) + return cls( + task_info.generate_graph_config( + enable_flow_limiting=options.running_mode == + _RunningMode.LIVE_STREAM), options.running_mode, + packets_callback if options.result_callback else None) + + def recognize( + self, + image: image_module.Image, + roi: Optional[_NormalizedRect] = None + ) -> GestureRecognitionResult: + """Performs hand gesture recognition on the given image. Only use this + method when the GestureRecognizer is created with the image running mode. + + The image can be of any size with format RGB or RGBA. + TODO: Describes how the input image will be preprocessed after the yuv + support is implemented. + + Args: + image: MediaPipe Image. + roi: The region of interest. + + Returns: + The hand gesture recognition results. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If gesture recognition failed to run. + """ + norm_rect = roi if roi is not None else _build_full_image_norm_rect() + output_packets = self._process_image_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + norm_rect.to_pb2())}) + gestures_proto_list = packet_getter.get_proto_list( + output_packets[_HAND_GESTURE_STREAM_NAME]) + handedness_proto_list = packet_getter.get_proto_list( + output_packets[_HANDEDNESS_STREAM_NAME]) + hand_landmarks_proto_list = packet_getter.get_proto_list( + output_packets[_HAND_LANDMARKS_STREAM_NAME]) + hand_world_landmarks_proto_list = packet_getter.get_proto_list( + output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]) + + return GestureRecognitionResult( + [ + classification_module.ClassificationList.create_from_pb2(gestures) + for gestures in gestures_proto_list + ], [ + classification_module.ClassificationList.create_from_pb2(handedness) + for handedness in handedness_proto_list + ], [ + landmark_module.NormalizedLandmarkList.create_from_pb2(hand_landmarks) + for hand_landmarks in hand_landmarks_proto_list + ], [ + landmark_module.LandmarkList.create_from_pb2(hand_world_landmarks) + for hand_world_landmarks in hand_world_landmarks_proto_list + ] + ) + + def recognize_for_video( + self, image: image_module.Image, + timestamp_ms: int, + roi: Optional[_NormalizedRect] = None + ) -> GestureRecognitionResult: + """Performs gesture recognition on the provided video frame. Only use this + method when the GestureRecognizer is created with the video running mode. + + Only use this method when the GestureRecognizer is created with the video + running mode. It's required to provide the video frame's timestamp (in + milliseconds) along with the video frame. The input timestamps should be + monotonically increasing for adjacent calls of this method. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input video frame in milliseconds. + roi: The region of interest. + + Returns: + The hand gesture recognition results. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If gesture recognition failed to run. + """ + norm_rect = roi if roi is not None else _build_full_image_norm_rect() + output_packets = self._process_video_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + norm_rect.to_pb2()).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + }) + gestures_proto_list = packet_getter.get_proto_list( + output_packets[_HAND_GESTURE_STREAM_NAME]) + handedness_proto_list = packet_getter.get_proto_list( + output_packets[_HANDEDNESS_STREAM_NAME]) + hand_landmarks_proto_list = packet_getter.get_proto_list( + output_packets[_HAND_LANDMARKS_STREAM_NAME]) + hand_world_landmarks_proto_list = packet_getter.get_proto_list( + output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]) + + return GestureRecognitionResult( + [ + classification_module.ClassificationList.create_from_pb2(gestures) + for gestures in gestures_proto_list + ], [ + classification_module.ClassificationList.create_from_pb2(handedness) + for handedness in handedness_proto_list + ], [ + landmark_module.NormalizedLandmarkList.create_from_pb2(hand_landmarks) + for hand_landmarks in hand_landmarks_proto_list + ], [ + landmark_module.LandmarkList.create_from_pb2(hand_world_landmarks) + for hand_world_landmarks in hand_world_landmarks_proto_list + ] + ) + + def recognize_async( + self, + image: image_module.Image, + timestamp_ms: int, + roi: Optional[_NormalizedRect] = None + ) -> None: + """Sends live image data to perform gesture recognition, and the results + will be available via the "result_callback" provided in the + GestureRecognizerOptions. Only use this method when the GestureRecognizer + is created with the live stream running mode. + + Only use this method when the GestureRecognizer is created with the live + stream running mode. The input timestamps should be monotonically increasing + for adjacent calls of this method. This method will return immediately after + the input image is accepted. The results will be available via the + `result_callback` provided in the `GestureRecognizerOptions`. The + `recognize_async` method is designed to process live stream data such as + camera input. To lower the overall latency, gesture recognizer may drop the + input images if needed. In other words, it's not guaranteed to have output + per input image. + + The `result_callback` provides: + - The hand gesture recognition results. + - The input image that the image classifier runs on. + - The input timestamp in milliseconds. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input image in milliseconds. + roi: The region of interest. + + Raises: + ValueError: If the current input timestamp is smaller than what the + gesture recognizer has already processed. + """ + norm_rect = roi if roi is not None else _build_full_image_norm_rect() + self._send_live_stream_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + norm_rect.to_pb2()).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + }) From 18eb089d39356ade117fbc629c5e19bef35f2d22 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Tue, 25 Oct 2022 07:38:04 -0700 Subject: [PATCH 02/34] Added a simple test to verify gesture recognition results --- .../tasks/python/components/containers/BUILD | 6 +- .../components/containers/classification.py | 19 ++- .../python/components/containers/gesture.py | 138 ------------------ .../containers/landmark_detection_result.py | 82 +++++++++++ mediapipe/tasks/python/test/vision/BUILD | 6 +- .../test/vision/gesture_recognizer_test.py | 85 ++++++++++- .../tasks/python/vision/gesture_recognizer.py | 15 +- mediapipe/tasks/testdata/vision/BUILD | 1 + 8 files changed, 184 insertions(+), 168 deletions(-) delete mode 100644 mediapipe/tasks/python/components/containers/gesture.py create mode 100644 mediapipe/tasks/python/components/containers/landmark_detection_result.py diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 325dff5fc..8aaa64cc9 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -55,11 +55,13 @@ py_library( ) py_library( - name = "gesture", - srcs = ["gesture.py"], + name = "landmark_detection_result", + srcs = ["landmark_detection_result.py"], deps = [ + ":rect", ":classification", ":landmark", + "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], ) diff --git a/mediapipe/tasks/python/components/containers/classification.py b/mediapipe/tasks/python/components/containers/classification.py index 157c34528..465e2dd28 100644 --- a/mediapipe/tasks/python/components/containers/classification.py +++ b/mediapipe/tasks/python/components/containers/classification.py @@ -14,14 +14,13 @@ """Classification data class.""" import dataclasses -from typing import Any, List +from typing import Any, List, Optional from mediapipe.framework.formats import classification_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls _ClassificationProto = classification_pb2.Classification _ClassificationListProto = classification_pb2.ClassificationList -_ClassificationListCollectionProto = classification_pb2.ClassificationListCollection @dataclasses.dataclass @@ -35,10 +34,10 @@ class Classification: display_name: Optional human-readable string for display purposes. """ - index: int - score: float - label_name: str - display_name: str + index: Optional[int] = None + score: Optional[float] = None + label: Optional[str] = None + display_name: Optional[str] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _ClassificationProto: @@ -46,7 +45,7 @@ class Classification: return _ClassificationProto( index=self.index, score=self.score, - label_name=self.label_name, + label=self.label, display_name=self.display_name) @classmethod @@ -56,7 +55,7 @@ class Classification: return Classification( index=pb2_obj.index, score=pb2_obj.score, - label_name=pb2_obj.label_name, + label=pb2_obj.label, display_name=pb2_obj.display_name) def __eq__(self, other: Any) -> bool: @@ -86,8 +85,8 @@ class ClassificationList: """ classifications: List[Classification] - tensor_index: int - tensor_name: str + tensor_index: Optional[int] = None + tensor_name: Optional[str] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _ClassificationListProto: diff --git a/mediapipe/tasks/python/components/containers/gesture.py b/mediapipe/tasks/python/components/containers/gesture.py deleted file mode 100644 index f314d18bd..000000000 --- a/mediapipe/tasks/python/components/containers/gesture.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Gesture data class.""" - -import dataclasses -from typing import Any, List - -from mediapipe.tasks.python.components.containers import classification -from mediapipe.tasks.python.components.containers import landmark -from mediapipe.tasks.python.core.optional_dependencies import doc_controls - - -@dataclasses.dataclass -class GestureRecognitionResult: - """ The gesture recognition result from GestureRecognizer, where each vector - element represents a single hand detected in the image. - - Attributes: - gestures: Recognized hand gestures with sorted order such that the - winning label is the first item in the list. - handedness: Classification of handedness. - hand_landmarks: Detected hand landmarks in normalized image coordinates. - hand_world_landmarks: Detected hand landmarks in world coordinates. - """ - - gestures: List[classification.ClassificationList] - handedness: List[classification.ClassificationList] - hand_landmarks: List[landmark.NormalizedLandmarkList] - hand_world_landmarks: List[landmark.LandmarkList] - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _DetectionProto: - """Generates a Detection protobuf object.""" - labels = [] - label_ids = [] - scores = [] - display_names = [] - for category in self.categories: - scores.append(category.score) - if category.index: - label_ids.append(category.index) - if category.category_name: - labels.append(category.category_name) - if category.display_name: - display_names.append(category.display_name) - return _DetectionProto( - label=labels, - label_id=label_ids, - score=scores, - display_name=display_names, - location_data=_LocationDataProto( - format=_LocationDataProto.Format.BOUNDING_BOX, - bounding_box=self.bounding_box.to_pb2())) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _DetectionProto) -> 'Detection': - """Creates a `Detection` object from the given protobuf object.""" - categories = [] - for idx, score in enumerate(pb2_obj.score): - categories.append( - category_module.Category( - score=score, - index=pb2_obj.label_id[idx] - if idx < len(pb2_obj.label_id) else None, - category_name=pb2_obj.label[idx] - if idx < len(pb2_obj.label) else None, - display_name=pb2_obj.display_name[idx] - if idx < len(pb2_obj.display_name) else None)) - - return Detection( - bounding_box=bounding_box_module.BoundingBox.create_from_pb2( - pb2_obj.location_data.bounding_box), - categories=categories) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, Detection): - return False - - return self.to_pb2().__eq__(other.to_pb2()) - - -@dataclasses.dataclass -class DetectionResult: - """Represents the list of detected objects. - - Attributes: - detections: A list of `Detection` objects. - """ - - detections: List[Detection] - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _DetectionListProto: - """Generates a DetectionList protobuf object.""" - return _DetectionListProto( - detection=[detection.to_pb2() for detection in self.detections]) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _DetectionListProto) -> 'DetectionResult': - """Creates a `DetectionResult` object from the given protobuf object.""" - return DetectionResult(detections=[ - Detection.create_from_pb2(detection) for detection in pb2_obj.detection - ]) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, DetectionResult): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/containers/landmark_detection_result.py b/mediapipe/tasks/python/components/containers/landmark_detection_result.py new file mode 100644 index 000000000..c3d93d414 --- /dev/null +++ b/mediapipe/tasks/python/components/containers/landmark_detection_result.py @@ -0,0 +1,82 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Landmark Detection Result data class.""" + +import dataclasses +from typing import Any, Optional + +from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2 +from mediapipe.tasks.python.components.containers import rect as rect_module +from mediapipe.tasks.python.components.containers import classification as classification_module +from mediapipe.tasks.python.components.containers import landmark as landmark_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult +_NormalizedRect = rect_module.NormalizedRect +_ClassificationList = classification_module.ClassificationList +_NormalizedLandmarkList = landmark_module.NormalizedLandmarkList +_LandmarkList = landmark_module.LandmarkList + + +@dataclasses.dataclass +class LandmarksDetectionResult: + """Represents the landmarks detection result. + + Attributes: + landmarks : A `NormalizedLandmarkList` object. + classifications : A `ClassificationList` object. + world_landmarks : A `LandmarkList` object. + rect : A `NormalizedRect` object. + """ + + landmarks: Optional[_NormalizedLandmarkList] + classifications: Optional[_ClassificationList] + world_landmarks: Optional[_LandmarkList] + rect: _NormalizedRect + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _LandmarksDetectionResultProto: + """Generates a LandmarksDetectionResult protobuf object.""" + return _LandmarksDetectionResultProto( + landmarks=self.landmarks.to_pb2(), + classifications=self.classifications.to_pb2(), + world_landmarks=self.world_landmarks.to_pb2(), + rect=self.rect.to_pb2()) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, + pb2_obj: _LandmarksDetectionResultProto + ) -> 'LandmarksDetectionResult': + """Creates a `LandmarksDetectionResult` object from the given protobuf + object.""" + return LandmarksDetectionResult( + landmarks=_NormalizedLandmarkList.create_from_pb2(pb2_obj.landmarks), + classifications=_ClassificationList.create_from_pb2( + pb2_obj.classifications), + world_landmarks=_LandmarkList.create_from_pb2(pb2_obj.world_landmarks), + rect=_NormalizedRect.create_from_pb2(pb2_obj.rect)) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + Args: + other: The object to be compared with. + Returns: + True if the objects are equal. + """ + if not isinstance(other, LandmarksDetectionResult): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 0dd83edcf..0d8b99984 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -43,15 +43,19 @@ py_test( data = [ "//mediapipe/tasks/testdata/vision:test_images", "//mediapipe/tasks/testdata/vision:test_models", + "//mediapipe/tasks/testdata/vision:test_protos", ], deps = [ "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2", + "//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/components/containers:classification", "//mediapipe/tasks/python/components/containers:landmark", - "//mediapipe/tasks/python/components/containers:rect", + "//mediapipe/tasks/python/components/containers:landmark_detection_result", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:gesture_recognizer", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + "@com_google_protobuf//:protobuf_python" ], ) diff --git a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py index 288cfd1f5..7d731d805 100644 --- a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py +++ b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py @@ -15,23 +15,31 @@ import enum +from google.protobuf import text_format from absl.testing import absltest from absl.testing import parameterized from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2 from mediapipe.tasks.python.components.containers import rect as rect_module from mediapipe.tasks.python.components.containers import classification as classification_module from mediapipe.tasks.python.components.containers import landmark as landmark_module +from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import gesture_recognizer from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult _BaseOptions = base_options_module.BaseOptions _NormalizedRect = rect_module.NormalizedRect +_Classification = classification_module.Classification _ClassificationList = classification_module.ClassificationList +_Landmark = landmark_module.Landmark _LandmarkList = landmark_module.LandmarkList +_NormalizedLandmark = landmark_module.NormalizedLandmark _NormalizedLandmarkList = landmark_module.NormalizedLandmarkList +_LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult _Image = image_module.Image _GestureRecognizer = gesture_recognizer.GestureRecognizer _GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions @@ -39,8 +47,35 @@ _GestureRecognitionResult = gesture_recognizer.GestureRecognitionResult _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _GESTURE_RECOGNIZER_MODEL_FILE = 'gesture_recognizer.task' -_IMAGE_FILE = 'right_hands.jpg' -_EXPECTED_DETECTION_RESULT = _GestureRecognitionResult([], [], [], []) +_THUMB_UP_IMAGE = 'thumb_up.jpg' +_THUMB_UP_LANDMARKS = "thumb_up_landmarks.pbtxt" +_THUMB_UP_LABEL = "Thumb_Up" +_THUMB_UP_INDEX = 5 +_LANDMARKS_ERROR_TOLERANCE = 0.03 + + +def _get_expected_gesture_recognition_result( + file_path: str, gesture_label: str, gesture_index: int +) -> _GestureRecognitionResult: + landmarks_detection_result_file_path = test_utils.get_test_data_path( + file_path) + with open(landmarks_detection_result_file_path, "rb") as f: + landmarks_detection_result_proto = _LandmarksDetectionResultProto() + # # Use this if a .pb file is available. + # landmarks_detection_result_proto.ParseFromString(f.read()) + text_format.Parse(f.read(), landmarks_detection_result_proto) + landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2( + landmarks_detection_result_proto) + gesture = _ClassificationList( + classifications=[ + _Classification(label=gesture_label, index=gesture_index, + display_name='') + ], tensor_index=0, tensor_name='') + return _GestureRecognitionResult( + gestures=[gesture], + handedness=[landmarks_detection_result.classifications], + hand_landmarks=[landmarks_detection_result.landmarks], + hand_world_landmarks=[landmarks_detection_result.world_landmarks]) class ModelFileType(enum.Enum): @@ -53,14 +88,45 @@ class GestureRecognizerTest(parameterized.TestCase): def setUp(self): super().setUp() self.test_image = _Image.create_from_file( - test_utils.get_test_data_path(_IMAGE_FILE)) + test_utils.get_test_data_path(_THUMB_UP_IMAGE)) self.gesture_recognizer_model_path = test_utils.get_test_data_path( _GESTURE_RECOGNIZER_MODEL_FILE) + def _assert_actual_result_approximately_matches_expected_result( + self, + actual_result: _GestureRecognitionResult, + expected_result: _GestureRecognitionResult + ): + # Expects to have the same number of hands detected. + self.assertLen(actual_result.hand_landmarks, + len(expected_result.hand_landmarks)) + self.assertLen(actual_result.hand_world_landmarks, + len(expected_result.hand_world_landmarks)) + self.assertLen(actual_result.handedness, len(expected_result.handedness)) + self.assertLen(actual_result.gestures, len(expected_result.gestures)) + # Actual landmarks match expected landmarks. + self.assertEqual(actual_result.hand_landmarks, + expected_result.hand_landmarks) + # Actual handedness matches expected handedness. + actual_top_handedness = actual_result.handedness[0].classifications[0] + expected_top_handedness = expected_result.handedness[0].classifications[0] + self.assertEqual(actual_top_handedness.index, expected_top_handedness.index) + self.assertEqual(actual_top_handedness.label, expected_top_handedness.label) + # Actual gesture with top score matches expected gesture. + actual_top_gesture = actual_result.gestures[0].classifications[0] + expected_top_gesture = expected_result.gestures[0].classifications[0] + self.assertEqual(actual_top_gesture.index, expected_top_gesture.index) + self.assertEqual(actual_top_gesture.label, expected_top_gesture.label) + @parameterized.parameters( - (ModelFileType.FILE_NAME, _EXPECTED_DETECTION_RESULT), - (ModelFileType.FILE_CONTENT, _EXPECTED_DETECTION_RESULT)) - def test_recognize(self, model_file_type, expected_recognition_result): + (ModelFileType.FILE_NAME, 0.3, _get_expected_gesture_recognition_result( + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX + )), + (ModelFileType.FILE_CONTENT, 0.3, _get_expected_gesture_recognition_result( + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX + ))) + def test_recognize(self, model_file_type, min_gesture_confidence, + expected_recognition_result): # Creates gesture recognizer. if model_file_type is ModelFileType.FILE_NAME: gesture_recognizer_base_options = _BaseOptions( @@ -75,13 +141,16 @@ class GestureRecognizerTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _GestureRecognizerOptions( - base_options=gesture_recognizer_base_options) + base_options=gesture_recognizer_base_options, + min_gesture_confidence=min_gesture_confidence + ) recognizer = _GestureRecognizer.create_from_options(options) # Performs hand gesture recognition on the input. recognition_result = recognizer.recognize(self.test_image) # Comparing results. - self.assertEqual(recognition_result, expected_recognition_result) + self._assert_actual_result_approximately_matches_expected_result( + recognition_result, expected_recognition_result) # Closes the gesture recognizer explicitly when the detector is not used in # a context. recognizer.close() diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index aca7a5277..c00508b36 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -136,8 +136,6 @@ class GestureRecognizerOptions: """Generates an GestureRecognizerOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - # hand_landmark_detector_base_options_proto = self.hand_landmark_detector_base_options.to_pb2() - # hand_landmark_detector_base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True # Configure hand detector options. hand_detector_options_proto = _HandDetectorGraphOptionsProto( @@ -153,13 +151,12 @@ class GestureRecognizerOptions: min_tracking_confidence=self.min_tracking_confidence) # Configure hand gesture recognizer options. - hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto() - if self.min_gesture_confidence >= 0: - classifier_options = _ClassifierOptions( - score_threshold=self.min_gesture_confidence) - hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options = \ - _GestureClassifierGraphOptionsProto( - classifier_options=classifier_options.to_pb2()) + classifier_options = _ClassifierOptions( + score_threshold=self.min_gesture_confidence) + gesture_classifier_options = _GestureClassifierGraphOptionsProto( + classifier_options=classifier_options.to_pb2()) + hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto( + canned_gesture_classifier_graph_options=gesture_classifier_options) return _GestureRecognizerGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index ebb8f05a6..365921bc1 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -121,6 +121,7 @@ filegroup( "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "hand_landmarker.task", + "gesture_recognizer.task", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite", From 8762d15c81ee201188cfc482f0bbcc8b18ec0530 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Tue, 25 Oct 2022 11:11:15 -0700 Subject: [PATCH 03/34] Added remaining tests for the GestureRecognizer Python MediaPipe Tasks API --- .../containers/landmark_detection_result.py | 2 +- mediapipe/tasks/python/test/vision/BUILD | 1 + .../test/vision/gesture_recognizer_test.py | 311 ++++++++++++++++-- mediapipe/tasks/python/vision/BUILD | 1 - mediapipe/tasks/python/vision/core/BUILD | 9 + .../vision/core/base_vision_task_api.py | 49 +++ .../vision/core/image_processing_options.py | 39 +++ .../tasks/python/vision/gesture_recognizer.py | 44 +-- 8 files changed, 414 insertions(+), 42 deletions(-) create mode 100644 mediapipe/tasks/python/vision/core/image_processing_options.py diff --git a/mediapipe/tasks/python/components/containers/landmark_detection_result.py b/mediapipe/tasks/python/components/containers/landmark_detection_result.py index c3d93d414..02ca5a918 100644 --- a/mediapipe/tasks/python/components/containers/landmark_detection_result.py +++ b/mediapipe/tasks/python/components/containers/landmark_detection_result.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Landmark Detection Result data class.""" +"""Landmarks Detection Result data class.""" import dataclasses from typing import Any, Optional diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 0d8b99984..6455d7fcf 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -56,6 +56,7 @@ py_test( "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:gesture_recognizer", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + "//mediapipe/tasks/python/vision/core:image_processing_options", "@com_google_protobuf//:protobuf_python" ], ) diff --git a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py index 7d731d805..a8316c528 100644 --- a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py +++ b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py @@ -14,7 +14,9 @@ """Tests for gesture recognizer.""" import enum +from unittest import mock +import numpy as np from google.protobuf import text_format from absl.testing import absltest from absl.testing import parameterized @@ -29,10 +31,11 @@ from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import gesture_recognizer from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module _LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult _BaseOptions = base_options_module.BaseOptions -_NormalizedRect = rect_module.NormalizedRect +_Rect = rect_module.Rect _Classification = classification_module.Classification _ClassificationList = classification_module.ClassificationList _Landmark = landmark_module.Landmark @@ -45,12 +48,19 @@ _GestureRecognizer = gesture_recognizer.GestureRecognizer _GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions _GestureRecognitionResult = gesture_recognizer.GestureRecognitionResult _RUNNING_MODE = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _GESTURE_RECOGNIZER_MODEL_FILE = 'gesture_recognizer.task' +_NO_HANDS_IMAGE = 'cats_and_dogs.jpg' +_TWO_HANDS_IMAGE = 'right_hands.jpg' _THUMB_UP_IMAGE = 'thumb_up.jpg' -_THUMB_UP_LANDMARKS = "thumb_up_landmarks.pbtxt" -_THUMB_UP_LABEL = "Thumb_Up" +_THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt' +_THUMB_UP_LABEL = 'Thumb_Up' _THUMB_UP_INDEX = 5 +_POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg' +_POINTING_UP_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt' +_POINTING_UP_LABEL = 'Pointing_Up' +_POINTING_UP_INDEX = 3 _LANDMARKS_ERROR_TOLERANCE = 0.03 @@ -89,7 +99,7 @@ class GestureRecognizerTest(parameterized.TestCase): super().setUp() self.test_image = _Image.create_from_file( test_utils.get_test_data_path(_THUMB_UP_IMAGE)) - self.gesture_recognizer_model_path = test_utils.get_test_data_path( + self.model_path = test_utils.get_test_data_path( _GESTURE_RECOGNIZER_MODEL_FILE) def _assert_actual_result_approximately_matches_expected_result( @@ -105,8 +115,15 @@ class GestureRecognizerTest(parameterized.TestCase): self.assertLen(actual_result.handedness, len(expected_result.handedness)) self.assertLen(actual_result.gestures, len(expected_result.gestures)) # Actual landmarks match expected landmarks. - self.assertEqual(actual_result.hand_landmarks, - expected_result.hand_landmarks) + self.assertLen(actual_result.hand_landmarks[0].landmarks, + len(expected_result.hand_landmarks[0].landmarks)) + actual_landmarks = actual_result.hand_landmarks[0].landmarks + expected_landmarks = expected_result.hand_landmarks[0].landmarks + for i in range(len(actual_landmarks)): + self.assertAlmostEqual(actual_landmarks[i].x, expected_landmarks[i].x, + delta=_LANDMARKS_ERROR_TOLERANCE) + self.assertAlmostEqual(actual_landmarks[i].y, expected_landmarks[i].y, + delta=_LANDMARKS_ERROR_TOLERANCE) # Actual handedness matches expected handedness. actual_top_handedness = actual_result.handedness[0].classifications[0] expected_top_handedness = expected_result.handedness[0].classifications[0] @@ -118,32 +135,56 @@ class GestureRecognizerTest(parameterized.TestCase): self.assertEqual(actual_top_gesture.index, expected_top_gesture.index) self.assertEqual(actual_top_gesture.label, expected_top_gesture.label) + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _GestureRecognizer.create_from_model_path(self.model_path) as recognizer: + self.assertIsInstance(recognizer, _GestureRecognizer) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _GestureRecognizerOptions(base_options=base_options) + with _GestureRecognizer.create_from_options(options) as recognizer: + self.assertIsInstance(recognizer, _GestureRecognizer) + + def test_create_from_options_fails_with_invalid_model_path(self): + # Invalid empty model path. + with self.assertRaisesRegex( + ValueError, + r"ExternalFile must specify at least one of 'file_content', " + r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): + base_options = _BaseOptions(model_asset_path='') + options = _GestureRecognizerOptions(base_options=base_options) + _GestureRecognizer.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _GestureRecognizerOptions(base_options=base_options) + recognizer = _GestureRecognizer.create_from_options(options) + self.assertIsInstance(recognizer, _GestureRecognizer) + @parameterized.parameters( - (ModelFileType.FILE_NAME, 0.3, _get_expected_gesture_recognition_result( + (ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result( _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX )), - (ModelFileType.FILE_CONTENT, 0.3, _get_expected_gesture_recognition_result( + (ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result( _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX ))) - def test_recognize(self, model_file_type, min_gesture_confidence, - expected_recognition_result): + def test_recognize(self, model_file_type, expected_recognition_result): # Creates gesture recognizer. if model_file_type is ModelFileType.FILE_NAME: - gesture_recognizer_base_options = _BaseOptions( - model_asset_path=self.gesture_recognizer_model_path) + base_options = _BaseOptions(model_asset_path=self.model_path) elif model_file_type is ModelFileType.FILE_CONTENT: - with open(self.gesture_recognizer_model_path, 'rb') as f: + with open(self.model_path, 'rb') as f: model_content = f.read() - gesture_recognizer_base_options = _BaseOptions( - model_asset_buffer=model_content) + base_options = _BaseOptions(model_asset_buffer=model_content) else: # Should never happen raise ValueError('model_file_type is invalid.') - options = _GestureRecognizerOptions( - base_options=gesture_recognizer_base_options, - min_gesture_confidence=min_gesture_confidence - ) + options = _GestureRecognizerOptions(base_options=base_options) recognizer = _GestureRecognizer.create_from_options(options) # Performs hand gesture recognition on the input. @@ -151,10 +192,238 @@ class GestureRecognizerTest(parameterized.TestCase): # Comparing results. self._assert_actual_result_approximately_matches_expected_result( recognition_result, expected_recognition_result) - # Closes the gesture recognizer explicitly when the detector is not used in - # a context. + # Closes the gesture recognizer explicitly when the gesture recognizer is + # not used in a context. recognizer.close() + @parameterized.parameters( + (ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result( + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX + )), + (ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result( + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX + ))) + def test_recognize_in_context(self, model_file_type, + expected_recognition_result): + # Creates gesture recognizer. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _GestureRecognizerOptions(base_options=base_options) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(self.test_image) + # Comparing results. + self._assert_actual_result_approximately_matches_expected_result( + recognition_result, expected_recognition_result) + + def test_recognize_succeeds_with_num_hands(self): + # Creates gesture recognizer. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _GestureRecognizerOptions(base_options=base_options, num_hands=2) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Load the pointing up rotated image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(_TWO_HANDS_IMAGE)) + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(test_image) + # Comparing results. + self.assertLen(recognition_result.handedness, 2) + + def test_recognize_succeeds_with_rotation(self): + # Creates gesture recognizer. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _GestureRecognizerOptions(base_options=base_options, num_hands=1) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Load the pointing up rotated image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(_POINTING_UP_ROTATED_IMAGE)) + # Set rotation parameters using ImageProcessingOptions. + image_processing_options = _ImageProcessingOptions(rotation_degrees=-90) + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(test_image, + image_processing_options) + expected_recognition_result = _get_expected_gesture_recognition_result( + _POINTING_UP_LANDMARKS, _POINTING_UP_LABEL, _POINTING_UP_INDEX) + # Comparing results. + self._assert_actual_result_approximately_matches_expected_result( + recognition_result, expected_recognition_result) + + def test_recognize_fails_with_region_of_interest(self): + # Creates gesture recognizer. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _GestureRecognizerOptions(base_options=base_options, num_hands=1) + with self.assertRaisesRegex( + ValueError, "This task doesn't support region-of-interest."): + with _GestureRecognizer.create_from_options(options) as recognizer: + # Set the `region_of_interest` parameter using `ImageProcessingOptions`. + image_processing_options = _ImageProcessingOptions( + region_of_interest=_Rect(0, 0, 1, 1)) + # Attempt to perform hand gesture recognition on the cropped input. + recognizer.recognize(self.test_image, image_processing_options) + + def test_empty_recognition_outputs(self): + options = _GestureRecognizerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path)) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Load the image with no hands. + no_hands_test_image = _Image.create_from_file( + test_utils.get_test_data_path(_NO_HANDS_IMAGE)) + # Performs gesture recognition on the input. + recognition_result = recognizer.recognize(no_hands_test_image) + self.assertEmpty(recognition_result.hand_landmarks) + self.assertEmpty(recognition_result.hand_world_landmarks) + self.assertEmpty(recognition_result.handedness) + self.assertEmpty(recognition_result.gestures) + + def test_missing_result_callback(self): + options = _GestureRecognizerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM) + with self.assertRaisesRegex(ValueError, + r'result callback must be provided'): + with _GestureRecognizer.create_from_options(options) as unused_recognizer: + pass + + @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO)) + def test_illegal_result_callback(self, running_mode): + options = _GestureRecognizerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=running_mode, + result_callback=mock.MagicMock()) + with self.assertRaisesRegex(ValueError, + r'result callback should not be provided'): + with _GestureRecognizer.create_from_options(options) as unused_recognizer: + pass + + def test_calling_recognize_for_video_in_image_mode(self): + options = _GestureRecognizerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE) + with _GestureRecognizer.create_from_options(options) as recognizer: + with self.assertRaisesRegex(ValueError, + r'not initialized with the video mode'): + recognizer.recognize_for_video(self.test_image, 0) + + def test_calling_recognize_async_in_image_mode(self): + options = _GestureRecognizerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE) + with _GestureRecognizer.create_from_options(options) as recognizer: + with self.assertRaisesRegex(ValueError, + r'not initialized with the live stream mode'): + recognizer.recognize_async(self.test_image, 0) + + def test_calling_recognize_in_video_mode(self): + options = _GestureRecognizerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _GestureRecognizer.create_from_options(options) as recognizer: + with self.assertRaisesRegex(ValueError, + r'not initialized with the image mode'): + recognizer.recognize(self.test_image) + + def test_calling_recognize_async_in_video_mode(self): + options = _GestureRecognizerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _GestureRecognizer.create_from_options(options) as recognizer: + with self.assertRaisesRegex(ValueError, + r'not initialized with the live stream mode'): + recognizer.recognize_async(self.test_image, 0) + + def test_recognize_for_video_with_out_of_order_timestamp(self): + options = _GestureRecognizerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _GestureRecognizer.create_from_options(options) as recognizer: + unused_result = recognizer.recognize_for_video(self.test_image, 1) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing'): + recognizer.recognize_for_video(self.test_image, 0) + + def test_recognize_for_video(self): + options = _GestureRecognizerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _GestureRecognizer.create_from_options(options) as recognizer: + for timestamp in range(0, 300, 30): + recognition_result = recognizer.recognize_for_video(self.test_image, + timestamp) + expected_recognition_result = _get_expected_gesture_recognition_result( + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX) + self._assert_actual_result_approximately_matches_expected_result( + recognition_result, expected_recognition_result) + + def test_calling_recognize_in_live_stream_mode(self): + options = _GestureRecognizerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock()) + with _GestureRecognizer.create_from_options(options) as recognizer: + with self.assertRaisesRegex(ValueError, + r'not initialized with the image mode'): + recognizer.recognize(self.test_image) + + def test_calling_recognize_for_video_in_live_stream_mode(self): + options = _GestureRecognizerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock()) + with _GestureRecognizer.create_from_options(options) as recognizer: + with self.assertRaisesRegex(ValueError, + r'not initialized with the video mode'): + recognizer.recognize_for_video(self.test_image, 0) + + def test_recognize_async_calls_with_illegal_timestamp(self): + options = _GestureRecognizerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock()) + with _GestureRecognizer.create_from_options(options) as recognizer: + recognizer.recognize_async(self.test_image, 100) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing'): + recognizer.recognize_async(self.test_image, 0) + + @parameterized.parameters( + (_THUMB_UP_IMAGE, _get_expected_gesture_recognition_result( + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)), + (_NO_HANDS_IMAGE, _GestureRecognitionResult([], [], [], []))) + def test_recognize_async_calls(self, image_path, expected_result): + test_image = _Image.create_from_file( + test_utils.get_test_data_path(image_path)) + observed_timestamp_ms = -1 + + def check_result(result: _GestureRecognitionResult, output_image: _Image, + timestamp_ms: int): + if result.hand_landmarks and result.hand_world_landmarks and \ + result.handedness and result.gestures: + self._assert_actual_result_approximately_matches_expected_result( + result, expected_result) + else: + self.assertEqual(result, expected_result) + self.assertTrue( + np.array_equal(output_image.numpy_view(), + test_image.numpy_view())) + self.assertLess(observed_timestamp_ms, timestamp_ms) + self.observed_timestamp_ms = timestamp_ms + + options = _GestureRecognizerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=check_result) + with _GestureRecognizer.create_from_options(options) as recognizer: + for timestamp in range(0, 300, 30): + recognizer.recognize_async(test_image, timestamp) + if __name__ == '__main__': absltest.main() diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 9a9ca3429..3b2a7e50d 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -52,7 +52,6 @@ py_library( "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_py_pb2", - "//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/components/containers:classification", "//mediapipe/tasks/python/components/containers:landmark", "//mediapipe/tasks/python/components/processors:classifier_options", diff --git a/mediapipe/tasks/python/vision/core/BUILD b/mediapipe/tasks/python/vision/core/BUILD index df1b06f4c..ddb7c024e 100644 --- a/mediapipe/tasks/python/vision/core/BUILD +++ b/mediapipe/tasks/python/vision/core/BUILD @@ -23,6 +23,14 @@ py_library( srcs = ["vision_task_running_mode.py"], ) +py_library( + name = "image_processing_options", + srcs = ["image_processing_options.py"], + deps = [ + "//mediapipe/tasks/python/components/containers:rect", + ], +) + py_library( name = "base_vision_task_api", srcs = [ @@ -30,6 +38,7 @@ py_library( ], deps = [ ":vision_task_running_mode", + ":image_processing_options", "//mediapipe/framework:calculator_py_pb2", "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/vision/core/base_vision_task_api.py b/mediapipe/tasks/python/vision/core/base_vision_task_api.py index b2f8a366a..be290c83c 100644 --- a/mediapipe/tasks/python/vision/core/base_vision_task_api.py +++ b/mediapipe/tasks/python/vision/core/base_vision_task_api.py @@ -13,17 +13,22 @@ # limitations under the License. """MediaPipe vision task base api.""" +import math from typing import Callable, Mapping, Optional from mediapipe.framework import calculator_pb2 from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.components.containers import rect as rect_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module _TaskRunner = task_runner_module.TaskRunner _Packet = packet_module.Packet +_NormalizedRect = rect_module.NormalizedRect _RunningMode = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions class BaseVisionTaskApi(object): @@ -122,6 +127,50 @@ class BaseVisionTaskApi(object): + self._running_mode.name) self._runner.send(inputs) + @staticmethod + def convert_to_normalized_rect( + options: _ImageProcessingOptions, + roi_allowed: bool = True + ) -> _NormalizedRect: + """ + Convert from ImageProcessingOptions to NormalizedRect, performing sanity + checks on-the-fly. If the input ImageProcessingOptions is not present, + returns a default NormalizedRect covering the whole image with rotation set + to 0. If 'roi_allowed' is false, an error will be returned if the input + ImageProcessingOptions has its 'region_of_interest' field set. + + Args: + options: Options for image processing. + roi_allowed: Indicates if the `region_of_interest` field is allowed to be + set. By default, it's set to True. + + """ + normalized_rect = _NormalizedRect(rotation=0, x_center=0.5, y_center=0.5, + width=1, height=1) + if options is None: + return normalized_rect + + if options.rotation_degrees % 90 != 0: + raise ValueError("Expected rotation to be a multiple of 90°.") + + # Convert to radians counter-clockwise. + normalized_rect.rotation = -options.rotation_degrees * math.pi / 180.0 + + if options.region_of_interest: + if not roi_allowed: + raise ValueError("This task doesn't support region-of-interest.") + roi = options.region_of_interest + if roi.x_center >= roi.width or roi.y_center >= roi.height: + raise ValueError( + "Expected Rect with x_center < width and y_center < height.") + if roi.x_center < 0 or roi.y_center < 0 or roi.width > 1 or roi.height > 1: + raise ValueError("Expected Rect values to be in [0,1].") + normalized_rect.x_center = roi.x_center + roi.width / 2.0 + normalized_rect.y_center = roi.y_center + roi.height / 2.0 + normalized_rect.width = roi.width - roi.x_center + normalized_rect.height = roi.height - roi.y_center + return normalized_rect + def close(self) -> None: """Shuts down the mediapipe vision task instance. diff --git a/mediapipe/tasks/python/vision/core/image_processing_options.py b/mediapipe/tasks/python/vision/core/image_processing_options.py new file mode 100644 index 000000000..2a3a13088 --- /dev/null +++ b/mediapipe/tasks/python/vision/core/image_processing_options.py @@ -0,0 +1,39 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe vision options for image processing.""" + +import dataclasses +from typing import Optional + +from mediapipe.tasks.python.components.containers import rect as rect_module + + +@dataclasses.dataclass +class ImageProcessingOptions: + """Options for image processing. + + If both region-of-interest and rotation are specified, the crop around the + region-of-interest is extracted first, then the specified rotation is applied + to the crop. + + Attributes: + region_of_interest: The optional region-of-interest to crop from the image. + If not specified, the full image is used. Coordinates must be in [0,1] + with 'left' < 'right' and 'top' < bottom. + rotation_degress: The rotation to apply to the image (or cropped + region-of-interest), in degrees clockwise. The rotation must be a + multiple (positive or negative) of 90°. + """ + region_of_interest: Optional[rect_module.Rect] = None + rotation_degrees: int = 0 diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index c00508b36..0036aa877 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -27,7 +27,6 @@ from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_reco from mediapipe.tasks.cc.vision.hand_detector.proto import hand_detector_graph_options_pb2 from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2 from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarks_detector_graph_options_pb2 -from mediapipe.tasks.python.components.containers import rect as rect_module from mediapipe.tasks.python.components.containers import classification as classification_module from mediapipe.tasks.python.components.containers import landmark as landmark_module from mediapipe.tasks.python.components.processors import classifier_options @@ -36,8 +35,8 @@ from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module -_NormalizedRect = rect_module.NormalizedRect _BaseOptions = base_options_module.BaseOptions _GestureClassifierGraphOptionsProto = gesture_classifier_graph_options_pb2.GestureClassifierGraphOptions _GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions @@ -47,6 +46,7 @@ _HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmar _HandLandmarksDetectorGraphOptionsProto = hand_landmarks_detector_graph_options_pb2.HandLandmarksDetectorGraphOptions _ClassifierOptions = classifier_options.ClassifierOptions _RunningMode = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo _TaskRunner = task_runner_module.TaskRunner @@ -67,11 +67,6 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerG _MICRO_SECONDS_PER_MILLISECOND = 1000 -def _build_full_image_norm_rect() -> _NormalizedRect: - # Builds a NormalizedRect covering the entire image. - return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1) - - @dataclasses.dataclass class GestureRecognitionResult: """The gesture recognition result from GestureRecognizer, where each vector @@ -278,7 +273,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): def recognize( self, image: image_module.Image, - roi: Optional[_NormalizedRect] = None + image_processing_options: Optional[_ImageProcessingOptions] = None ) -> GestureRecognitionResult: """Performs hand gesture recognition on the given image. Only use this method when the GestureRecognizer is created with the image running mode. @@ -289,7 +284,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. - roi: The region of interest. + image_processing_options: Options for image processing. Returns: The hand gesture recognition results. @@ -298,11 +293,16 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If gesture recognition failed to run. """ - norm_rect = roi if roi is not None else _build_full_image_norm_rect() + normalized_rect = self.convert_to_normalized_rect(image_processing_options, + roi_allowed=False) output_packets = self._process_image_data({ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), _NORM_RECT_STREAM_NAME: packet_creator.create_proto( - norm_rect.to_pb2())}) + normalized_rect.to_pb2())}) + + if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): + return GestureRecognitionResult([], [], [], []) + gestures_proto_list = packet_getter.get_proto_list( output_packets[_HAND_GESTURE_STREAM_NAME]) handedness_proto_list = packet_getter.get_proto_list( @@ -331,7 +331,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): def recognize_for_video( self, image: image_module.Image, timestamp_ms: int, - roi: Optional[_NormalizedRect] = None + image_processing_options: Optional[_ImageProcessingOptions] = None ) -> GestureRecognitionResult: """Performs gesture recognition on the provided video frame. Only use this method when the GestureRecognizer is created with the video running mode. @@ -344,7 +344,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input video frame in milliseconds. - roi: The region of interest. + image_processing_options: Options for image processing. Returns: The hand gesture recognition results. @@ -353,14 +353,19 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If gesture recognition failed to run. """ - norm_rect = roi if roi is not None else _build_full_image_norm_rect() + normalized_rect = self.convert_to_normalized_rect(image_processing_options, + roi_allowed=False) output_packets = self._process_video_data({ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), _NORM_RECT_STREAM_NAME: packet_creator.create_proto( - norm_rect.to_pb2()).at( + normalized_rect.to_pb2()).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) }) + + if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): + return GestureRecognitionResult([], [], [], []) + gestures_proto_list = packet_getter.get_proto_list( output_packets[_HAND_GESTURE_STREAM_NAME]) handedness_proto_list = packet_getter.get_proto_list( @@ -390,7 +395,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): self, image: image_module.Image, timestamp_ms: int, - roi: Optional[_NormalizedRect] = None + image_processing_options: Optional[_ImageProcessingOptions] = None ) -> None: """Sends live image data to perform gesture recognition, and the results will be available via the "result_callback" provided in the @@ -415,17 +420,18 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input image in milliseconds. - roi: The region of interest. + image_processing_options: Options for image processing. Raises: ValueError: If the current input timestamp is smaller than what the gesture recognizer has already processed. """ - norm_rect = roi if roi is not None else _build_full_image_norm_rect() + normalized_rect = self.convert_to_normalized_rect(image_processing_options, + roi_allowed=False) self._send_live_stream_data({ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), _NORM_RECT_STREAM_NAME: packet_creator.create_proto( - norm_rect.to_pb2()).at( + normalized_rect.to_pb2()).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) }) From 75af46d2739245bb46de9fd30e541e2b6de3b077 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Tue, 25 Oct 2022 23:13:12 -0700 Subject: [PATCH 04/34] Revised API to align with recent changes --- .../python/components/containers/classification.py | 10 ++-------- .../python/test/vision/gesture_recognizer_test.py | 2 +- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/mediapipe/tasks/python/components/containers/classification.py b/mediapipe/tasks/python/components/containers/classification.py index 465e2dd28..a9225e804 100644 --- a/mediapipe/tasks/python/components/containers/classification.py +++ b/mediapipe/tasks/python/components/containers/classification.py @@ -85,8 +85,6 @@ class ClassificationList: """ classifications: List[Classification] - tensor_index: Optional[int] = None - tensor_name: Optional[str] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _ClassificationListProto: @@ -95,9 +93,7 @@ class ClassificationList: classification=[ classification.to_pb2() for classification in self.classifications - ], - tensor_index=self.tensor_index, - tensor_name=self.tensor_name) + ]) @classmethod @doc_controls.do_not_generate_docs @@ -110,9 +106,7 @@ class ClassificationList: classifications=[ Classification.create_from_pb2(classification) for classification in pb2_obj.classification - ], - tensor_index=pb2_obj.tensor_index, - tensor_name=pb2_obj.tensor_name) + ]) def __eq__(self, other: Any) -> bool: """Checks if this object is equal to the given object. diff --git a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py index a8316c528..3bf994a1d 100644 --- a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py +++ b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py @@ -80,7 +80,7 @@ def _get_expected_gesture_recognition_result( classifications=[ _Classification(label=gesture_label, index=gesture_index, display_name='') - ], tensor_index=0, tensor_name='') + ]) return _GestureRecognitionResult( gestures=[gesture], handedness=[landmarks_detection_result.classifications], From fbf7ba6f1a5259c93ff3b32a968643e2f7b4b454 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Tue, 25 Oct 2022 23:15:16 -0700 Subject: [PATCH 05/34] Reverted some changes to rect --- mediapipe/tasks/python/components/containers/rect.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/tasks/python/components/containers/rect.py b/mediapipe/tasks/python/components/containers/rect.py index 4b943a551..510561592 100644 --- a/mediapipe/tasks/python/components/containers/rect.py +++ b/mediapipe/tasks/python/components/containers/rect.py @@ -26,6 +26,7 @@ _NormalizedRectProto = rect_pb2.NormalizedRect @dataclasses.dataclass class Rect: """A rectangle with rotation in image coordinates. + Attributes: x_center : The X coordinate of the top-left corner, in pixels. y_center : The Y coordinate of the top-left corner, in pixels. width: The width of the rectangle, in pixels. @@ -80,8 +81,11 @@ class Rect: @dataclasses.dataclass class NormalizedRect: """A rectangle with rotation in normalized coordinates. + The values of box + center location and size are within [0, 1]. + Attributes: x_center : The X normalized coordinate of the top-left corner. y_center : The Y normalized coordinate of the top-left corner. width: The width of the rectangle. From b81b5a90354b079fd564aa23d31008e43377a2d3 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Fri, 28 Oct 2022 01:38:15 -0700 Subject: [PATCH 06/34] Added a test for min_gesture_confidence --- .../test/vision/gesture_recognizer_test.py | 18 ++++++++++++++++++ .../vision/core/image_processing_options.py | 2 +- .../tasks/python/vision/gesture_recognizer.py | 8 ++++---- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py index 3bf994a1d..cbee18170 100644 --- a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py +++ b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py @@ -224,6 +224,24 @@ class GestureRecognizerTest(parameterized.TestCase): self._assert_actual_result_approximately_matches_expected_result( recognition_result, expected_recognition_result) + def test_recognize_succeeds_with_min_gesture_confidence(self): + # Creates gesture recognizer. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _GestureRecognizerOptions(base_options=base_options, + min_gesture_confidence=2) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(self.test_image) + expected_result = _get_expected_gesture_recognition_result( + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX) + # Only contains one top scoring gesture. + self.assertLen(recognition_result.gestures[0].classifications, 1) + # Actual gesture with top score matches expected gesture. + actual_top_gesture = recognition_result.gestures[0].classifications[0] + expected_top_gesture = expected_result.gestures[0].classifications[0] + self.assertEqual(actual_top_gesture.index, expected_top_gesture.index) + self.assertEqual(actual_top_gesture.label, expected_top_gesture.label) + def test_recognize_succeeds_with_num_hands(self): # Creates gesture recognizer. base_options = _BaseOptions(model_asset_path=self.model_path) diff --git a/mediapipe/tasks/python/vision/core/image_processing_options.py b/mediapipe/tasks/python/vision/core/image_processing_options.py index 2a3a13088..1a519809c 100644 --- a/mediapipe/tasks/python/vision/core/image_processing_options.py +++ b/mediapipe/tasks/python/vision/core/image_processing_options.py @@ -30,7 +30,7 @@ class ImageProcessingOptions: Attributes: region_of_interest: The optional region-of-interest to crop from the image. If not specified, the full image is used. Coordinates must be in [0,1] - with 'left' < 'right' and 'top' < bottom. + with 'x_center' < 'width' and 'y_center' < height. rotation_degress: The rotation to apply to the image (or cropped region-of-interest), in degrees clockwise. The rotation must be a multiple (positive or negative) of 90°. diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index 0036aa877..11cb5c7b2 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -118,10 +118,10 @@ class GestureRecognizerOptions: base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE num_hands: Optional[int] = 1 - min_hand_detection_confidence: Optional[int] = 0.5 - min_hand_presence_confidence: Optional[int] = 0.5 - min_tracking_confidence: Optional[int] = 0.5 - min_gesture_confidence: Optional[int] = -1 + min_hand_detection_confidence: Optional[float] = 0.5 + min_hand_presence_confidence: Optional[float] = 0.5 + min_tracking_confidence: Optional[float] = 0.5 + min_gesture_confidence: Optional[float] = -1 result_callback: Optional[ Callable[[GestureRecognitionResult, image_module.Image, int], None]] = None From f62cfd169005a37362a8d177a3e36372b14aa791 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Sun, 30 Oct 2022 08:23:14 -0700 Subject: [PATCH 07/34] Removed classification proto to use the existing category dataclass instead and removed NormalizedLandmarkList and LandmarkList dataclasses --- .../tasks/python/components/containers/BUILD | 13 +- .../python/components/containers/category.py | 10 +- .../components/containers/classification.py | 121 ------------------ .../python/components/containers/landmark.py | 96 -------------- .../containers/landmark_detection_result.py | 61 ++++++--- mediapipe/tasks/python/test/vision/BUILD | 2 +- .../test/vision/gesture_recognizer_test.py | 49 ++++--- mediapipe/tasks/python/vision/BUILD | 2 +- .../tasks/python/vision/gesture_recognizer.py | 76 +++++++---- 9 files changed, 127 insertions(+), 303 deletions(-) delete mode 100644 mediapipe/tasks/python/components/containers/classification.py diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 85f5b61c6..de3b7352c 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -36,15 +36,6 @@ py_library( ], ) -py_library( - name = "classification", - srcs = ["classification.py"], - deps = [ - "//mediapipe/framework/formats:classification_py_pb2", - "//mediapipe/tasks/python/core:optional_dependencies", - ], -) - py_library( name = "landmark", srcs = ["landmark.py"], @@ -59,9 +50,11 @@ py_library( srcs = ["landmark_detection_result.py"], deps = [ ":rect", - ":classification", ":landmark", + "//mediapipe/framework/formats:classification_py_pb2", + "//mediapipe/framework/formats:landmark_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2", + "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/core:optional_dependencies", ], ) diff --git a/mediapipe/tasks/python/components/containers/category.py b/mediapipe/tasks/python/components/containers/category.py index 0b347fc10..cfdb83740 100644 --- a/mediapipe/tasks/python/components/containers/category.py +++ b/mediapipe/tasks/python/components/containers/category.py @@ -14,7 +14,7 @@ """Category data class.""" import dataclasses -from typing import Any +from typing import Any, Optional from mediapipe.tasks.cc.components.containers.proto import category_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -39,10 +39,10 @@ class Category: category_name: The label of this category object. """ - index: int - score: float - display_name: str - category_name: str + index: Optional[int] = None + score: Optional[float] = None + display_name: Optional[str] = None + category_name: Optional[str] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _CategoryProto: diff --git a/mediapipe/tasks/python/components/containers/classification.py b/mediapipe/tasks/python/components/containers/classification.py deleted file mode 100644 index a9225e804..000000000 --- a/mediapipe/tasks/python/components/containers/classification.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Classification data class.""" - -import dataclasses -from typing import Any, List, Optional - -from mediapipe.framework.formats import classification_pb2 -from mediapipe.tasks.python.core.optional_dependencies import doc_controls - -_ClassificationProto = classification_pb2.Classification -_ClassificationListProto = classification_pb2.ClassificationList - - -@dataclasses.dataclass -class Classification: - """A classification. - - Attributes: - index: The index of the class in the corresponding label map. - score: The probability score for this class. - label_name: Label or name of the class. - display_name: Optional human-readable string for display purposes. - """ - - index: Optional[int] = None - score: Optional[float] = None - label: Optional[str] = None - display_name: Optional[str] = None - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _ClassificationProto: - """Generates a Classification protobuf object.""" - return _ClassificationProto( - index=self.index, - score=self.score, - label=self.label, - display_name=self.display_name) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _ClassificationProto) -> 'Classification': - """Creates a `Classification` object from the given protobuf object.""" - return Classification( - index=pb2_obj.index, - score=pb2_obj.score, - label=pb2_obj.label, - display_name=pb2_obj.display_name) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, Classification): - return False - - return self.to_pb2().__eq__(other.to_pb2()) - - -@dataclasses.dataclass -class ClassificationList: - """Represents the classifications for a given classifier. - Attributes: - classification : A list of `Classification` objects. - tensor_index: Optional index of the tensor that produced these - classifications. - tensor_name: Optional name of the tensor that produced these - classifications tensor metadata name. - """ - - classifications: List[Classification] - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _ClassificationListProto: - """Generates a ClassificationList protobuf object.""" - return _ClassificationListProto( - classification=[ - classification.to_pb2() - for classification in self.classifications - ]) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2( - cls, - pb2_obj: _ClassificationListProto - ) -> 'ClassificationList': - """Creates a `ClassificationList` object from the given protobuf object.""" - return ClassificationList( - classifications=[ - Classification.create_from_pb2(classification) - for classification in pb2_obj.classification - ]) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - Args: - other: The object to be compared with. - Returns: - True if the objects are equal. - """ - if not isinstance(other, ClassificationList): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/containers/landmark.py b/mediapipe/tasks/python/components/containers/landmark.py index a86c17f2f..2c87ee676 100644 --- a/mediapipe/tasks/python/components/containers/landmark.py +++ b/mediapipe/tasks/python/components/containers/landmark.py @@ -20,9 +20,7 @@ from mediapipe.framework.formats import landmark_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls _LandmarkProto = landmark_pb2.Landmark -_LandmarkListProto = landmark_pb2.LandmarkList _NormalizedLandmarkProto = landmark_pb2.NormalizedLandmark -_NormalizedLandmarkListProto = landmark_pb2.NormalizedLandmarkList @dataclasses.dataclass @@ -89,53 +87,6 @@ class Landmark: return self.to_pb2().__eq__(other.to_pb2()) -@dataclasses.dataclass -class LandmarkList: - """Represents the list of landmarks. - - Attributes: - landmarks : A list of `Landmark` objects. - """ - - landmarks: List[Landmark] - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _LandmarkListProto: - """Generates a LandmarkList protobuf object.""" - return _LandmarkListProto( - landmark=[ - landmark.to_pb2() - for landmark in self.landmarks - ] - ) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2( - cls, - pb2_obj: _LandmarkListProto - ) -> 'LandmarkList': - """Creates a `LandmarkList` object from the given protobuf object.""" - return LandmarkList( - landmarks=[ - Landmark.create_from_pb2(landmark) - for landmark in pb2_obj.landmark - ] - ) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - Args: - other: The object to be compared with. - Returns: - True if the objects are equal. - """ - if not isinstance(other, LandmarkList): - return False - - return self.to_pb2().__eq__(other.to_pb2()) - - @dataclasses.dataclass class NormalizedLandmark: """A normalized version of above Landmark proto. @@ -201,50 +152,3 @@ class NormalizedLandmark: return False return self.to_pb2().__eq__(other.to_pb2()) - - -@dataclasses.dataclass -class NormalizedLandmarkList: - """Represents the list of normalized landmarks. - - Attributes: - landmarks : A list of `Landmark` objects. - """ - - landmarks: List[NormalizedLandmark] - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _NormalizedLandmarkListProto: - """Generates a NormalizedLandmarkList protobuf object.""" - return _NormalizedLandmarkListProto( - landmark=[ - landmark.to_pb2() - for landmark in self.landmarks - ] - ) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2( - cls, - pb2_obj: _NormalizedLandmarkListProto - ) -> 'NormalizedLandmarkList': - """Creates a `NormalizedLandmarkList` object from the given protobuf object.""" - return NormalizedLandmarkList( - landmarks=[ - NormalizedLandmark.create_from_pb2(landmark) - for landmark in pb2_obj.landmark - ] - ) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - Args: - other: The object to be compared with. - Returns: - True if the objects are equal. - """ - if not isinstance(other, NormalizedLandmarkList): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/containers/landmark_detection_result.py b/mediapipe/tasks/python/components/containers/landmark_detection_result.py index 02ca5a918..ad21812c7 100644 --- a/mediapipe/tasks/python/components/containers/landmark_detection_result.py +++ b/mediapipe/tasks/python/components/containers/landmark_detection_result.py @@ -14,19 +14,25 @@ """Landmarks Detection Result data class.""" import dataclasses -from typing import Any, Optional +from typing import Any, Optional, List from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2 +from mediapipe.framework.formats import classification_pb2 +from mediapipe.framework.formats import landmark_pb2 from mediapipe.tasks.python.components.containers import rect as rect_module -from mediapipe.tasks.python.components.containers import classification as classification_module +from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import landmark as landmark_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls _LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult +_ClassificationProto = classification_pb2.Classification +_ClassificationListProto = classification_pb2.ClassificationList +_LandmarkListProto = landmark_pb2.LandmarkList +_NormalizedLandmarkListProto = landmark_pb2.NormalizedLandmarkList _NormalizedRect = rect_module.NormalizedRect -_ClassificationList = classification_module.ClassificationList -_NormalizedLandmarkList = landmark_module.NormalizedLandmarkList -_LandmarkList = landmark_module.LandmarkList +_Category = category_module.Category +_NormalizedLandmark = landmark_module.NormalizedLandmark +_Landmark = landmark_module.Landmark @dataclasses.dataclass @@ -34,25 +40,32 @@ class LandmarksDetectionResult: """Represents the landmarks detection result. Attributes: - landmarks : A `NormalizedLandmarkList` object. - classifications : A `ClassificationList` object. - world_landmarks : A `LandmarkList` object. + landmarks : A list of `NormalizedLandmark` objects. + categories : A list of `Category` objects. + world_landmarks : A list of `Landmark` objects. rect : A `NormalizedRect` object. """ - landmarks: Optional[_NormalizedLandmarkList] - classifications: Optional[_ClassificationList] - world_landmarks: Optional[_LandmarkList] + landmarks: Optional[List[_NormalizedLandmark]] + categories: Optional[List[_Category]] + world_landmarks: Optional[List[_Landmark]] rect: _NormalizedRect @doc_controls.do_not_generate_docs def to_pb2(self) -> _LandmarksDetectionResultProto: """Generates a LandmarksDetectionResult protobuf object.""" return _LandmarksDetectionResultProto( - landmarks=self.landmarks.to_pb2(), - classifications=self.classifications.to_pb2(), - world_landmarks=self.world_landmarks.to_pb2(), - rect=self.rect.to_pb2()) + landmarks=_NormalizedLandmarkListProto(landmarks=self.landmarks), + classifications=_ClassificationListProto( + classification=[ + _ClassificationProto( + index=category.index, + score=category.score, + label=category.category_name, + display_name=category.display_name) + for category in self.categories]), + world_landmarks=_LandmarkListProto(landmarks=self.world_landmarks), + rect=self.rect.to_pb2()) @classmethod @doc_controls.do_not_generate_docs @@ -63,11 +76,19 @@ class LandmarksDetectionResult: """Creates a `LandmarksDetectionResult` object from the given protobuf object.""" return LandmarksDetectionResult( - landmarks=_NormalizedLandmarkList.create_from_pb2(pb2_obj.landmarks), - classifications=_ClassificationList.create_from_pb2( - pb2_obj.classifications), - world_landmarks=_LandmarkList.create_from_pb2(pb2_obj.world_landmarks), - rect=_NormalizedRect.create_from_pb2(pb2_obj.rect)) + landmarks=[ + _NormalizedLandmark.create_from_pb2(landmark) + for landmark in pb2_obj.landmarks.landmark], + categories=[category_module.Category( + score=classification.score, + index=classification.index, + category_name=classification.label, + display_name=classification.display_name) + for classification in pb2_obj.classifications.classification], + world_landmarks=[ + _Landmark.create_from_pb2(landmark) + for landmark in pb2_obj.world_landmarks.landmark], + rect=_NormalizedRect.create_from_pb2(pb2_obj.rect)) def __eq__(self, other: Any) -> bool: """Checks if this object is equal to the given object. diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 5ba91e6a9..9b0fab6cb 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -69,7 +69,7 @@ py_test( "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/containers:classification", + "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:landmark", "//mediapipe/tasks/python/components/containers:landmark_detection_result", "//mediapipe/tasks/python/core:base_options", diff --git a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py index cbee18170..8f7c66519 100644 --- a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py +++ b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py @@ -24,7 +24,7 @@ from absl.testing import parameterized from mediapipe.python._framework_bindings import image as image_module from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2 from mediapipe.tasks.python.components.containers import rect as rect_module -from mediapipe.tasks.python.components.containers import classification as classification_module +from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import landmark as landmark_module from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module from mediapipe.tasks.python.core import base_options as base_options_module @@ -36,12 +36,9 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image _LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult _BaseOptions = base_options_module.BaseOptions _Rect = rect_module.Rect -_Classification = classification_module.Classification -_ClassificationList = classification_module.ClassificationList +_Category = category_module.Category _Landmark = landmark_module.Landmark -_LandmarkList = landmark_module.LandmarkList _NormalizedLandmark = landmark_module.NormalizedLandmark -_NormalizedLandmarkList = landmark_module.NormalizedLandmarkList _LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult _Image = image_module.Image _GestureRecognizer = gesture_recognizer.GestureRecognizer @@ -76,14 +73,11 @@ def _get_expected_gesture_recognition_result( text_format.Parse(f.read(), landmarks_detection_result_proto) landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2( landmarks_detection_result_proto) - gesture = _ClassificationList( - classifications=[ - _Classification(label=gesture_label, index=gesture_index, - display_name='') - ]) + gesture = _Category(category_name=gesture_label, index=gesture_index, + display_name='') return _GestureRecognitionResult( - gestures=[gesture], - handedness=[landmarks_detection_result.classifications], + gestures=[[gesture]], + handedness=[landmarks_detection_result.categories], hand_landmarks=[landmarks_detection_result.landmarks], hand_world_landmarks=[landmarks_detection_result.world_landmarks]) @@ -115,25 +109,27 @@ class GestureRecognizerTest(parameterized.TestCase): self.assertLen(actual_result.handedness, len(expected_result.handedness)) self.assertLen(actual_result.gestures, len(expected_result.gestures)) # Actual landmarks match expected landmarks. - self.assertLen(actual_result.hand_landmarks[0].landmarks, - len(expected_result.hand_landmarks[0].landmarks)) - actual_landmarks = actual_result.hand_landmarks[0].landmarks - expected_landmarks = expected_result.hand_landmarks[0].landmarks + self.assertLen(actual_result.hand_landmarks[0], + len(expected_result.hand_landmarks[0])) + actual_landmarks = actual_result.hand_landmarks[0] + expected_landmarks = expected_result.hand_landmarks[0] for i in range(len(actual_landmarks)): self.assertAlmostEqual(actual_landmarks[i].x, expected_landmarks[i].x, delta=_LANDMARKS_ERROR_TOLERANCE) self.assertAlmostEqual(actual_landmarks[i].y, expected_landmarks[i].y, delta=_LANDMARKS_ERROR_TOLERANCE) # Actual handedness matches expected handedness. - actual_top_handedness = actual_result.handedness[0].classifications[0] - expected_top_handedness = expected_result.handedness[0].classifications[0] + actual_top_handedness = actual_result.handedness[0][0] + expected_top_handedness = expected_result.handedness[0][0] self.assertEqual(actual_top_handedness.index, expected_top_handedness.index) - self.assertEqual(actual_top_handedness.label, expected_top_handedness.label) + self.assertEqual(actual_top_handedness.category_name, + expected_top_handedness.category_name) # Actual gesture with top score matches expected gesture. - actual_top_gesture = actual_result.gestures[0].classifications[0] - expected_top_gesture = expected_result.gestures[0].classifications[0] + actual_top_gesture = actual_result.gestures[0][0] + expected_top_gesture = expected_result.gestures[0][0] self.assertEqual(actual_top_gesture.index, expected_top_gesture.index) - self.assertEqual(actual_top_gesture.label, expected_top_gesture.label) + self.assertEqual(actual_top_gesture.category_name, + expected_top_gesture.category_name) def test_create_from_file_succeeds_with_valid_model_path(self): # Creates with default option and valid model file successfully. @@ -235,12 +231,13 @@ class GestureRecognizerTest(parameterized.TestCase): expected_result = _get_expected_gesture_recognition_result( _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX) # Only contains one top scoring gesture. - self.assertLen(recognition_result.gestures[0].classifications, 1) + self.assertLen(recognition_result.gestures[0], 1) # Actual gesture with top score matches expected gesture. - actual_top_gesture = recognition_result.gestures[0].classifications[0] - expected_top_gesture = expected_result.gestures[0].classifications[0] + actual_top_gesture = recognition_result.gestures[0][0] + expected_top_gesture = expected_result.gestures[0][0] self.assertEqual(actual_top_gesture.index, expected_top_gesture.index) - self.assertEqual(actual_top_gesture.label, expected_top_gesture.label) + self.assertEqual(actual_top_gesture.category_name, + expected_top_gesture.category_name) def test_recognize_succeeds_with_num_hands(self): # Creates gesture recognizer. diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index fc74911ea..f0b3c9f53 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -74,7 +74,7 @@ py_library( "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_py_pb2", - "//mediapipe/tasks/python/components/containers:classification", + "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:landmark", "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index 11cb5c7b2..142eb1dc6 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -27,7 +27,7 @@ from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_reco from mediapipe.tasks.cc.vision.hand_detector.proto import hand_detector_graph_options_pb2 from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2 from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarks_detector_graph_options_pb2 -from mediapipe.tasks.python.components.containers import classification as classification_module +from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import landmark as landmark_module from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module @@ -80,10 +80,10 @@ class GestureRecognitionResult: hand_world_landmarks: Detected hand landmarks in world coordinates. """ - gestures: List[classification_module.ClassificationList] - handedness: List[classification_module.ClassificationList] - hand_landmarks: List[landmark_module.NormalizedLandmarkList] - hand_world_landmarks: List[landmark_module.LandmarkList] + gestures: List[List[category_module.Category]] + handedness: List[List[category_module.Category]] + hand_landmarks: List[List[landmark_module.NormalizedLandmark]] + hand_world_landmarks: List[List[landmark_module.Landmark]] @dataclasses.dataclass @@ -231,16 +231,26 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): gesture_recognition_result = GestureRecognitionResult( [ - classification_module.ClassificationList.create_from_pb2(gestures) - for gestures in gestures_proto_list + [ + category_module.Category( + index=gesture.index, score=gesture.score, + display_name=gesture.display_name, category_name=gesture.label) + for gesture in gesture_classifications.classification] + for gesture_classifications in gestures_proto_list ], [ - classification_module.ClassificationList.create_from_pb2(handedness) - for handedness in handedness_proto_list + [ + category_module.Category( + index=gesture.index, score=gesture.score, + display_name=gesture.display_name, category_name=gesture.label) + for gesture in handedness_classifications.classification] + for handedness_classifications in handedness_proto_list ], [ - landmark_module.NormalizedLandmarkList.create_from_pb2(hand_landmarks) + [landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) + for hand_landmark in hand_landmarks.landmark] for hand_landmarks in hand_landmarks_proto_list ], [ - landmark_module.LandmarkList.create_from_pb2(hand_world_landmarks) + [landmark_module.Landmark.create_from_pb2(hand_world_landmark) + for hand_world_landmark in hand_world_landmarks.landmark] for hand_world_landmarks in hand_world_landmarks_proto_list ] ) @@ -314,16 +324,26 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): return GestureRecognitionResult( [ - classification_module.ClassificationList.create_from_pb2(gestures) - for gestures in gestures_proto_list + [ + category_module.Category( + index=gesture.index, score=gesture.score, + display_name=gesture.display_name, category_name=gesture.label) + for gesture in gesture_classifications.classification] + for gesture_classifications in gestures_proto_list ], [ - classification_module.ClassificationList.create_from_pb2(handedness) - for handedness in handedness_proto_list + [ + category_module.Category( + index=gesture.index, score=gesture.score, + display_name=gesture.display_name, category_name=gesture.label) + for gesture in handedness_classifications.classification] + for handedness_classifications in handedness_proto_list ], [ - landmark_module.NormalizedLandmarkList.create_from_pb2(hand_landmarks) + [landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) + for hand_landmark in hand_landmarks.landmark] for hand_landmarks in hand_landmarks_proto_list ], [ - landmark_module.LandmarkList.create_from_pb2(hand_world_landmarks) + [landmark_module.Landmark.create_from_pb2(hand_world_landmark) + for hand_world_landmark in hand_world_landmarks.landmark] for hand_world_landmarks in hand_world_landmarks_proto_list ] ) @@ -377,16 +397,26 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): return GestureRecognitionResult( [ - classification_module.ClassificationList.create_from_pb2(gestures) - for gestures in gestures_proto_list + [ + category_module.Category( + index=gesture.index, score=gesture.score, + display_name=gesture.display_name, category_name=gesture.label) + for gesture in gesture_classifications.classification] + for gesture_classifications in gestures_proto_list ], [ - classification_module.ClassificationList.create_from_pb2(handedness) - for handedness in handedness_proto_list + [ + category_module.Category( + index=gesture.index, score=gesture.score, + display_name=gesture.display_name, category_name=gesture.label) + for gesture in handedness_classifications.classification] + for handedness_classifications in handedness_proto_list ], [ - landmark_module.NormalizedLandmarkList.create_from_pb2(hand_landmarks) + [landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) + for hand_landmark in hand_landmarks.landmark] for hand_landmarks in hand_landmarks_proto_list ], [ - landmark_module.LandmarkList.create_from_pb2(hand_world_landmarks) + [landmark_module.Landmark.create_from_pb2(hand_world_landmark) + for hand_world_landmark in hand_world_landmarks.landmark] for hand_world_landmarks in hand_world_landmarks_proto_list ] ) From 4b66599419bd7ffe8bb90db7c7a5533a43004801 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Sun, 30 Oct 2022 09:10:15 -0700 Subject: [PATCH 08/34] Updated docstring in gesture_recognizer --- mediapipe/tasks/python/vision/gesture_recognizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index 142eb1dc6..e8d9ef342 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -444,7 +444,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): The `result_callback` provides: - The hand gesture recognition results. - - The input image that the image classifier runs on. + - The input image that the gesture recognizer runs on. - The input timestamp in milliseconds. Args: From fb4872b068b9d34d63997779f3b746d389852fa5 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Sun, 30 Oct 2022 15:42:26 -0700 Subject: [PATCH 09/34] Refactored code and removed some issues --- .../python/components/containers/landmark.py | 40 +---- .../containers/landmark_detection_result.py | 14 +- .../tasks/python/vision/gesture_recognizer.py | 145 +++++------------- 3 files changed, 49 insertions(+), 150 deletions(-) diff --git a/mediapipe/tasks/python/components/containers/landmark.py b/mediapipe/tasks/python/components/containers/landmark.py index 2c87ee676..7eb7d8e96 100644 --- a/mediapipe/tasks/python/components/containers/landmark.py +++ b/mediapipe/tasks/python/components/containers/landmark.py @@ -30,9 +30,9 @@ class Landmark: Use x for 1D points, (x, y) for 2D points and (x, y, z) for 3D points. Attributes: - x: The x coordinate of the 3D point. - y: The y coordinate of the 3D point. - z: The z coordinate of the 3D point. + x: The x coordinate. + y: The y coordinate. + z: The z coordinate. visibility: Landmark visibility. Should stay unset if not supported. Float score of whether landmark is visible or occluded by other objects. Landmark considered as invisible also if it is not present on the screen @@ -72,20 +72,6 @@ class Landmark: visibility=pb2_obj.visibility, presence=pb2_obj.presence) - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, Landmark): - return False - - return self.to_pb2().__eq__(other.to_pb2()) - @dataclasses.dataclass class NormalizedLandmark: @@ -94,9 +80,9 @@ class NormalizedLandmark: All coordinates should be within [0, 1]. Attributes: - x: The normalized x coordinate of the 3D point. - y: The normalized y coordinate of the 3D point. - z: The normalized z coordinate of the 3D point. + x: The normalized x coordinate. + y: The normalized y coordinate. + z: The normalized z coordinate. visibility: Landmark visibility. Should stay unset if not supported. Float score of whether landmark is visible or occluded by other objects. Landmark considered as invisible also if it is not present on the screen @@ -138,17 +124,3 @@ class NormalizedLandmark: z=pb2_obj.z, visibility=pb2_obj.visibility, presence=pb2_obj.presence) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, NormalizedLandmark): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/containers/landmark_detection_result.py b/mediapipe/tasks/python/components/containers/landmark_detection_result.py index ad21812c7..7c21733e2 100644 --- a/mediapipe/tasks/python/components/containers/landmark_detection_result.py +++ b/mediapipe/tasks/python/components/containers/landmark_detection_result.py @@ -14,7 +14,7 @@ """Landmarks Detection Result data class.""" import dataclasses -from typing import Any, Optional, List +from typing import Optional, List from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2 from mediapipe.framework.formats import classification_pb2 @@ -89,15 +89,3 @@ class LandmarksDetectionResult: _Landmark.create_from_pb2(landmark) for landmark in pb2_obj.world_landmarks.landmark], rect=_NormalizedRect.create_from_pb2(pb2_obj.rect)) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - Args: - other: The object to be compared with. - Returns: - True if the objects are equal. - """ - if not isinstance(other, LandmarksDetectionResult): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index e8d9ef342..c6d30dc4e 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -48,7 +48,6 @@ _ClassifierOptions = classifier_options.ClassifierOptions _RunningMode = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo -_TaskRunner = task_runner_module.TaskRunner _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' @@ -86,6 +85,45 @@ class GestureRecognitionResult: hand_world_landmarks: List[List[landmark_module.Landmark]] +def _build_recognition_result( + output_packets: Mapping[str, packet_module.Packet] +) -> GestureRecognitionResult: + gestures_proto_list = packet_getter.get_proto_list( + output_packets[_HAND_GESTURE_STREAM_NAME]) + handedness_proto_list = packet_getter.get_proto_list( + output_packets[_HANDEDNESS_STREAM_NAME]) + hand_landmarks_proto_list = packet_getter.get_proto_list( + output_packets[_HAND_LANDMARKS_STREAM_NAME]) + hand_world_landmarks_proto_list = packet_getter.get_proto_list( + output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]) + + return GestureRecognitionResult( + [ + [ + category_module.Category( + index=gesture.index, score=gesture.score, + display_name=gesture.display_name, category_name=gesture.label) + for gesture in gesture_classifications.classification] + for gesture_classifications in gestures_proto_list + ], [ + [ + category_module.Category( + index=gesture.index, score=gesture.score, + display_name=gesture.display_name, category_name=gesture.label) + for gesture in handedness_classifications.classification] + for handedness_classifications in handedness_proto_list + ], [ + [landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) + for hand_landmark in hand_landmarks.landmark] + for hand_landmarks in hand_landmarks_proto_list + ], [ + [landmark_module.Landmark.create_from_pb2(hand_world_landmark) + for hand_world_landmark in hand_world_landmarks.landmark] + for hand_world_landmarks in hand_world_landmarks_proto_list + ] + ) + + @dataclasses.dataclass class GestureRecognizerOptions: """Options for the gesture recognizer task. @@ -220,40 +258,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) return - gestures_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_GESTURE_STREAM_NAME]) - handedness_proto_list = packet_getter.get_proto_list( - output_packets[_HANDEDNESS_STREAM_NAME]) - hand_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_LANDMARKS_STREAM_NAME]) - hand_world_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]) - - gesture_recognition_result = GestureRecognitionResult( - [ - [ - category_module.Category( - index=gesture.index, score=gesture.score, - display_name=gesture.display_name, category_name=gesture.label) - for gesture in gesture_classifications.classification] - for gesture_classifications in gestures_proto_list - ], [ - [ - category_module.Category( - index=gesture.index, score=gesture.score, - display_name=gesture.display_name, category_name=gesture.label) - for gesture in handedness_classifications.classification] - for handedness_classifications in handedness_proto_list - ], [ - [landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) - for hand_landmark in hand_landmarks.landmark] - for hand_landmarks in hand_landmarks_proto_list - ], [ - [landmark_module.Landmark.create_from_pb2(hand_world_landmark) - for hand_world_landmark in hand_world_landmarks.landmark] - for hand_world_landmarks in hand_world_landmarks_proto_list - ] - ) + gesture_recognition_result = _build_recognition_result(output_packets) timestamp = output_packets[_HAND_GESTURE_STREAM_NAME].timestamp options.result_callback( gesture_recognition_result, image, @@ -313,40 +318,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): return GestureRecognitionResult([], [], [], []) - gestures_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_GESTURE_STREAM_NAME]) - handedness_proto_list = packet_getter.get_proto_list( - output_packets[_HANDEDNESS_STREAM_NAME]) - hand_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_LANDMARKS_STREAM_NAME]) - hand_world_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]) - - return GestureRecognitionResult( - [ - [ - category_module.Category( - index=gesture.index, score=gesture.score, - display_name=gesture.display_name, category_name=gesture.label) - for gesture in gesture_classifications.classification] - for gesture_classifications in gestures_proto_list - ], [ - [ - category_module.Category( - index=gesture.index, score=gesture.score, - display_name=gesture.display_name, category_name=gesture.label) - for gesture in handedness_classifications.classification] - for handedness_classifications in handedness_proto_list - ], [ - [landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) - for hand_landmark in hand_landmarks.landmark] - for hand_landmarks in hand_landmarks_proto_list - ], [ - [landmark_module.Landmark.create_from_pb2(hand_world_landmark) - for hand_world_landmark in hand_world_landmarks.landmark] - for hand_world_landmarks in hand_world_landmarks_proto_list - ] - ) + return _build_recognition_result(output_packets) def recognize_for_video( self, image: image_module.Image, @@ -386,40 +358,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): return GestureRecognitionResult([], [], [], []) - gestures_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_GESTURE_STREAM_NAME]) - handedness_proto_list = packet_getter.get_proto_list( - output_packets[_HANDEDNESS_STREAM_NAME]) - hand_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_LANDMARKS_STREAM_NAME]) - hand_world_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]) - - return GestureRecognitionResult( - [ - [ - category_module.Category( - index=gesture.index, score=gesture.score, - display_name=gesture.display_name, category_name=gesture.label) - for gesture in gesture_classifications.classification] - for gesture_classifications in gestures_proto_list - ], [ - [ - category_module.Category( - index=gesture.index, score=gesture.score, - display_name=gesture.display_name, category_name=gesture.label) - for gesture in handedness_classifications.classification] - for handedness_classifications in handedness_proto_list - ], [ - [landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) - for hand_landmark in hand_landmarks.landmark] - for hand_landmarks in hand_landmarks_proto_list - ], [ - [landmark_module.Landmark.create_from_pb2(hand_world_landmark) - for hand_world_landmark in hand_world_landmarks.landmark] - for hand_world_landmarks in hand_world_landmarks_proto_list - ] - ) + return _build_recognition_result(output_packets) def recognize_async( self, From 19be9e90123f1905a1d0fc1b7fbbe33b9047d0e6 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 31 Oct 2022 05:34:31 -0700 Subject: [PATCH 10/34] Revised gesture recognizer implementation --- .../test/vision/gesture_recognizer_test.py | 55 +++++++++++++------ mediapipe/tasks/python/vision/BUILD | 3 - .../tasks/python/vision/gesture_recognizer.py | 48 ++++++---------- mediapipe/tasks/testdata/vision/BUILD | 1 + 4 files changed, 57 insertions(+), 50 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py index 8f7c66519..916bd3e0c 100644 --- a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py +++ b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py @@ -47,22 +47,26 @@ _GestureRecognitionResult = gesture_recognizer.GestureRecognitionResult _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions -_GESTURE_RECOGNIZER_MODEL_FILE = 'gesture_recognizer.task' +_GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = 'gesture_recognizer.task' +_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE = 'gesture_recognizer_with_custom_classifier.task' _NO_HANDS_IMAGE = 'cats_and_dogs.jpg' _TWO_HANDS_IMAGE = 'right_hands.jpg' +_FIST_IMAGE = 'fist.jpg' +_FIST_LANDMARKS = 'fist_landmarks.pbtxt' +_FIST_LABEL = 'Closed_Fist' _THUMB_UP_IMAGE = 'thumb_up.jpg' _THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt' _THUMB_UP_LABEL = 'Thumb_Up' -_THUMB_UP_INDEX = 5 _POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg' _POINTING_UP_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt' _POINTING_UP_LABEL = 'Pointing_Up' -_POINTING_UP_INDEX = 3 +_ROCK_LABEL = "Rock" _LANDMARKS_ERROR_TOLERANCE = 0.03 +_GESTURE_EXPECTED_INDEX = -1 def _get_expected_gesture_recognition_result( - file_path: str, gesture_label: str, gesture_index: int + file_path: str, gesture_label: str ) -> _GestureRecognitionResult: landmarks_detection_result_file_path = test_utils.get_test_data_path( file_path) @@ -73,7 +77,8 @@ def _get_expected_gesture_recognition_result( text_format.Parse(f.read(), landmarks_detection_result_proto) landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2( landmarks_detection_result_proto) - gesture = _Category(category_name=gesture_label, index=gesture_index, + gesture = _Category(category_name=gesture_label, + index=_GESTURE_EXPECTED_INDEX, display_name='') return _GestureRecognitionResult( gestures=[[gesture]], @@ -94,7 +99,7 @@ class GestureRecognizerTest(parameterized.TestCase): self.test_image = _Image.create_from_file( test_utils.get_test_data_path(_THUMB_UP_IMAGE)) self.model_path = test_utils.get_test_data_path( - _GESTURE_RECOGNIZER_MODEL_FILE) + _GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) def _assert_actual_result_approximately_matches_expected_result( self, @@ -127,7 +132,7 @@ class GestureRecognizerTest(parameterized.TestCase): # Actual gesture with top score matches expected gesture. actual_top_gesture = actual_result.gestures[0][0] expected_top_gesture = expected_result.gestures[0][0] - self.assertEqual(actual_top_gesture.index, expected_top_gesture.index) + self.assertEqual(actual_top_gesture.index, _GESTURE_EXPECTED_INDEX) self.assertEqual(actual_top_gesture.category_name, expected_top_gesture.category_name) @@ -163,10 +168,10 @@ class GestureRecognizerTest(parameterized.TestCase): @parameterized.parameters( (ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result( - _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL )), (ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result( - _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL ))) def test_recognize(self, model_file_type, expected_recognition_result): # Creates gesture recognizer. @@ -194,10 +199,10 @@ class GestureRecognizerTest(parameterized.TestCase): @parameterized.parameters( (ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result( - _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL )), (ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result( - _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL ))) def test_recognize_in_context(self, model_file_type, expected_recognition_result): @@ -224,12 +229,12 @@ class GestureRecognizerTest(parameterized.TestCase): # Creates gesture recognizer. base_options = _BaseOptions(model_asset_path=self.model_path) options = _GestureRecognizerOptions(base_options=base_options, - min_gesture_confidence=2) + min_gesture_confidence=0.5) with _GestureRecognizer.create_from_options(options) as recognizer: # Performs hand gesture recognition on the input. recognition_result = recognizer.recognize(self.test_image) expected_result = _get_expected_gesture_recognition_result( - _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX) + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL) # Only contains one top scoring gesture. self.assertLen(recognition_result.gestures[0], 1) # Actual gesture with top score matches expected gesture. @@ -266,11 +271,29 @@ class GestureRecognizerTest(parameterized.TestCase): recognition_result = recognizer.recognize(test_image, image_processing_options) expected_recognition_result = _get_expected_gesture_recognition_result( - _POINTING_UP_LANDMARKS, _POINTING_UP_LABEL, _POINTING_UP_INDEX) + _POINTING_UP_LANDMARKS, _POINTING_UP_LABEL) # Comparing results. self._assert_actual_result_approximately_matches_expected_result( recognition_result, expected_recognition_result) + def test_recognize_succeeds_with_custom_gesture_fist(self): + # Creates gesture recognizer. + model_path = test_utils.get_test_data_path( + _GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) + base_options = _BaseOptions(model_asset_path=model_path) + options = _GestureRecognizerOptions(base_options=base_options, num_hands=1) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Load the fist image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(_FIST_IMAGE)) + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(test_image) + expected_recognition_result = _get_expected_gesture_recognition_result( + _FIST_LANDMARKS, _ROCK_LABEL) + # Comparing results. + self._assert_actual_result_approximately_matches_expected_result( + recognition_result, expected_recognition_result) + def test_recognize_fails_with_region_of_interest(self): # Creates gesture recognizer. base_options = _BaseOptions(model_asset_path=self.model_path) @@ -373,7 +396,7 @@ class GestureRecognizerTest(parameterized.TestCase): recognition_result = recognizer.recognize_for_video(self.test_image, timestamp) expected_recognition_result = _get_expected_gesture_recognition_result( - _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX) + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL) self._assert_actual_result_approximately_matches_expected_result( recognition_result, expected_recognition_result) @@ -410,7 +433,7 @@ class GestureRecognizerTest(parameterized.TestCase): @parameterized.parameters( (_THUMB_UP_IMAGE, _get_expected_gesture_recognition_result( - _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)), + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL)), (_NO_HANDS_IMAGE, _GestureRecognitionResult([], [], [], []))) def test_recognize_async_calls(self, image_path, expected_result): test_image = _Image.create_from_file( diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 87de5b987..66c9ece65 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -87,12 +87,9 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_py_pb2", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:landmark", "//mediapipe/tasks/python/components/processors:classifier_options", diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index c6d30dc4e..9a2e3ba29 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -20,13 +20,9 @@ from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module -from mediapipe.python._framework_bindings import task_runner as task_runner_module -from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_classifier_graph_options_pb2 from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_recognizer_graph_options_pb2 from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_recognizer_graph_options_pb2 -from mediapipe.tasks.cc.vision.hand_detector.proto import hand_detector_graph_options_pb2 from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2 -from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarks_detector_graph_options_pb2 from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import landmark as landmark_module from mediapipe.tasks.python.components.processors import classifier_options @@ -38,12 +34,9 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module _BaseOptions = base_options_module.BaseOptions -_GestureClassifierGraphOptionsProto = gesture_classifier_graph_options_pb2.GestureClassifierGraphOptions _GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions _HandGestureRecognizerGraphOptionsProto = hand_gesture_recognizer_graph_options_pb2.HandGestureRecognizerGraphOptions -_HandDetectorGraphOptionsProto = hand_detector_graph_options_pb2.HandDetectorGraphOptions _HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions -_HandLandmarksDetectorGraphOptionsProto = hand_landmarks_detector_graph_options_pb2.HandLandmarksDetectorGraphOptions _ClassifierOptions = classifier_options.ClassifierOptions _RunningMode = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions @@ -64,6 +57,7 @@ _HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks' _HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 +_GESTURE_DEFAULT_INDEX = -1 @dataclasses.dataclass @@ -72,8 +66,9 @@ class GestureRecognitionResult: element represents a single hand detected in the image. Attributes: - gestures: Recognized hand gestures with sorted order such that the - winning label is the first item in the list. + gestures: Recognized hand gestures of detected hands. Note that the index + of the gesture is always 0, because the raw indices from multiple gesture + classifiers cannot consolidate to a meaningful index. handedness: Classification of handedness. hand_landmarks: Detected hand landmarks in normalized image coordinates. hand_world_landmarks: Detected hand landmarks in world coordinates. @@ -101,16 +96,16 @@ def _build_recognition_result( [ [ category_module.Category( - index=gesture.index, score=gesture.score, + index=_GESTURE_DEFAULT_INDEX, score=gesture.score, display_name=gesture.display_name, category_name=gesture.label) for gesture in gesture_classifications.classification] for gesture_classifications in gestures_proto_list ], [ [ category_module.Category( - index=gesture.index, score=gesture.score, - display_name=gesture.display_name, category_name=gesture.label) - for gesture in handedness_classifications.classification] + index=handedness.index, score=handedness.score, + display_name=handedness.display_name, category_name=handedness.label) + for handedness in handedness_classifications.classification] for handedness_classifications in handedness_proto_list ], [ [landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) @@ -170,26 +165,17 @@ class GestureRecognizerOptions: base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - # Configure hand detector options. - hand_detector_options_proto = _HandDetectorGraphOptionsProto( - num_hands=self.num_hands, - min_detection_confidence=self.min_hand_detection_confidence) - - # Configure hand landmarker options. - hand_landmarks_detector_options_proto = _HandLandmarksDetectorGraphOptionsProto( - min_detection_confidence=self.min_hand_presence_confidence) - hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto( - hand_detector_graph_options=hand_detector_options_proto, - hand_landmarks_detector_graph_options=hand_landmarks_detector_options_proto, - min_tracking_confidence=self.min_tracking_confidence) + # Configure hand detector and hand landmarker options. + hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto() + hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence + hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands + hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence + hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence # Configure hand gesture recognizer options. - classifier_options = _ClassifierOptions( - score_threshold=self.min_gesture_confidence) - gesture_classifier_options = _GestureClassifierGraphOptionsProto( - classifier_options=classifier_options.to_pb2()) - hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto( - canned_gesture_classifier_graph_options=gesture_classifier_options) + hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto() + hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence + hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence return _GestureRecognizerGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 0545c5cca..c7265f5c9 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -130,6 +130,7 @@ filegroup( "hand_landmark_lite.tflite", "hand_landmarker.task", "gesture_recognizer.task", + "gesture_recognizer_with_custom_classifier.task", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite", From 888ddd4b74dd1965d386a2e6b34cf7ced99d4a3c Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 31 Oct 2022 05:37:24 -0700 Subject: [PATCH 11/34] Removed unused classifier options proto --- mediapipe/tasks/python/vision/BUILD | 1 - mediapipe/tasks/python/vision/gesture_recognizer.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 66c9ece65..0505471e8 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -92,7 +92,6 @@ py_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:landmark", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index 9a2e3ba29..82dc00f19 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -25,7 +25,6 @@ from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_reco from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2 from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import landmark as landmark_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -37,7 +36,6 @@ _BaseOptions = base_options_module.BaseOptions _GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions _HandGestureRecognizerGraphOptionsProto = hand_gesture_recognizer_graph_options_pb2.HandGestureRecognizerGraphOptions _HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions -_ClassifierOptions = classifier_options.ClassifierOptions _RunningMode = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo From d635b4281e7b7defa9a722162e76e59cebc5e6c9 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 31 Oct 2022 05:47:28 -0700 Subject: [PATCH 12/34] Added a test for the canned classification of the gesture victory --- .../test/vision/gesture_recognizer_test.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py index 916bd3e0c..9e1b47355 100644 --- a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py +++ b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py @@ -53,7 +53,9 @@ _NO_HANDS_IMAGE = 'cats_and_dogs.jpg' _TWO_HANDS_IMAGE = 'right_hands.jpg' _FIST_IMAGE = 'fist.jpg' _FIST_LANDMARKS = 'fist_landmarks.pbtxt' -_FIST_LABEL = 'Closed_Fist' +_VICTORY_IMAGE = 'victory.jpg' +_VICTORY_LANDMARKS = 'victory_landmarks.pbtxt' +_VICTORY_LABEL = 'Victory' _THUMB_UP_IMAGE = 'thumb_up.jpg' _THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt' _THUMB_UP_LABEL = 'Thumb_Up' @@ -276,6 +278,22 @@ class GestureRecognizerTest(parameterized.TestCase): self._assert_actual_result_approximately_matches_expected_result( recognition_result, expected_recognition_result) + def test_recognize_succeeds_with_canned_gesture_victory(self): + # Creates gesture recognizer. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _GestureRecognizerOptions(base_options=base_options, num_hands=1) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Load the fist image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(_VICTORY_IMAGE)) + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(test_image) + expected_recognition_result = _get_expected_gesture_recognition_result( + _VICTORY_LANDMARKS, _VICTORY_LABEL) + # Comparing results. + self._assert_actual_result_approximately_matches_expected_result( + recognition_result, expected_recognition_result) + def test_recognize_succeeds_with_custom_gesture_fist(self): # Creates gesture recognizer. model_path = test_utils.get_test_data_path( From 2b5a07757997cdf85916d0aa023325023418e9bc Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 31 Oct 2022 05:48:45 -0700 Subject: [PATCH 13/34] Updated comments --- mediapipe/tasks/python/test/vision/gesture_recognizer_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py index 9e1b47355..e8aa61883 100644 --- a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py +++ b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py @@ -283,7 +283,7 @@ class GestureRecognizerTest(parameterized.TestCase): base_options = _BaseOptions(model_asset_path=self.model_path) options = _GestureRecognizerOptions(base_options=base_options, num_hands=1) with _GestureRecognizer.create_from_options(options) as recognizer: - # Load the fist image. + # Load the victory image. test_image = _Image.create_from_file( test_utils.get_test_data_path(_VICTORY_IMAGE)) # Performs hand gesture recognition on the input. From d3b472e888ae7b62b7dd921949b3e9db71c37303 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 31 Oct 2022 22:16:37 -0700 Subject: [PATCH 14/34] Add allow_list/deny_list support --- mediapipe/tasks/python/test/vision/BUILD | 1 + .../test/vision/gesture_recognizer_test.py | 111 ++++++++++++++++-- mediapipe/tasks/python/vision/BUILD | 3 +- .../tasks/python/vision/gesture_recognizer.py | 41 ++++--- 4 files changed, 127 insertions(+), 29 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 40afe22b8..da8ad3f83 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -88,6 +88,7 @@ py_test( "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:landmark", "//mediapipe/tasks/python/components/containers:landmark_detection_result", + "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:gesture_recognizer", diff --git a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py index e8aa61883..d5cd72cd7 100644 --- a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py +++ b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py @@ -27,6 +27,7 @@ from mediapipe.tasks.python.components.containers import rect as rect_module from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import landmark as landmark_module from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module +from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import gesture_recognizer @@ -40,6 +41,7 @@ _Category = category_module.Category _Landmark = landmark_module.Landmark _NormalizedLandmark = landmark_module.NormalizedLandmark _LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult +_ClassifierOptions = classifier_options.ClassifierOptions _Image = image_module.Image _GestureRecognizer = gesture_recognizer.GestureRecognizer _GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions @@ -59,10 +61,12 @@ _VICTORY_LABEL = 'Victory' _THUMB_UP_IMAGE = 'thumb_up.jpg' _THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt' _THUMB_UP_LABEL = 'Thumb_Up' +_POINTING_UP_IMAGE = 'pointing_up.jpg' +_POINTING_UP_LANDMARKS = 'pointing_up_landmarks.pbtxt' _POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg' -_POINTING_UP_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt' +_POINTING_UP_ROTATED_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt' _POINTING_UP_LABEL = 'Pointing_Up' -_ROCK_LABEL = "Rock" +_ROCK_LABEL = 'Rock' _LANDMARKS_ERROR_TOLERANCE = 0.03 _GESTURE_EXPECTED_INDEX = -1 @@ -227,11 +231,13 @@ class GestureRecognizerTest(parameterized.TestCase): self._assert_actual_result_approximately_matches_expected_result( recognition_result, expected_recognition_result) - def test_recognize_succeeds_with_min_gesture_confidence(self): + def test_recognize_succeeds_with_score_threshold(self): # Creates gesture recognizer. base_options = _BaseOptions(model_asset_path=self.model_path) - options = _GestureRecognizerOptions(base_options=base_options, - min_gesture_confidence=0.5) + canned_gesture_classifier_options = _ClassifierOptions(score_threshold=.5) + options = _GestureRecognizerOptions( + base_options=base_options, + canned_gesture_classifier_options=canned_gesture_classifier_options) with _GestureRecognizer.create_from_options(options) as recognizer: # Performs hand gesture recognition on the input. recognition_result = recognizer.recognize(self.test_image) @@ -273,7 +279,7 @@ class GestureRecognizerTest(parameterized.TestCase): recognition_result = recognizer.recognize(test_image, image_processing_options) expected_recognition_result = _get_expected_gesture_recognition_result( - _POINTING_UP_LANDMARKS, _POINTING_UP_LABEL) + _POINTING_UP_ROTATED_LANDMARKS, _POINTING_UP_LABEL) # Comparing results. self._assert_actual_result_approximately_matches_expected_result( recognition_result, expected_recognition_result) @@ -294,14 +300,14 @@ class GestureRecognizerTest(parameterized.TestCase): self._assert_actual_result_approximately_matches_expected_result( recognition_result, expected_recognition_result) - def test_recognize_succeeds_with_custom_gesture_fist(self): + def test_recognize_succeeds_with_custom_gesture_rock(self): # Creates gesture recognizer. model_path = test_utils.get_test_data_path( _GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) base_options = _BaseOptions(model_asset_path=model_path) options = _GestureRecognizerOptions(base_options=base_options, num_hands=1) with _GestureRecognizer.create_from_options(options) as recognizer: - # Load the fist image. + # Load the rock image. test_image = _Image.create_from_file( test_utils.get_test_data_path(_FIST_IMAGE)) # Performs hand gesture recognition on the input. @@ -312,6 +318,95 @@ class GestureRecognizerTest(parameterized.TestCase): self._assert_actual_result_approximately_matches_expected_result( recognition_result, expected_recognition_result) + def test_recognize_succeeds_with_allow_gesture_pointing_up(self): + # Creates gesture recognizer. + model_path = test_utils.get_test_data_path( + _GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) + base_options = _BaseOptions(model_asset_path=model_path) + canned_gesture_classifier_options = _ClassifierOptions( + category_allowlist=['Pointing_Up']) + options = _GestureRecognizerOptions( + base_options=base_options, + num_hands=1, + canned_gesture_classifier_options=canned_gesture_classifier_options) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Load the pointing up image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(_POINTING_UP_IMAGE)) + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(test_image) + expected_recognition_result = _get_expected_gesture_recognition_result( + _POINTING_UP_LANDMARKS, _POINTING_UP_LABEL) + # Comparing results. + self._assert_actual_result_approximately_matches_expected_result( + recognition_result, expected_recognition_result) + + def test_recognize_succeeds_with_deny_gesture_pointing_up(self): + # Creates gesture recognizer. + model_path = test_utils.get_test_data_path( + _GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) + base_options = _BaseOptions(model_asset_path=model_path) + canned_gesture_classifier_options = _ClassifierOptions( + category_denylist=['Pointing_Up']) + options = _GestureRecognizerOptions( + base_options=base_options, + num_hands=1, + canned_gesture_classifier_options=canned_gesture_classifier_options) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Load the pointing up image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(_POINTING_UP_IMAGE)) + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(test_image) + actual_top_gesture = recognition_result.gestures[0][0] + self.assertEqual(actual_top_gesture.category_name, 'None') + + def test_recognize_succeeds_with_allow_all_gestures_except_pointing_up(self): + # Creates gesture recognizer. + model_path = test_utils.get_test_data_path( + _GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) + base_options = _BaseOptions(model_asset_path=model_path) + canned_gesture_classifier_options = _ClassifierOptions( + score_threshold=.5, category_allowlist=[ + 'None', 'Open_Palm', 'Victory', 'Thumb_Down', 'Thumb_Up', + 'ILoveYou', 'Closed_Fist']) + options = _GestureRecognizerOptions( + base_options=base_options, + num_hands=1, + canned_gesture_classifier_options=canned_gesture_classifier_options) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Load the pointing up image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(_POINTING_UP_IMAGE)) + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(test_image) + actual_top_gesture = recognition_result.gestures[0][0] + self.assertEqual(actual_top_gesture.category_name, 'None') + + def test_recognize_succeeds_with_prefer_allow_list_than_deny_list(self): + # Creates gesture recognizer. + model_path = test_utils.get_test_data_path( + _GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) + base_options = _BaseOptions(model_asset_path=model_path) + canned_gesture_classifier_options = _ClassifierOptions( + score_threshold=.5, category_allowlist=['Pointing_Up'], + category_denylist=['Pointing_Up']) + options = _GestureRecognizerOptions( + base_options=base_options, + num_hands=1, + canned_gesture_classifier_options=canned_gesture_classifier_options) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Load the pointing up image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(_POINTING_UP_IMAGE)) + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(test_image) + expected_recognition_result = _get_expected_gesture_recognition_result( + _POINTING_UP_LANDMARKS, _POINTING_UP_LABEL) + # Comparing results. + self._assert_actual_result_approximately_matches_expected_result( + recognition_result, expected_recognition_result) + def test_recognize_fails_with_region_of_interest(self): # Creates gesture recognizer. base_options = _BaseOptions(model_asset_path=self.model_path) diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 0505471e8..dec149908 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -88,10 +88,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_py_pb2", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_py_pb2", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:landmark", + "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index 82dc00f19..2659f9a03 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -21,10 +21,9 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_recognizer_graph_options_pb2 -from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_recognizer_graph_options_pb2 -from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2 from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import landmark as landmark_module +from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -34,8 +33,7 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image _BaseOptions = base_options_module.BaseOptions _GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions -_HandGestureRecognizerGraphOptionsProto = hand_gesture_recognizer_graph_options_pb2.HandGestureRecognizerGraphOptions -_HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions +_ClassifierOptions = classifier_options.ClassifierOptions _RunningMode = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo @@ -137,11 +135,16 @@ class GestureRecognizerOptions: score in the hand landmark detection. min_tracking_confidence: The minimum confidence score for the hand tracking to be considered successful. - min_gesture_confidence: The minimum confidence score for the gestures to be - considered successful. If < 0, the gesture confidence thresholds in the - model metadata are used. - TODO: Note this option is subject to change, after scoring merging - calculator is implemented. + canned_gesture_classifier_options: Options for configuring the canned + gestures classifier, such as score threshold, allow list and deny list of + gestures. The categories for canned gesture classifiers are: + ["None", "Closed_Fist", "Open_Palm", "Pointing_Up", "Thumb_Down", + "Thumb_Up", "Victory", "ILoveYou"] + TODO :Note this option is subject to change. + custom_gesture_classifier_options: Options for configuring the custom + gestures classifier, such as score threshold, allow list and deny list of + gestures. + TODO :Note this option is subject to change. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. @@ -152,7 +155,8 @@ class GestureRecognizerOptions: min_hand_detection_confidence: Optional[float] = 0.5 min_hand_presence_confidence: Optional[float] = 0.5 min_tracking_confidence: Optional[float] = 0.5 - min_gesture_confidence: Optional[float] = -1 + canned_gesture_classifier_options: Optional[_ClassifierOptions] = _ClassifierOptions() + custom_gesture_classifier_options: Optional[_ClassifierOptions] = _ClassifierOptions() result_callback: Optional[ Callable[[GestureRecognitionResult, image_module.Image, int], None]] = None @@ -163,23 +167,22 @@ class GestureRecognizerOptions: base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True + # Initialize gesture recognizer options from base options. + gesture_recognizer_options_proto = _GestureRecognizerGraphOptionsProto( + base_options=base_options_proto) # Configure hand detector and hand landmarker options. - hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto() + hand_landmarker_options_proto = gesture_recognizer_options_proto.hand_landmarker_graph_options hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence # Configure hand gesture recognizer options. - hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto() - hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence - hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence + hand_gesture_recognizer_options_proto = gesture_recognizer_options_proto.hand_gesture_recognizer_graph_options + hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.CopyFrom(self.canned_gesture_classifier_options.to_pb2()) + hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.CopyFrom(self.custom_gesture_classifier_options.to_pb2()) - return _GestureRecognizerGraphOptionsProto( - base_options=base_options_proto, - hand_landmarker_graph_options=hand_landmarker_options_proto, - hand_gesture_recognizer_graph_options=hand_gesture_recognizer_options_proto - ) + return gesture_recognizer_options_proto class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): From a913255080b10692808a9a66edd10f1e490758af Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 31 Oct 2022 23:07:05 -0700 Subject: [PATCH 15/34] Removed min score thres from tests --- mediapipe/tasks/python/test/vision/gesture_recognizer_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py index d5cd72cd7..fb8ca6713 100644 --- a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py +++ b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py @@ -367,7 +367,7 @@ class GestureRecognizerTest(parameterized.TestCase): _GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) base_options = _BaseOptions(model_asset_path=model_path) canned_gesture_classifier_options = _ClassifierOptions( - score_threshold=.5, category_allowlist=[ + category_allowlist=[ 'None', 'Open_Palm', 'Victory', 'Thumb_Down', 'Thumb_Up', 'ILoveYou', 'Closed_Fist']) options = _GestureRecognizerOptions( @@ -389,7 +389,7 @@ class GestureRecognizerTest(parameterized.TestCase): _GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) base_options = _BaseOptions(model_asset_path=model_path) canned_gesture_classifier_options = _ClassifierOptions( - score_threshold=.5, category_allowlist=['Pointing_Up'], + category_allowlist=['Pointing_Up'], category_denylist=['Pointing_Up']) options = _GestureRecognizerOptions( base_options=base_options, From c5765ac8363b17557d5820c7c6f9f6942cde2492 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Tue, 1 Nov 2022 15:37:00 -0700 Subject: [PATCH 16/34] Refactored Rect to use top-left coordinates and appropriately updated the Image Classifier and Gesture Recognizer APIs/tests --- .../python/components/containers/rect.py | 73 ++++++------------- mediapipe/tasks/python/test/vision/BUILD | 1 + .../test/vision/gesture_recognizer_test.py | 2 +- .../test/vision/image_classifier_test.py | 30 ++++---- mediapipe/tasks/python/vision/BUILD | 2 + .../vision/core/base_vision_task_api.py | 14 ++-- .../vision/core/image_processing_options.py | 2 +- .../tasks/python/vision/gesture_recognizer.py | 2 +- .../tasks/python/vision/image_classifier.py | 49 ++++++------- 9 files changed, 75 insertions(+), 100 deletions(-) diff --git a/mediapipe/tasks/python/components/containers/rect.py b/mediapipe/tasks/python/components/containers/rect.py index 510561592..90e98fef4 100644 --- a/mediapipe/tasks/python/components/containers/rect.py +++ b/mediapipe/tasks/python/components/containers/rect.py @@ -19,75 +19,44 @@ from typing import Any, Optional from mediapipe.framework.formats import rect_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls -_RectProto = rect_pb2.Rect _NormalizedRectProto = rect_pb2.NormalizedRect @dataclasses.dataclass class Rect: - """A rectangle with rotation in image coordinates. + """A rectangle, used e.g. as part of detection results or as input + region-of-interest. - Attributes: x_center : The X coordinate of the top-left corner, in pixels. - y_center : The Y coordinate of the top-left corner, in pixels. - width: The width of the rectangle, in pixels. - height: The height of the rectangle, in pixels. - rotation: Rotation angle is clockwise in radians. - rect_id: Optional unique id to help associate different rectangles to each - other. + The coordinates are normalized wrt the image dimensions, i.e. generally in + [0,1] but they may exceed these bounds if describing a region overlapping the + image. The origin is on the top-left corner of the image. + + Attributes: + left: The X coordinate of the left side of the rectangle. + top: The Y coordinate of the top of the rectangle. + right: The X coordinate of the right side of the rectangle. + bottom: The Y coordinate of the bottom of the rectangle. """ - x_center: int - y_center: int - width: int - height: int - rotation: Optional[float] = 0.0 - rect_id: Optional[int] = None - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _RectProto: - """Generates a Rect protobuf object.""" - return _RectProto( - x_center=self.x_center, - y_center=self.y_center, - width=self.width, - height=self.height, - ) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _RectProto) -> 'Rect': - """Creates a `Rect` object from the given protobuf object.""" - return Rect( - x_center=pb2_obj.x_center, - y_center=pb2_obj.y_center, - width=pb2_obj.width, - height=pb2_obj.height) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, Rect): - return False - - return self.to_pb2().__eq__(other.to_pb2()) + left: float + top: float + right: float + bottom: float @dataclasses.dataclass class NormalizedRect: - """A rectangle with rotation in normalized coordinates. + """A rectangle with rotation in normalized coordinates. Location of the center + of the rectangle in image coordinates. The (0.0, 0.0) point is at the + (top, left) corner. The values of box center location and size are within [0, 1]. - Attributes: x_center : The X normalized coordinate of the top-left corner. - y_center : The Y normalized coordinate of the top-left corner. + Attributes: x_center: The normalized X coordinate of the rectangle, in + image coordinates. + y_center: The normalized Y coordinate of the rectangle, in image coordinates. width: The width of the rectangle. height: The height of the rectangle. rotation: Rotation angle is clockwise in radians. diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index da8ad3f83..4966ffd29 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -54,6 +54,7 @@ py_test( "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_classifier", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + "//mediapipe/tasks/python/vision/core:image_processing_options", ], ) diff --git a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py index fb8ca6713..e2fbcbcd5 100644 --- a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py +++ b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py @@ -78,7 +78,7 @@ def _get_expected_gesture_recognition_result( file_path) with open(landmarks_detection_result_file_path, "rb") as f: landmarks_detection_result_proto = _LandmarksDetectionResultProto() - # # Use this if a .pb file is available. + # Use this if a .pb file is available. # landmarks_detection_result_proto.ParseFromString(f.read()) text_format.Parse(f.read(), landmarks_detection_result_proto) landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2( diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index afaf921a7..e56bcdea0 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -29,8 +29,10 @@ from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_classifier from mediapipe.tasks.python.vision.core import vision_task_running_mode +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module -_NormalizedRect = rect.NormalizedRect + +_Rect = rect.Rect _BaseOptions = base_options_module.BaseOptions _ClassifierOptions = classifier_options.ClassifierOptions _Category = category.Category @@ -41,6 +43,7 @@ _Image = image.Image _ImageClassifier = image_classifier.ImageClassifier _ImageClassifierOptions = image_classifier.ImageClassifierOptions _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' _IMAGE_FILE = 'burger.jpg' @@ -226,11 +229,11 @@ class ImageClassifierTest(parameterized.TestCase): # Load the test image. test_image = _Image.create_from_file( test_utils.get_test_data_path('multi_objects.jpg')) - # NormalizedRect around the soccer ball. - roi = _NormalizedRect( - x_center=0.532, y_center=0.521, width=0.164, height=0.427) + # Region-of-interest around the soccer ball. + roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345) + image_processing_options = _ImageProcessingOptions(roi) # Performs image classification on the input. - image_result = classifier.classify(test_image, roi) + image_result = classifier.classify(test_image, image_processing_options) # Comparing results. _assert_proto_equals(image_result.to_pb2(), _generate_soccer_ball_results(0).to_pb2()) @@ -414,12 +417,12 @@ class ImageClassifierTest(parameterized.TestCase): # Load the test image. test_image = _Image.create_from_file( test_utils.get_test_data_path('multi_objects.jpg')) - # NormalizedRect around the soccer ball. - roi = _NormalizedRect( - x_center=0.532, y_center=0.521, width=0.164, height=0.427) + # Region-of-interest around the soccer ball. + roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345) + image_processing_options = _ImageProcessingOptions(roi) for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( - test_image, timestamp, roi) + test_image, timestamp, image_processing_options) self.assertEqual(classification_result, _generate_soccer_ball_results(timestamp)) @@ -486,9 +489,9 @@ class ImageClassifierTest(parameterized.TestCase): # Load the test image. test_image = _Image.create_from_file( test_utils.get_test_data_path('multi_objects.jpg')) - # NormalizedRect around the soccer ball. - roi = _NormalizedRect( - x_center=0.532, y_center=0.521, width=0.164, height=0.427) + # Region-of-interest around the soccer ball. + roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345) + image_processing_options = _ImageProcessingOptions(roi) observed_timestamp_ms = -1 def check_result(result: _ClassificationResult, output_image: _Image, @@ -508,7 +511,8 @@ class ImageClassifierTest(parameterized.TestCase): result_callback=check_result) with _ImageClassifier.create_from_options(options) as classifier: for timestamp in range(0, 300, 30): - classifier.classify_async(test_image, timestamp, roi) + classifier.classify_async(test_image, timestamp, + image_processing_options) if __name__ == '__main__': diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index dec149908..2b9b5201e 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -56,6 +56,7 @@ py_library( "//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/vision/core:base_vision_task_api", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + "//mediapipe/tasks/python/vision/core:image_processing_options", ], ) @@ -96,5 +97,6 @@ py_library( "//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/vision/core:base_vision_task_api", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + "//mediapipe/tasks/python/vision/core:image_processing_options", ], ) diff --git a/mediapipe/tasks/python/vision/core/base_vision_task_api.py b/mediapipe/tasks/python/vision/core/base_vision_task_api.py index be290c83c..86771adee 100644 --- a/mediapipe/tasks/python/vision/core/base_vision_task_api.py +++ b/mediapipe/tasks/python/vision/core/base_vision_task_api.py @@ -160,15 +160,15 @@ class BaseVisionTaskApi(object): if not roi_allowed: raise ValueError("This task doesn't support region-of-interest.") roi = options.region_of_interest - if roi.x_center >= roi.width or roi.y_center >= roi.height: + if roi.left >= roi.right or roi.top >= roi.bottom: raise ValueError( - "Expected Rect with x_center < width and y_center < height.") - if roi.x_center < 0 or roi.y_center < 0 or roi.width > 1 or roi.height > 1: + "Expected Rect with left < right and top < bottom.") + if roi.left < 0 or roi.top < 0 or roi.right > 1 or roi.bottom > 1: raise ValueError("Expected Rect values to be in [0,1].") - normalized_rect.x_center = roi.x_center + roi.width / 2.0 - normalized_rect.y_center = roi.y_center + roi.height / 2.0 - normalized_rect.width = roi.width - roi.x_center - normalized_rect.height = roi.height - roi.y_center + normalized_rect.x_center = (roi.left + roi.right) / 2.0 + normalized_rect.y_center = (roi.top + roi.bottom) / 2.0 + normalized_rect.width = roi.right - roi.left + normalized_rect.height = roi.bottom - roi.top return normalized_rect def close(self) -> None: diff --git a/mediapipe/tasks/python/vision/core/image_processing_options.py b/mediapipe/tasks/python/vision/core/image_processing_options.py index 1a519809c..fafde049e 100644 --- a/mediapipe/tasks/python/vision/core/image_processing_options.py +++ b/mediapipe/tasks/python/vision/core/image_processing_options.py @@ -30,7 +30,7 @@ class ImageProcessingOptions: Attributes: region_of_interest: The optional region-of-interest to crop from the image. If not specified, the full image is used. Coordinates must be in [0,1] - with 'x_center' < 'width' and 'y_center' < height. + with 'left' < 'right' and 'top' < 'bottom'. rotation_degress: The rotation to apply to the image (or cropped region-of-interest), in degrees clockwise. The rotation must be a multiple (positive or negative) of 90°. diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index 2659f9a03..33286b90b 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -63,7 +63,7 @@ class GestureRecognitionResult: Attributes: gestures: Recognized hand gestures of detected hands. Note that the index - of the gesture is always 0, because the raw indices from multiple gesture + of the gesture is always -1, because the raw indices from multiple gesture classifiers cannot consolidate to a meaningful index. handedness: Classification of handedness. hand_landmarks: Detected hand landmarks in normalized image coordinates. diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 7be5d5733..89e6775e2 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -31,12 +31,14 @@ from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import vision_task_running_mode +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module _NormalizedRect = rect.NormalizedRect _BaseOptions = base_options_module.BaseOptions _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions _ClassifierOptions = classifier_options.ClassifierOptions _RunningMode = vision_task_running_mode.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo _CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' @@ -44,17 +46,12 @@ _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' -_NORM_RECT_NAME = 'norm_rect_in' +_NORM_RECT_STREAM_NAME = 'norm_rect_in' _NORM_RECT_TAG = 'NORM_RECT' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 -def _build_full_image_norm_rect() -> _NormalizedRect: - # Builds a NormalizedRect covering the entire image. - return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1) - - @dataclasses.dataclass class ImageClassifierOptions: """Options for the image classifier task. @@ -156,7 +153,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): task_graph=_TASK_GRAPH_NAME, input_streams=[ ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), - ':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]), + ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ], output_streams=[ ':'.join([ @@ -171,17 +168,16 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): _RunningMode.LIVE_STREAM), options.running_mode, packets_callback if options.result_callback else None) - # TODO: Replace _NormalizedRect with ImageProcessingOption def classify( self, image: image_module.Image, - roi: Optional[_NormalizedRect] = None + image_processing_options: Optional[_ImageProcessingOptions] = None ) -> classifications.ClassificationResult: """Performs image classification on the provided MediaPipe Image. Args: image: MediaPipe Image. - roi: The region of interest. + image_processing_options: Options for image processing. Returns: A classification result object that contains a list of classifications. @@ -190,10 +186,11 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If image classification failed to run. """ - norm_rect = roi if roi is not None else _build_full_image_norm_rect() + normalized_rect = self.convert_to_normalized_rect(image_processing_options) output_packets = self._process_image_data({ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), - _NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()) + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2()) }) classification_result_proto = classifications_pb2.ClassificationResult() @@ -210,7 +207,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): self, image: image_module.Image, timestamp_ms: int, - roi: Optional[_NormalizedRect] = None + image_processing_options: Optional[_ImageProcessingOptions] = None ) -> classifications.ClassificationResult: """Performs image classification on the provided video frames. @@ -222,7 +219,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input video frame in milliseconds. - roi: The region of interest. + image_processing_options: Options for image processing. Returns: A classification result object that contains a list of classifications. @@ -231,13 +228,13 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If image classification failed to run. """ - norm_rect = roi if roi is not None else _build_full_image_norm_rect() + normalized_rect = self.convert_to_normalized_rect(image_processing_options) output_packets = self._process_video_data({ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), - _NORM_RECT_NAME: - packet_creator.create_proto(norm_rect.to_pb2()).at( + _NORM_RECT_STREAM_NAME: + packet_creator.create_proto(normalized_rect.to_pb2()).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) }) @@ -251,10 +248,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): for classification in classification_result_proto.classifications ]) - def classify_async(self, - image: image_module.Image, - timestamp_ms: int, - roi: Optional[_NormalizedRect] = None) -> None: + def classify_async( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None + ) -> None: """Sends live image data (an Image with a unique timestamp) to perform image classification. Only use this method when the ImageClassifier is created with the live @@ -275,18 +274,18 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input image in milliseconds. - roi: The region of interest. + image_processing_options: Options for image processing. Raises: ValueError: If the current input timestamp is smaller than what the image classifier has already processed. """ - norm_rect = roi if roi is not None else _build_full_image_norm_rect() + normalized_rect = self.convert_to_normalized_rect(image_processing_options) self._send_live_stream_data({ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), - _NORM_RECT_NAME: - packet_creator.create_proto(norm_rect.to_pb2()).at( + _NORM_RECT_STREAM_NAME: + packet_creator.create_proto(normalized_rect.to_pb2()).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) }) From 35a04522faecd3e82acedeeaaf747da2615820c4 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 3 Nov 2022 00:46:00 -0700 Subject: [PATCH 17/34] Moved the OutputType and Activation classes to ImageSegmenter's inner classes --- .../test/vision/image_segmenter_test.py | 4 ++-- .../tasks/python/vision/image_segmenter.py | 23 +++++++++---------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_segmenter_test.py b/mediapipe/tasks/python/test/vision/image_segmenter_test.py index dde751c1d..5072d3482 100644 --- a/mediapipe/tasks/python/test/vision/image_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/image_segmenter_test.py @@ -33,8 +33,8 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode _BaseOptions = base_options_module.BaseOptions _Image = image_module.Image _ImageFormat = image_frame.ImageFormat -_OutputType = image_segmenter.OutputType -_Activation = image_segmenter.Activation +_OutputType = image_segmenter.ImageSegmenterOptions.OutputType +_Activation = image_segmenter.ImageSegmenterOptions.Activation _ImageSegmenter = image_segmenter.ImageSegmenter _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index ebfeca1bc..c1b50a5ae 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -44,18 +44,6 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageSegmenterGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 -class OutputType(enum.Enum): - UNSPECIFIED = 0 - CATEGORY_MASK = 1 - CONFIDENCE_MASK = 2 - - -class Activation(enum.Enum): - NONE = 0 - SIGMOID = 1 - SOFTMAX = 2 - - @dataclasses.dataclass class ImageSegmenterOptions: """Options for the image segmenter task. @@ -74,6 +62,17 @@ class ImageSegmenterOptions: data. The result callback should only be specified when the running mode is set to the live stream mode. """ + + class OutputType(enum.Enum): + UNSPECIFIED = 0 + CATEGORY_MASK = 1 + CONFIDENCE_MASK = 2 + + class Activation(enum.Enum): + NONE = 0 + SIGMOID = 1 + SOFTMAX = 2 + base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE output_type: Optional[OutputType] = OutputType.CATEGORY_MASK From b472d8ff6682658813dc7a01670fbe7694f301c7 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 3 Nov 2022 09:22:35 -0700 Subject: [PATCH 18/34] Add the missing __init__.py file. PiperOrigin-RevId: 485892501 --- .../tasks/python/components/processors/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 mediapipe/tasks/python/components/processors/__init__.py diff --git a/mediapipe/tasks/python/components/processors/__init__.py b/mediapipe/tasks/python/components/processors/__init__.py new file mode 100644 index 000000000..65c1214af --- /dev/null +++ b/mediapipe/tasks/python/components/processors/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From d29c3d7512151af026b638ee3d76e097ad268b67 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 3 Nov 2022 11:28:56 -0700 Subject: [PATCH 19/34] Add metadata writer to image_classifier in model_maker PiperOrigin-RevId: 485926985 --- .../python/core/tasks/custom_model.py | 9 +-- .../python/core/utils/model_util.py | 46 +++++++------ .../python/core/utils/model_util_test.py | 26 ++++--- .../python/core/utils/test_util.py | 46 ++++++++++--- .../python/vision/image_classifier/BUILD | 8 ++- .../image_classifier/image_classifier.py | 34 ++++++++-- .../image_classifier/image_classifier_test.py | 17 ++++- .../vision/image_classifier/testdata/BUILD | 23 +++++++ .../image_classifier/testdata/metadata.json | 68 +++++++++++++++++++ mediapipe/tasks/python/test/BUILD | 2 +- 10 files changed, 225 insertions(+), 54 deletions(-) create mode 100644 mediapipe/model_maker/python/vision/image_classifier/testdata/BUILD create mode 100644 mediapipe/model_maker/python/vision/image_classifier/testdata/metadata.json diff --git a/mediapipe/model_maker/python/core/tasks/custom_model.py b/mediapipe/model_maker/python/core/tasks/custom_model.py index 66d1494db..188bf62cc 100644 --- a/mediapipe/model_maker/python/core/tasks/custom_model.py +++ b/mediapipe/model_maker/python/core/tasks/custom_model.py @@ -52,6 +52,7 @@ class CustomModel(abc.ABC): """Prints a summary of the model.""" self._model.summary() +# TODO: Remove this method when all tasks use Metadata writer def export_tflite( self, export_dir: str, @@ -62,7 +63,7 @@ class CustomModel(abc.ABC): Args: export_dir: The directory to save exported files. - tflite_filename: File name to save tflite model. The full export path is + tflite_filename: File name to save TFLite model. The full export path is {export_dir}/{tflite_filename}. quantization_config: The configuration for model quantization. preprocess: A callable to preprocess the representative dataset for @@ -73,11 +74,11 @@ class CustomModel(abc.ABC): tf.io.gfile.makedirs(export_dir) tflite_filepath = os.path.join(export_dir, tflite_filename) - # TODO: Populate metadata to the exported TFLite model. - model_util.export_tflite( + tflite_model = model_util.convert_to_tflite( model=self._model, - tflite_filepath=tflite_filepath, quantization_config=quantization_config, preprocess=preprocess) + model_util.save_tflite( + tflite_model=tflite_model, tflite_file=tflite_filepath) tf.compat.v1.logging.info( 'TensorFlow Lite model exported successfully: %s' % tflite_filepath) diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index e1228eb6e..02d4f5b1e 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -89,28 +89,25 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None, return len(train_data) // batch_size -def export_tflite( +def convert_to_tflite( model: tf.keras.Model, - tflite_filepath: str, quantization_config: Optional[quantization.QuantizationConfig] = None, supported_ops: Tuple[tf.lite.OpsSet, ...] = (tf.lite.OpsSet.TFLITE_BUILTINS,), - preprocess: Optional[Callable[..., bool]] = None): - """Converts the model to tflite format and saves it. + preprocess: Optional[Callable[..., bool]] = None) -> bytearray: + """Converts the input Keras model to TFLite format. Args: - model: model to be converted to tflite. - tflite_filepath: File path to save tflite model. + model: Keras model to be converted to TFLite. quantization_config: Configuration for post-training quantization. supported_ops: A list of supported ops in the converted TFLite file. preprocess: A callable to preprocess the representative dataset for quantization. The callable takes three arguments in order: feature, label, and is_training. - """ - if tflite_filepath is None: - raise ValueError( - "TFLite filepath couldn't be None when exporting to tflite.") + Returns: + bytearray of TFLite model + """ with tempfile.TemporaryDirectory() as temp_dir: save_path = os.path.join(temp_dir, 'saved_model') model.save(save_path, include_optimizer=False, save_format='tf') @@ -122,9 +119,22 @@ def export_tflite( converter.target_spec.supported_ops = supported_ops tflite_model = converter.convert() + return tflite_model - with tf.io.gfile.GFile(tflite_filepath, 'wb') as f: + +def save_tflite(tflite_model: bytearray, tflite_file: str) -> None: + """Saves TFLite file to tflite_file. + + Args: + tflite_model: A valid flatbuffer representing the TFLite model. + tflite_file: File path to save TFLite model. + """ + if tflite_file is None: + raise ValueError("TFLite filepath can't be None when exporting to TFLite.") + with tf.io.gfile.GFile(tflite_file, 'wb') as f: f.write(tflite_model) + tf.compat.v1.logging.info( + 'TensorFlow Lite model exported successfully to: %s' % tflite_file) class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): @@ -176,14 +186,12 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): class LiteRunner(object): """A runner to do inference with the TFLite model.""" - def __init__(self, tflite_filepath: str): - """Initializes Lite runner with tflite model file. + def __init__(self, tflite_model: bytearray): + """Initializes Lite runner from TFLite model buffer. Args: - tflite_filepath: File path to the TFLite model. + tflite_model: A valid flatbuffer representing the TFLite model. """ - with tf.io.gfile.GFile(tflite_filepath, 'rb') as f: - tflite_model = f.read() self.interpreter = tf.lite.Interpreter(model_content=tflite_model) self.interpreter.allocate_tensors() self.input_details = self.interpreter.get_input_details() @@ -250,9 +258,9 @@ class LiteRunner(object): return output_tensors -def get_lite_runner(tflite_filepath: str) -> 'LiteRunner': - """Returns a `LiteRunner` from file path to TFLite model.""" - lite_runner = LiteRunner(tflite_filepath) +def get_lite_runner(tflite_buffer: bytearray) -> 'LiteRunner': + """Returns a `LiteRunner` from flatbuffer of the TFLite model.""" + lite_runner = LiteRunner(tflite_buffer) return lite_runner diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index 35b52eb75..1f9e0f1db 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -95,13 +95,12 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): 'name': 'test' }) - def test_export_tflite(self): + def test_convert_to_tflite(self): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) - tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') - model_util.export_tflite(model, tflite_file) + tflite_model = model_util.convert_to_tflite(model) test_util.test_tflite( - keras_model=model, tflite_file=tflite_file, size=[1, input_dim]) + keras_model=model, tflite_model=tflite_model, size=[1, input_dim]) @parameterized.named_parameters( dict( @@ -118,25 +117,32 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): testcase_name='float16_quantize', config=quantization.QuantizationConfig.for_float16(), model_size=1468)) - def test_export_tflite_quantized(self, config, model_size): + def test_convert_to_tflite_quantized(self, config, model_size): input_dim = 16 num_classes = 2 max_input_value = 5 model = test_util.build_model( input_shape=[input_dim], num_classes=num_classes) - tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite') - model_util.export_tflite( - model=model, tflite_filepath=tflite_file, quantization_config=config) + tflite_model = model_util.convert_to_tflite( + model=model, quantization_config=config) self.assertTrue( test_util.test_tflite( keras_model=model, - tflite_file=tflite_file, + tflite_model=tflite_model, size=[1, input_dim], high=max_input_value, atol=1e-00)) - self.assertNear(os.path.getsize(tflite_file), model_size, 300) + self.assertNear(len(tflite_model), model_size, 300) + def test_save_tflite(self): + input_dim = 4 + model = test_util.build_model(input_shape=[input_dim], num_classes=2) + tflite_model = model_util.convert_to_tflite(model) + tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') + model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file) + test_util.test_tflite_file( + keras_model=model, tflite_file=tflite_file, size=[1, input_dim]) if __name__ == '__main__': tf.test.main() diff --git a/mediapipe/model_maker/python/core/utils/test_util.py b/mediapipe/model_maker/python/core/utils/test_util.py index b402d3793..14d02814e 100644 --- a/mediapipe/model_maker/python/core/utils/test_util.py +++ b/mediapipe/model_maker/python/core/utils/test_util.py @@ -79,13 +79,13 @@ def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model: return model -def is_same_output(tflite_file: str, +def is_same_output(tflite_model: bytearray, keras_model: tf.keras.Model, input_tensors: Union[List[tf.Tensor], tf.Tensor], atol: float = 1e-04) -> bool: """Returns if the output of TFLite model and keras model are identical.""" # Gets output from lite model. - lite_runner = model_util.get_lite_runner(tflite_file) + lite_runner = model_util.get_lite_runner(tflite_model) lite_output = lite_runner.run(input_tensors) # Gets output from keras model. @@ -95,12 +95,41 @@ def is_same_output(tflite_file: str, def test_tflite(keras_model: tf.keras.Model, - tflite_file: str, + tflite_model: bytearray, size: Union[int, List[int]], high: float = 1, atol: float = 1e-04) -> bool: """Verifies if the output of TFLite model and TF Keras model are identical. + Args: + keras_model: Input TensorFlow Keras model. + tflite_model: Input TFLite model flatbuffer. + size: Size of the input tesnor. + high: Higher boundary of the values in input tensors. + atol: Absolute tolerance of the difference between the outputs of Keras + model and TFLite model. + + Returns: + True if the output of TFLite model and TF Keras model are identical. + Otherwise, False. + """ + random_input = create_random_sample(size=size, high=high) + random_input = tf.convert_to_tensor(random_input) + + return is_same_output( + tflite_model=tflite_model, + keras_model=keras_model, + input_tensors=random_input, + atol=atol) + + +def test_tflite_file(keras_model: tf.keras.Model, + tflite_file: bytearray, + size: Union[int, List[int]], + high: float = 1, + atol: float = 1e-04) -> bool: + """Verifies if the output of TFLite model and TF Keras model are identical. + Args: keras_model: Input TensorFlow Keras model. tflite_file: Input TFLite model file. @@ -113,11 +142,6 @@ def test_tflite(keras_model: tf.keras.Model, True if the output of TFLite model and TF Keras model are identical. Otherwise, False. """ - random_input = create_random_sample(size=size, high=high) - random_input = tf.convert_to_tensor(random_input) - - return is_same_output( - tflite_file=tflite_file, - keras_model=keras_model, - input_tensors=random_input, - atol=atol) + with tf.io.gfile.GFile(tflite_file, "rb") as f: + tflite_model = f.read() + return test_tflite(keras_model, tflite_model, size, high, atol) diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index a2268059f..551c3777c 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -81,6 +81,8 @@ py_library( "//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:quantization", "//mediapipe/model_maker/python/vision/core:image_preprocessing", + "//mediapipe/tasks/python/metadata/metadata_writers:image_classifier", + "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", ], ) @@ -88,7 +90,11 @@ py_library( name = "image_classifier_test_lib", testonly = 1, srcs = ["image_classifier_test.py"], - deps = [":image_classifier_import"], + data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"], + deps = [ + ":image_classifier_import", + "//mediapipe/tasks/python/test:test_utils", + ], ) py_test( diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index a3282ebf9..61e7c7152 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """APIs to train image classifier model.""" +import os from typing import List, Optional @@ -26,6 +27,8 @@ from mediapipe.model_maker.python.vision.core import image_preprocessing from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib +from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer +from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer class ImageClassifier(classifier.Classifier): @@ -156,15 +159,32 @@ class ImageClassifier(classifier.Classifier): self, model_name: str = 'model.tflite', quantization_config: Optional[quantization.QuantizationConfig] = None): - """Converts the model to the requested formats and exports to a file. + """Converts and saves the model to a TFLite file with metadata included. + + Note that only the TFLite file is needed for deployment. This function also + saves a metadata.json file to the same directory as the TFLite file which + can be used to interpret the metadata content in the TFLite file. Args: - model_name: File name to save tflite model. The full export path is - {export_dir}/{tflite_filename}. + model_name: File name to save TFLite model with metadata. The full export + path is {self._hparams.model_dir}/{model_name}. quantization_config: The configuration for model quantization. """ - super().export_tflite( - self._hparams.model_dir, - model_name, - quantization_config, + if not tf.io.gfile.exists(self._hparams.model_dir): + tf.io.gfile.makedirs(self._hparams.model_dir) + tflite_file = os.path.join(self._hparams.model_dir, model_name) + metadata_file = os.path.join(self._hparams.model_dir, 'metadata.json') + + tflite_model = model_util.convert_to_tflite( + model=self._model, + quantization_config=quantization_config, preprocess=self._preprocess) + writer = image_classifier_writer.MetadataWriter.create( + tflite_model, + self._model_spec.mean_rgb, + self._model_spec.stddev_rgb, + labels=metadata_writer.Labels().add(self._label_names)) + tflite_model_with_metadata, metadata_json = writer.populate() + model_util.save_tflite(tflite_model_with_metadata, tflite_file) + with open(metadata_file, 'w') as f: + f.write(metadata_json) diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 8ed6de7ad..2f949e648 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import filecmp import os from absl.testing import parameterized @@ -19,6 +20,7 @@ import numpy as np import tensorflow as tf from mediapipe.model_maker.python.vision import image_classifier +from mediapipe.tasks.python.test import test_utils def _fill_image(rgb, image_size): @@ -86,7 +88,7 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): validation_data=self.test_data) self._test_accuracy(model) - def test_efficientnetlite0_model_with_model_maker_retraining_lib(self): + def test_efficientnetlite0_model_train_and_export(self): hparams = image_classifier.HParams( train_epochs=1, batch_size=1, shuffle=True) model = image_classifier.ImageClassifier.create( @@ -96,6 +98,19 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): validation_data=self.test_data) self._test_accuracy(model) + # Test export_model + model.export_model() + output_metadata_file = os.path.join(hparams.model_dir, 'metadata.json') + output_tflite_file = os.path.join(hparams.model_dir, 'model.tflite') + expected_metadata_file = test_utils.get_test_data_path('metadata.json') + + self.assertTrue(os.path.exists(output_tflite_file)) + self.assertGreater(os.path.getsize(output_tflite_file), 0) + + self.assertTrue(os.path.exists(output_metadata_file)) + self.assertGreater(os.path.getsize(output_metadata_file), 0) + self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file)) + def _test_accuracy(self, model, threshold=0.0): _, accuracy = model.evaluate(self.test_data) self.assertGreaterEqual(accuracy, threshold) diff --git a/mediapipe/model_maker/python/vision/image_classifier/testdata/BUILD b/mediapipe/model_maker/python/vision/image_classifier/testdata/BUILD new file mode 100644 index 000000000..37730ea91 --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/testdata/BUILD @@ -0,0 +1,23 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//mediapipe/model_maker/python/vision/image_classifier:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "testdata", + srcs = ["metadata.json"], +) diff --git a/mediapipe/model_maker/python/vision/image_classifier/testdata/metadata.json b/mediapipe/model_maker/python/vision/image_classifier/testdata/metadata.json new file mode 100644 index 000000000..43d4e6d6c --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/testdata/metadata.json @@ -0,0 +1,68 @@ +{ + "name": "ImageClassifier", + "description": "Identify the most prominent object in the image from a known set of categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be processed.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 0.0 + ], + "std": [ + 255.0 + ] + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "Score of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ] + } + ], + "min_parser_version": "1.0.0" +} diff --git a/mediapipe/tasks/python/test/BUILD b/mediapipe/tasks/python/test/BUILD index 5ad057983..fb608f123 100644 --- a/mediapipe/tasks/python/test/BUILD +++ b/mediapipe/tasks/python/test/BUILD @@ -24,7 +24,7 @@ py_library( srcs = ["test_utils.py"], srcs_version = "PY3", visibility = [ - "//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__", + "//mediapipe/model_maker/python:__subpackages__", "//mediapipe/tasks:internal", ], deps = [ From 06cb73fc81d09e81632318bf2ecdb68e4fc1a894 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 3 Nov 2022 15:42:59 -0700 Subject: [PATCH 20/34] Modified internal dependencies. PiperOrigin-RevId: 485992876 --- .../tensor/inference_calculator_gl_advanced.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index 52359f7f5..ad5df849f 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -25,12 +25,12 @@ #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/util/tflite/tflite_gpu_runner.h" -#if defined(MEDIAPIPE_ANDROID) +#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS) #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/util/android/file/base/file.h" #include "mediapipe/util/android/file/base/filesystem.h" #include "mediapipe/util/android/file/base/helpers.h" -#endif // MEDIAPIPE_ANDROID +#endif // defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS) #define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe @@ -231,7 +231,7 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner( return tflite_gpu_runner_->Build(); } -#if defined(MEDIAPIPE_ANDROID) +#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS) absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init( const mediapipe::InferenceCalculatorOptions& options, const mediapipe::InferenceCalculatorOptions::Delegate::Gpu& @@ -318,7 +318,7 @@ InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches( tflite::gpu::TFLiteGPURunner* gpu_runner) const { return absl::OkStatus(); } -#endif // MEDIAPIPE_ANDROID +#endif // defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS) absl::Status InferenceCalculatorGlAdvancedImpl::UpdateContract( CalculatorContract* cc) { From b2cc2cc60c34efbc36c1a2a14cc8eab0af4dd424 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 3 Nov 2022 17:21:24 -0700 Subject: [PATCH 21/34] Internal changes PiperOrigin-RevId: 486014305 --- mediapipe/model_maker/python/internal/README.md | 4 ---- mediapipe/model_maker/python/internal/__init__.py | 1 - 2 files changed, 5 deletions(-) delete mode 100644 mediapipe/model_maker/python/internal/README.md delete mode 100644 mediapipe/model_maker/python/internal/__init__.py diff --git a/mediapipe/model_maker/python/internal/README.md b/mediapipe/model_maker/python/internal/README.md deleted file mode 100644 index 100d6a520..000000000 --- a/mediapipe/model_maker/python/internal/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# MediaPipe Model Maker Internal Library - -This directory contains model maker library for internal users and experimental -purposes. diff --git a/mediapipe/model_maker/python/internal/__init__.py b/mediapipe/model_maker/python/internal/__init__.py deleted file mode 100644 index 05f41d8a4..000000000 --- a/mediapipe/model_maker/python/internal/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Model maker internal library.""" From 5f5f50d8f72d5ca4ee4de26a1aa42a7d2d3ca506 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 4 Nov 2022 08:30:30 -0700 Subject: [PATCH 22/34] Implement MediaPipe Tasks Python AudioData. PiperOrigin-RevId: 486147173 --- .../tasks/python/components/containers/BUILD | 5 + .../components/containers/audio_data.py | 109 ++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 mediapipe/tasks/python/components/containers/audio_data.py diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 20ee501cc..91e115476 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -18,6 +18,11 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +py_library( + name = "audio_data", + srcs = ["audio_data.py"], +) + py_library( name = "bounding_box", srcs = ["bounding_box.py"], diff --git a/mediapipe/tasks/python/components/containers/audio_data.py b/mediapipe/tasks/python/components/containers/audio_data.py new file mode 100644 index 000000000..21b606079 --- /dev/null +++ b/mediapipe/tasks/python/components/containers/audio_data.py @@ -0,0 +1,109 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe audio data.""" + +import dataclasses +from typing import Optional + +import numpy as np + + +@dataclasses.dataclass +class AudioFormat: + """Audio format metadata. + + Attributes: + num_channels: the number of channels of the audio data. + sample_rate: the audio sample rate. + """ + num_channels: int = 1 + sample_rate: Optional[float] = None + + +class AudioData(object): + """MediaPipe Tasks' audio container.""" + + def __init__( + self, buffer_length: int, + audio_format: AudioFormat = AudioFormat()) -> None: + """Initializes the `AudioData` object. + + Args: + buffer_length: the length of the audio buffer. + audio_format: the audio format metadata. + """ + self._audio_format = audio_format + self._buffer = np.zeros([buffer_length, self._audio_format.num_channels], + dtype=np.float32) + + def clear(self): + """Clears the internal buffer and fill it with zeros.""" + self._buffer.fill(0) + + def load_from_array(self, + src: np.ndarray, + offset: int = 0, + size: int = -1) -> None: + """Loads the audio data from a NumPy array. + + Args: + src: A NumPy source array contains the input audio. + offset: An optional offset for loading a slice of the `src` array to the + buffer. + size: An optional size parameter denoting the number of samples to load + from the `src` array. + + Raises: + ValueError: If the input array has an incorrect shape or if + `offset` + `size` exceeds the length of the `src` array. + """ + if src.shape[1] != self._audio_format.num_channels: + raise ValueError(f"Input audio contains an invalid number of channels. " + f"Expect {self._audio_format.num_channels}.") + + if size < 0: + size = len(src) + + if offset + size > len(src): + raise ValueError( + f"Index out of range. offset {offset} + size {size} should be <= " + f"src's length: {len(src)}") + + if len(src) >= len(self._buffer): + # If the internal buffer is shorter than the load target (src), copy + # values from the end of the src array to the internal buffer. + new_offset = offset + size - len(self._buffer) + new_size = len(self._buffer) + self._buffer = src[new_offset:new_offset + new_size].copy() + else: + # Shift the internal buffer backward and add the incoming data to the end + # of the buffer. + shift = size + self._buffer = np.roll(self._buffer, -shift, axis=0) + self._buffer[-shift:, :] = src[offset:offset + size].copy() + + @property + def audio_format(self) -> AudioFormat: + """Gets the audio format of the audio.""" + return self._audio_format + + @property + def buffer_length(self) -> int: + """Gets the sample count of the audio.""" + return self._buffer.shape[0] + + @property + def buffer(self) -> np.ndarray: + """Gets the internal buffer.""" + return self._buffer From 62cd67e99609ed919285d7a4b9a3dbde992821f8 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 4 Nov 2022 08:36:31 -0700 Subject: [PATCH 23/34] Internal change PiperOrigin-RevId: 486148328 --- mediapipe/tasks/cc/components/BUILD | 1 + mediapipe/tasks/cc/components/image_preprocessing.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD index 344fafb4e..4b5439035 100644 --- a/mediapipe/tasks/cc/components/BUILD +++ b/mediapipe/tasks/cc/components/BUILD @@ -40,6 +40,7 @@ cc_library( "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", diff --git a/mediapipe/tasks/cc/components/image_preprocessing.cc b/mediapipe/tasks/cc/components/image_preprocessing.cc index 7940080e1..ef447df97 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.cc +++ b/mediapipe/tasks/cc/components/image_preprocessing.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/calculators/image/image_clone_calculator.pb.h" #include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator_framework.h" From 93a587a42255c4526e1e0550455dece869b387b6 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 4 Nov 2022 08:44:10 -0700 Subject: [PATCH 24/34] Make the python documentation clear to resolve https://github.com/google/mediapipe/issues/3805. PiperOrigin-RevId: 486149904 --- docs/getting_started/python.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/getting_started/python.md b/docs/getting_started/python.md index 289988e55..880d5c85d 100644 --- a/docs/getting_started/python.md +++ b/docs/getting_started/python.md @@ -141,3 +141,4 @@ Nvidia Jetson and Raspberry Pi, please read ```bash (mp_env)mediapipe$ python3 setup.py bdist_wheel ``` +7. Exit from the MediaPipe repo directory and launch the Python interpreter. From 8b2c937b9e573838ebd1ecf5ddd95f6ee9b882f8 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 4 Nov 2022 09:44:07 -0700 Subject: [PATCH 25/34] Migrate AudioClassifier C++ to use new ClassificationResult struct. PiperOrigin-RevId: 486162683 --- .../tasks/cc/audio/audio_classifier/BUILD | 1 + .../audio_classifier/audio_classifier.cc | 45 +++- .../audio/audio_classifier/audio_classifier.h | 70 +++--- .../audio_classifier_graph.cc | 48 ++-- .../audio_classifier/audio_classifier_test.cc | 236 +++++++++--------- 5 files changed, 222 insertions(+), 178 deletions(-) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index ac238bfda..501a9e6fd 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -65,6 +65,7 @@ cc_library( "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", "//mediapipe/tasks/cc/audio/core:base_audio_task_api", "//mediapipe/tasks/cc/audio/core:running_mode", + "//mediapipe/tasks/cc/components/containers:classification_result", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc index 702d802c5..3b01ddb88 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc @@ -18,12 +18,14 @@ limitations under the License. #include #include #include +#include #include "absl/status/statusor.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" @@ -38,12 +40,16 @@ namespace audio_classifier { namespace { +using ::mediapipe::tasks::components::containers::ConvertToClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr char kAudioStreamName[] = "audio_in"; constexpr char kAudioTag[] = "AUDIO"; -constexpr char kClassificationResultStreamName[] = "classification_result_out"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; +constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; +constexpr char kClassificationsName[] = "classifications_out"; +constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; +constexpr char kTimestampedClassificationsName[] = + "timestamped_classifications_out"; constexpr char kSampleRateName[] = "sample_rate_in"; constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kSubgraphTypeName[] = @@ -63,9 +69,11 @@ CalculatorGraphConfig CreateGraphConfig( } subgraph.GetOptions().Swap( options_proto.get()); - subgraph.Out(kClassificationResultTag) - .SetName(kClassificationResultStreamName) >> - graph.Out(kClassificationResultTag); + subgraph.Out(kClassificationsTag).SetName(kClassificationsName) >> + graph.Out(kClassificationsTag); + subgraph.Out(kTimestampedClassificationsTag) + .SetName(kTimestampedClassificationsName) >> + graph.Out(kTimestampedClassificationsTag); return graph.GetConfig(); } @@ -91,13 +99,30 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) { return options_proto; } -absl::StatusOr ConvertOutputPackets( +absl::StatusOr> ConvertOutputPackets( absl::StatusOr status_or_packets) { if (!status_or_packets.ok()) { return status_or_packets.status(); } - return status_or_packets.value()[kClassificationResultStreamName] - .Get(); + auto classification_results = + status_or_packets.value()[kTimestampedClassificationsName] + .Get>(); + std::vector results; + results.reserve(classification_results.size()); + for (const auto& classification_result : classification_results) { + results.emplace_back(ConvertToClassificationResult(classification_result)); + } + return results; +} + +absl::StatusOr ConvertAsyncOutputPackets( + absl::StatusOr status_or_packets) { + if (!status_or_packets.ok()) { + return status_or_packets.status(); + } + return ConvertToClassificationResult( + status_or_packets.value()[kClassificationsName] + .Get()); } } // namespace @@ -118,7 +143,7 @@ absl::StatusOr> AudioClassifier::Create( auto result_callback = options->result_callback; packets_callback = [=](absl::StatusOr status_or_packets) { - result_callback(ConvertOutputPackets(status_or_packets)); + result_callback(ConvertAsyncOutputPackets(status_or_packets)); }; } return core::AudioTaskApiFactory::Create> AudioClassifier::Create( std::move(packets_callback)); } -absl::StatusOr AudioClassifier::Classify( +absl::StatusOr> AudioClassifier::Classify( Matrix audio_clip, double audio_sample_rate) { return ConvertOutputPackets(ProcessAudioClip( {{kAudioStreamName, MakePacket(std::move(audio_clip))}, diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h index 200cffb8c..4b5d2c04b 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h @@ -18,12 +18,13 @@ limitations under the License. #include #include +#include #include "absl/status/statusor.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h" -#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/core/base_options.h" @@ -32,6 +33,10 @@ namespace tasks { namespace audio { namespace audio_classifier { +// Alias the shared ClassificationResult struct as result type. +using AudioClassifierResult = + ::mediapipe::tasks::components::containers::ClassificationResult; + // The options for configuring a mediapipe audio classifier task. struct AudioClassifierOptions { // Base options for configuring Task library, such as specifying the TfLite @@ -59,9 +64,8 @@ struct AudioClassifierOptions { // The user-defined result callback for processing audio stream data. // The result callback should only be specified when the running mode is set // to RunningMode::AUDIO_STREAM. - std::function)> - result_callback = nullptr; + std::function)> result_callback = + nullptr; }; // Performs audio classification on audio clips or audio stream. @@ -117,23 +121,36 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi { // required to provide the corresponding audio sample rate along with the // input audio clips. // - // For each audio clip, the output classifications are grouped in a - // ClassificationResult object that has three dimensions: - // Classification head: - // The prediction heads targeting different audio classification tasks - // such as audio event classification and bird sound classification. - // Classification timestamp: - // The start time (in milliseconds) of each audio clip that is sent to the - // model for audio classification. As the audio classification models take - // a fixed number of audio samples, long audio clips will be framed to - // multiple buffers (with the desired number of audio samples) during - // preprocessing. - // Classification category: - // The list of the classification categories that model predicts per - // framed audio clip. + // The input audio clip may be longer than what the model is able to process + // in a single inference. When this occurs, the input audio clip is split into + // multiple chunks starting at different timestamps. For this reason, this + // function returns a vector of ClassificationResult objects, each associated + // with a timestamp corresponding to the start (in milliseconds) of the chunk + // data that was classified, e.g: + // + // ClassificationResult #0 (first chunk of data): + // timestamp_ms: 0 (starts at 0ms) + // classifications #0 (single head model): + // category #0: + // category_name: "Speech" + // score: 0.6 + // category #1: + // category_name: "Music" + // score: 0.2 + // ClassificationResult #1 (second chunk of data): + // timestamp_ms: 800 (starts at 800ms) + // classifications #0 (single head model): + // category #0: + // category_name: "Speech" + // score: 0.5 + // category #1: + // category_name: "Silence" + // score: 0.1 + // ... + // // TODO: Use `sample_rate` in AudioClassifierOptions by default // and makes `audio_sample_rate` optional. - absl::StatusOr Classify( + absl::StatusOr> Classify( mediapipe::Matrix audio_clip, double audio_sample_rate); // Sends audio data (a block in a continuous audio stream) to perform audio @@ -147,17 +164,10 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi { // milliseconds) to indicate the start time of the input audio block. The // timestamps must be monotonically increasing. // - // The output classifications are grouped in a ClassificationResult object - // that has three dimensions: - // Classification head: - // The prediction heads targeting different audio classification tasks - // such as audio event classification and bird sound classification. - // Classification timestamp : - // The start time (in milliseconds) of the framed audio block that is sent - // to the model for audio classification. - // Classification category: - // The list of the classification categories that model predicts per - // framed audio clip. + // The input audio block may be longer than what the model is able to process + // in a single inference. When this occurs, the input audio block is split + // into multiple chunks. For this reason, the callback may be called multiple + // times (once per chunk) for each call to this function. absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms); // Shuts down the AudioClassifier when all works are done. diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc index 12f8ce31a..2b75209bb 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -57,12 +58,20 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr char kAtPrestreamTag[] = "AT_PRESTREAM"; constexpr char kAudioTag[] = "AUDIO"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; +constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; +constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; constexpr char kPacketTag[] = "PACKET"; constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTimestampsTag[] = "TIMESTAMPS"; +// Struct holding the different output streams produced by the audio classifier +// graph. +struct AudioClassifierOutputStreams { + Source classifications; + Source> timestamped_classifications; +}; + absl::Status SanityCheckOptions( const proto::AudioClassifierGraphOptions& options) { if (options.base_options().use_stream_mode() && @@ -124,16 +133,20 @@ void ConfigureAudioToTensorCalculator( // series stream header with sample rate info. // // Outputs: -// CLASSIFICATION_RESULT - ClassificationResult -// The aggregated classification result object that has 3 dimensions: -// (classification head, classification timestamp, classification category). +// CLASSIFICATIONS - ClassificationResult @Optional +// The classification results aggregated by head. Only produces results if +// the graph if the 'use_stream_mode' option is true. +// TIMESTAMPED_CLASSIFICATIONS - std::vector @Optional +// The classification result aggregated by timestamp, then by head. Only +// produces results if the graph if the 'use_stream_mode' option is false. // // Example: // node { // calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph" // input_stream: "AUDIO:audio_in" // input_stream: "SAMPLE_RATE:sample_rate_in" -// output_stream: "CLASSIFICATION_RESULT:classification_result_out" +// output_stream: "CLASSIFICATIONS:classifications" +// output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications" // options { // [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext] // { @@ -162,7 +175,7 @@ class AudioClassifierGraph : public core::ModelTaskGraph { .base_options() .use_stream_mode(); ASSIGN_OR_RETURN( - auto classification_result_out, + auto output_streams, BuildAudioClassificationTask( sc->Options(), *model_resources, graph[Input(kAudioTag)], @@ -170,8 +183,11 @@ class AudioClassifierGraph : public core::ModelTaskGraph { ? absl::nullopt : absl::make_optional(graph[Input(kSampleRateTag)]), graph)); - classification_result_out >> - graph[Output(kClassificationResultTag)]; + output_streams.classifications >> + graph[Output(kClassificationsTag)]; + output_streams.timestamped_classifications >> + graph[Output>( + kTimestampedClassificationsTag)]; return graph.GetConfig(); } @@ -187,7 +203,7 @@ class AudioClassifierGraph : public core::ModelTaskGraph { // audio_in: (mediapipe::Matrix) stream to run audio classification on. // sample_rate_in: (double) optional stream of the input audio sample rate. // graph: the mediapipe builder::Graph instance to be updated. - absl::StatusOr> BuildAudioClassificationTask( + absl::StatusOr BuildAudioClassificationTask( const proto::AudioClassifierGraphOptions& task_options, const core::ModelResources& model_resources, Source audio_in, absl::optional> sample_rate_in, Graph& graph) { @@ -250,16 +266,20 @@ class AudioClassifierGraph : public core::ModelTaskGraph { inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Time aggregation is only needed for performing audio classification on - // audio files. Disables time aggregration by not connecting the + // audio files. Disables timestamp aggregation by not connecting the // "TIMESTAMPS" streams. if (!use_stream_mode) { audio_to_tensor.Out(kTimestampsTag) >> postprocessing.In(kTimestampsTag); } - // Outputs the aggregated classification result as the subgraph output - // stream. - return postprocessing[Output( - kClassificationResultTag)]; + // Output both streams as graph output streams/ + return AudioClassifierOutputStreams{ + /*classifications=*/postprocessing[Output( + kClassificationsTag)], + /*timestamped_classifications=*/ + postprocessing[Output>( + kTimestampedClassificationsTag)], + }; } }; diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc index 591d5e4f7..a4fe5e32e 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc @@ -32,13 +32,11 @@ limitations under the License. #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" -#include "mediapipe/framework/port/parse_text_proto.h" -#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/utils/test_utils.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" -#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" namespace mediapipe { @@ -49,7 +47,6 @@ namespace { using ::absl::StatusOr; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::testing::HasSubstr; using ::testing::Optional; @@ -73,95 +70,86 @@ Matrix GetAudioData(absl::string_view filename) { return matrix_mapping.matrix(); } -void CheckSpeechClassificationResult(const ClassificationResult& result) { - EXPECT_THAT(result.classifications_size(), testing::Eq(1)); - EXPECT_EQ(result.classifications(0).head_name(), "scores"); - EXPECT_EQ(result.classifications(0).head_index(), 0); - EXPECT_THAT(result.classifications(0).entries_size(), testing::Eq(5)); +void CheckSpeechResult(const std::vector& result, + int expected_num_categories = 521) { + EXPECT_EQ(result.size(), 5); + // Ignore last result, which operates on a too small chunk to return relevant + // results. std::vector timestamps_ms = {0, 975, 1950, 2925}; for (int i = 0; i < timestamps_ms.size(); i++) { - EXPECT_THAT(result.classifications(0).entries(0).categories_size(), - testing::Eq(521)); - const auto* top_category = - &result.classifications(0).entries(0).categories(0); - EXPECT_THAT(top_category->category_name(), testing::Eq("Speech")); - EXPECT_GT(top_category->score(), 0.9f); - EXPECT_EQ(result.classifications(0).entries(i).timestamp_ms(), - timestamps_ms[i]); + EXPECT_EQ(result[i].timestamp_ms, timestamps_ms[i]); + EXPECT_EQ(result[i].classifications.size(), 1); + auto classifications = result[i].classifications[0]; + EXPECT_EQ(classifications.head_index, 0); + EXPECT_EQ(classifications.head_name, "scores"); + EXPECT_EQ(classifications.categories.size(), expected_num_categories); + auto category = classifications.categories[0]; + EXPECT_EQ(category.index, 0); + EXPECT_EQ(category.category_name, "Speech"); + EXPECT_GT(category.score, 0.9f); } } -void CheckTwoHeadsClassificationResult(const ClassificationResult& result) { - EXPECT_THAT(result.classifications_size(), testing::Eq(2)); - // Checks classification head #1. - EXPECT_EQ(result.classifications(0).head_name(), "yamnet_classification"); - EXPECT_EQ(result.classifications(0).head_index(), 0); - EXPECT_THAT(result.classifications(0).entries(0).categories_size(), - testing::Eq(521)); - const auto* top_category = - &result.classifications(0).entries(0).categories(0); - EXPECT_THAT(top_category->category_name(), - testing::Eq("Environmental noise")); - EXPECT_GT(top_category->score(), 0.5f); - EXPECT_EQ(result.classifications(0).entries(0).timestamp_ms(), 0); - if (result.classifications(0).entries_size() == 2) { - top_category = &result.classifications(0).entries(1).categories(0); - EXPECT_THAT(top_category->category_name(), testing::Eq("Silence")); - EXPECT_GT(top_category->score(), 0.99f); - EXPECT_EQ(result.classifications(0).entries(1).timestamp_ms(), 975); +void CheckTwoHeadsResult(const std::vector& result) { + EXPECT_GE(result.size(), 1); + EXPECT_LE(result.size(), 2); + // Check first result. + EXPECT_EQ(result[0].timestamp_ms, 0); + EXPECT_EQ(result[0].classifications.size(), 2); + // Check first head. + EXPECT_EQ(result[0].classifications[0].head_index, 0); + EXPECT_EQ(result[0].classifications[0].head_name, "yamnet_classification"); + EXPECT_EQ(result[0].classifications[0].categories.size(), 521); + EXPECT_EQ(result[0].classifications[0].categories[0].index, 508); + EXPECT_EQ(result[0].classifications[0].categories[0].category_name, + "Environmental noise"); + EXPECT_GT(result[0].classifications[0].categories[0].score, 0.5f); + // Check second head. + EXPECT_EQ(result[0].classifications[1].head_index, 1); + EXPECT_EQ(result[0].classifications[1].head_name, "bird_classification"); + EXPECT_EQ(result[0].classifications[1].categories.size(), 5); + EXPECT_EQ(result[0].classifications[1].categories[0].index, 4); + EXPECT_EQ(result[0].classifications[1].categories[0].category_name, + "Chestnut-crowned Antpitta"); + EXPECT_GT(result[0].classifications[1].categories[0].score, 0.9f); + // Check second result, if present. + if (result.size() == 2) { + EXPECT_EQ(result[1].timestamp_ms, 975); + EXPECT_EQ(result[1].classifications.size(), 2); + // Check first head. + EXPECT_EQ(result[1].classifications[0].head_index, 0); + EXPECT_EQ(result[1].classifications[0].head_name, "yamnet_classification"); + EXPECT_EQ(result[1].classifications[0].categories.size(), 521); + EXPECT_EQ(result[1].classifications[0].categories[0].index, 494); + EXPECT_EQ(result[1].classifications[0].categories[0].category_name, + "Silence"); + EXPECT_GT(result[1].classifications[0].categories[0].score, 0.99f); + // Check second head. + EXPECT_EQ(result[1].classifications[1].head_index, 1); + EXPECT_EQ(result[1].classifications[1].head_name, "bird_classification"); + EXPECT_EQ(result[1].classifications[1].categories.size(), 5); + EXPECT_EQ(result[1].classifications[1].categories[0].index, 1); + EXPECT_EQ(result[1].classifications[1].categories[0].category_name, + "White-breasted Wood-Wren"); + EXPECT_GT(result[1].classifications[1].categories[0].score, 0.99f); } - // Checks classification head #2. - EXPECT_EQ(result.classifications(1).head_name(), "bird_classification"); - EXPECT_EQ(result.classifications(1).head_index(), 1); - EXPECT_THAT(result.classifications(1).entries(0).categories_size(), - testing::Eq(5)); - top_category = &result.classifications(1).entries(0).categories(0); - EXPECT_THAT(top_category->category_name(), - testing::Eq("Chestnut-crowned Antpitta")); - EXPECT_GT(top_category->score(), 0.9f); - EXPECT_EQ(result.classifications(1).entries(0).timestamp_ms(), 0); } -ClassificationResult GenerateSpeechClassificationResult() { - return ParseTextProtoOrDie( - R"pb(classifications { - head_index: 0 - head_name: "scores" - entries { - categories { index: 0 score: 0.94140625 category_name: "Speech" } - timestamp_ms: 0 - } - entries { - categories { index: 0 score: 0.9921875 category_name: "Speech" } - timestamp_ms: 975 - } - entries { - categories { index: 0 score: 0.98828125 category_name: "Speech" } - timestamp_ms: 1950 - } - entries { - categories { index: 0 score: 0.99609375 category_name: "Speech" } - timestamp_ms: 2925 - } - entries { - # categories are filtered out due to the low scores. - timestamp_ms: 3900 - } - })pb"); -} - -void CheckStreamingModeClassificationResult( - std::vector outputs) { - ASSERT_TRUE(outputs.size() == 5 || outputs.size() == 6); - auto expected_results = GenerateSpeechClassificationResult(); - for (int i = 0; i < outputs.size() - 1; ++i) { - EXPECT_THAT(outputs[i].classifications(0).entries(0), - EqualsProto(expected_results.classifications(0).entries(i))); +void CheckStreamingModeResults(std::vector outputs) { + EXPECT_EQ(outputs.size(), 5); + // Ignore last result, which operates on a too small chunk to return relevant + // results. + for (int i = 0; i < outputs.size() - 1; i++) { + EXPECT_FALSE(outputs[i].timestamp_ms.has_value()); + EXPECT_EQ(outputs[i].classifications.size(), 1); + EXPECT_EQ(outputs[i].classifications[0].head_index, 0); + EXPECT_EQ(outputs[i].classifications[0].head_name, "scores"); + EXPECT_EQ(outputs[i].classifications[0].categories.size(), 1); + EXPECT_EQ(outputs[i].classifications[0].categories[0].index, 0); + EXPECT_EQ(outputs[i].classifications[0].categories[0].category_name, + "Speech"); + EXPECT_GT(outputs[i].classifications[0].categories[0].score, 0.9f); } - int last_elem_index = outputs.size() - 1; - EXPECT_EQ( - mediapipe::Timestamp::Done().Value() / 1000, - outputs[last_elem_index].classifications(0).entries(0).timestamp_ms()); } class CreateFromOptionsTest : public tflite_shims::testing::Test {}; @@ -264,7 +252,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); options->result_callback = - [](absl::StatusOr status_or_result) {}; + [](absl::StatusOr status_or_result) {}; StatusOr> audio_classifier_or = AudioClassifier::Create(std::move(options)); @@ -284,7 +272,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) { JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); options->running_mode = core::RunningMode::AUDIO_STREAM; options->result_callback = - [](absl::StatusOr status_or_result) {}; + [](absl::StatusOr status_or_result) {}; StatusOr> audio_classifier_or = AudioClassifier::Create(std::move(options)); @@ -310,7 +298,7 @@ TEST_F(ClassifyTest, Succeeds) { auto result, audio_classifier->Classify(std::move(audio_buffer), /*audio_sample_rate=*/16000)); MP_ASSERT_OK(audio_classifier->Close()); - CheckSpeechClassificationResult(result); + CheckSpeechResult(result); } TEST_F(ClassifyTest, SucceedsWithResampling) { @@ -324,7 +312,7 @@ TEST_F(ClassifyTest, SucceedsWithResampling) { auto result, audio_classifier->Classify(std::move(audio_buffer), /*audio_sample_rate=*/48000)); MP_ASSERT_OK(audio_classifier->Close()); - CheckSpeechClassificationResult(result); + CheckSpeechResult(result); } TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) { @@ -339,13 +327,13 @@ TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) { auto result_16k_hz, audio_classifier->Classify(std::move(audio_buffer_16k_hz), /*audio_sample_rate=*/16000)); - CheckSpeechClassificationResult(result_16k_hz); + CheckSpeechResult(result_16k_hz); MP_ASSERT_OK_AND_ASSIGN( auto result_48k_hz, audio_classifier->Classify(std::move(audio_buffer_48k_hz), /*audio_sample_rate=*/48000)); MP_ASSERT_OK(audio_classifier->Close()); - CheckSpeechClassificationResult(result_48k_hz); + CheckSpeechResult(result_48k_hz); } TEST_F(ClassifyTest, SucceedsWithInsufficientData) { @@ -361,15 +349,16 @@ TEST_F(ClassifyTest, SucceedsWithInsufficientData) { MP_ASSERT_OK_AND_ASSIGN( auto result, audio_classifier->Classify(std::move(zero_matrix), 16000)); MP_ASSERT_OK(audio_classifier->Close()); - EXPECT_THAT(result.classifications_size(), testing::Eq(1)); - EXPECT_THAT(result.classifications(0).entries_size(), testing::Eq(1)); - EXPECT_THAT(result.classifications(0).entries(0).categories_size(), - testing::Eq(521)); - EXPECT_THAT( - result.classifications(0).entries(0).categories(0).category_name(), - testing::Eq("Silence")); - EXPECT_THAT(result.classifications(0).entries(0).categories(0).score(), - testing::FloatEq(.800781f)); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0].timestamp_ms, 0); + EXPECT_EQ(result[0].classifications.size(), 1); + EXPECT_EQ(result[0].classifications[0].head_index, 0); + EXPECT_EQ(result[0].classifications[0].head_name, "scores"); + EXPECT_EQ(result[0].classifications[0].categories.size(), 521); + EXPECT_EQ(result[0].classifications[0].categories[0].index, 494); + EXPECT_EQ(result[0].classifications[0].categories[0].category_name, + "Silence"); + EXPECT_FLOAT_EQ(result[0].classifications[0].categories[0].score, 0.800781f); } TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) { @@ -383,7 +372,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) { auto result, audio_classifier->Classify(std::move(audio_buffer), /*audio_sample_rate=*/16000)); MP_ASSERT_OK(audio_classifier->Close()); - CheckTwoHeadsClassificationResult(result); + CheckTwoHeadsResult(result); } TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) { @@ -397,7 +386,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) { auto result, audio_classifier->Classify(std::move(audio_buffer), /*audio_sample_rate=*/44100)); MP_ASSERT_OK(audio_classifier->Close()); - CheckTwoHeadsClassificationResult(result); + CheckTwoHeadsResult(result); } TEST_F(ClassifyTest, @@ -413,13 +402,13 @@ TEST_F(ClassifyTest, auto result_44k_hz, audio_classifier->Classify(std::move(audio_buffer_44k_hz), /*audio_sample_rate=*/44100)); - CheckTwoHeadsClassificationResult(result_44k_hz); + CheckTwoHeadsResult(result_44k_hz); MP_ASSERT_OK_AND_ASSIGN( auto result_16k_hz, audio_classifier->Classify(std::move(audio_buffer_16k_hz), /*audio_sample_rate=*/16000)); MP_ASSERT_OK(audio_classifier->Close()); - CheckTwoHeadsClassificationResult(result_16k_hz); + CheckTwoHeadsResult(result_16k_hz); } TEST_F(ClassifyTest, SucceedsWithMaxResultOption) { @@ -428,14 +417,13 @@ TEST_F(ClassifyTest, SucceedsWithMaxResultOption) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithMetadata); options->classifier_options.max_results = 1; - options->classifier_options.score_threshold = 0.35f; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_classifier, AudioClassifier::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN( auto result, audio_classifier->Classify(std::move(audio_buffer), /*audio_sample_rate=*/48000)); MP_ASSERT_OK(audio_classifier->Close()); - EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult())); + CheckSpeechResult(result, /*expected_num_categories=*/1); } TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) { @@ -450,7 +438,7 @@ TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) { auto result, audio_classifier->Classify(std::move(audio_buffer), /*audio_sample_rate=*/48000)); MP_ASSERT_OK(audio_classifier->Close()); - EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult())); + CheckSpeechResult(result, /*expected_num_categories=*/1); } TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) { @@ -466,7 +454,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) { auto result, audio_classifier->Classify(std::move(audio_buffer), /*audio_sample_rate=*/48000)); MP_ASSERT_OK(audio_classifier->Close()); - EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult())); + CheckSpeechResult(result, /*expected_num_categories=*/1); } TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) { @@ -482,16 +470,16 @@ TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) { auto result, audio_classifier->Classify(std::move(audio_buffer), /*audio_sample_rate=*/48000)); MP_ASSERT_OK(audio_classifier->Close()); - // All categroies with the "Speech" label are filtered out. - EXPECT_THAT(result, EqualsProto(R"pb(classifications { - head_index: 0 - head_name: "scores" - entries { timestamp_ms: 0 } - entries { timestamp_ms: 975 } - entries { timestamp_ms: 1950 } - entries { timestamp_ms: 2925 } - entries { timestamp_ms: 3900 } - })pb")); + // All categories with the "Speech" label are filtered out. + std::vector timestamps_ms = {0, 975, 1950, 2925}; + for (int i = 0; i < timestamps_ms.size(); i++) { + EXPECT_EQ(result[i].timestamp_ms, timestamps_ms[i]); + EXPECT_EQ(result[i].classifications.size(), 1); + auto classifications = result[i].classifications[0]; + EXPECT_EQ(classifications.head_index, 0); + EXPECT_EQ(classifications.head_name, "scores"); + EXPECT_TRUE(classifications.categories.empty()); + } } class ClassifyAsyncTest : public tflite_shims::testing::Test {}; @@ -506,9 +494,9 @@ TEST_F(ClassifyAsyncTest, Succeeds) { options->classifier_options.score_threshold = 0.3f; options->running_mode = core::RunningMode::AUDIO_STREAM; options->sample_rate = kSampleRateHz; - std::vector outputs; + std::vector outputs; options->result_callback = - [&outputs](absl::StatusOr status_or_result) { + [&outputs](absl::StatusOr status_or_result) { MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result); }; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_classifier, @@ -523,7 +511,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) { start_col += kYamnetNumOfAudioSamples * 3; } MP_ASSERT_OK(audio_classifier->Close()); - CheckStreamingModeClassificationResult(outputs); + CheckStreamingModeResults(outputs); } TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) { @@ -536,9 +524,9 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) { options->classifier_options.score_threshold = 0.3f; options->running_mode = core::RunningMode::AUDIO_STREAM; options->sample_rate = kSampleRateHz; - std::vector outputs; + std::vector outputs; options->result_callback = - [&outputs](absl::StatusOr status_or_result) { + [&outputs](absl::StatusOr status_or_result) { MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result); }; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_classifier, @@ -555,7 +543,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) { start_col += num_samples; } MP_ASSERT_OK(audio_classifier->Close()); - CheckStreamingModeClassificationResult(outputs); + CheckStreamingModeResults(outputs); } } // namespace From 57759aedafb19c1c1703968d616dfcc26460ca22 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 4 Nov 2022 09:45:33 -0700 Subject: [PATCH 26/34] Enable using tasks metadata components in MediaPipe PyPI packages. PiperOrigin-RevId: 486162986 --- requirements.txt | 1 + setup.py | 51 +++++++++++++++++++++++++++++++++++++----------- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/requirements.txt b/requirements.txt index 00e51ffd6..85a08aea9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ absl-py attrs>=19.1.0 +flatbuffers>=2.0 matplotlib numpy opencv-contrib-python diff --git a/setup.py b/setup.py index cda53b1c6..40c4a7361 100644 --- a/setup.py +++ b/setup.py @@ -129,6 +129,15 @@ def _add_mp_init_files(): mp_dir_init_file.close() +def _copy_to_build_lib_dir(build_lib, file): + """Copy a file from bazel-bin to the build lib dir.""" + dst = os.path.join(build_lib + '/', file) + dst_dir = os.path.dirname(dst) + if not os.path.exists(dst_dir): + os.makedirs(dst_dir) + shutil.copyfile(os.path.join('bazel-bin/', file), dst) + + class GeneratePyProtos(build_ext.build_ext): """Generate MediaPipe Python protobuf files by Protocol Compiler.""" @@ -259,7 +268,7 @@ class BuildModules(build_ext.build_ext): ] if subprocess.call(fetch_model_command) != 0: sys.exit(-1) - self._copy_to_build_lib_dir(external_file) + _copy_to_build_lib_dir(self.build_lib, external_file) def _generate_binary_graph(self, binary_graph_target): """Generate binary graph for a particular MediaPipe binary graph target.""" @@ -277,15 +286,27 @@ class BuildModules(build_ext.build_ext): bazel_command.append('--define=OPENCV=source') if subprocess.call(bazel_command) != 0: sys.exit(-1) - self._copy_to_build_lib_dir(binary_graph_target + '.binarypb') + _copy_to_build_lib_dir(self.build_lib, binary_graph_target + '.binarypb') - def _copy_to_build_lib_dir(self, file): - """Copy a file from bazel-bin to the build lib dir.""" - dst = os.path.join(self.build_lib + '/', file) - dst_dir = os.path.dirname(dst) - if not os.path.exists(dst_dir): - os.makedirs(dst_dir) - shutil.copyfile(os.path.join('bazel-bin/', file), dst) + +class GenerateMetadataSchema(build_ext.build_ext): + """Generate metadata python schema files.""" + + def run(self): + for target in ['metadata_schema_py', 'schema_py']: + bazel_command = [ + 'bazel', + 'build', + '--compilation_mode=opt', + '--define=MEDIAPIPE_DISABLE_GPU=1', + '--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable), + '//mediapipe/tasks/metadata:' + target, + ] + if subprocess.call(bazel_command) != 0: + sys.exit(-1) + _copy_to_build_lib_dir( + self.build_lib, + 'mediapipe/tasks/metadata/' + target + '_generated.py') class BazelExtension(setuptools.Extension): @@ -375,6 +396,7 @@ class BuildPy(build_py.build_py): build_ext_obj = self.distribution.get_command_obj('build_ext') build_ext_obj.link_opencv = self.link_opencv self.run_command('gen_protos') + self.run_command('generate_metadata_schema') self.run_command('build_modules') self.run_command('build_ext') build_py.build_py.run(self) @@ -434,18 +456,25 @@ setuptools.setup( author_email='mediapipe@google.com', long_description=_get_long_description(), long_description_content_type='text/markdown', - packages=setuptools.find_packages(exclude=['mediapipe.examples.desktop.*']), + packages=setuptools.find_packages( + exclude=['mediapipe.examples.desktop.*', 'mediapipe.model_maker.*']), install_requires=_parse_requirements('requirements.txt'), cmdclass={ 'build_py': BuildPy, - 'gen_protos': GeneratePyProtos, 'build_modules': BuildModules, 'build_ext': BuildExtension, + 'generate_metadata_schema': GenerateMetadataSchema, + 'gen_protos': GeneratePyProtos, 'install': Install, 'restore': Restore, }, ext_modules=[ BazelExtension('//mediapipe/python:_framework_bindings'), + BazelExtension( + '//mediapipe/tasks/cc/metadata/python:_pywrap_metadata_version'), + BazelExtension( + '//mediapipe/tasks/python/metadata/flatbuffers_lib:_pywrap_flatbuffers' + ), ], zip_safe=False, include_package_data=True, From 5fd3701cfd7564b3b6de7120dfc882355675b033 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 4 Nov 2022 10:37:52 -0700 Subject: [PATCH 27/34] Expose ImageSegmenter and GestureRecognizer APIs in PyPI packages. PiperOrigin-RevId: 486176206 --- mediapipe/tasks/python/vision/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mediapipe/tasks/python/vision/__init__.py b/mediapipe/tasks/python/vision/__init__.py index def113178..276433a49 100644 --- a/mediapipe/tasks/python/vision/__init__.py +++ b/mediapipe/tasks/python/vision/__init__.py @@ -15,17 +15,25 @@ """MediaPipe Tasks Vision API.""" import mediapipe.tasks.python.vision.core +import mediapipe.tasks.python.vision.gesture_recognizer import mediapipe.tasks.python.vision.image_classifier +import mediapipe.tasks.python.vision.image_segmenter import mediapipe.tasks.python.vision.object_detector +GestureRecognizer = gesture_recognizer.GestureRecognizer +GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions ImageClassifier = image_classifier.ImageClassifier ImageClassifierOptions = image_classifier.ImageClassifierOptions +ImageSegmenter = image_segmenter.ImageSegmenter +ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions ObjectDetector = object_detector.ObjectDetector ObjectDetectorOptions = object_detector.ObjectDetectorOptions RunningMode = core.vision_task_running_mode.VisionTaskRunningMode # Remove unnecessary modules to avoid duplication in API docs. del core +del gesture_recognizer del image_classifier +del image_segmenter del object_detector del mediapipe From f05868d8e0c8b60ed5ee6a9c63ba1c94f2f56126 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 4 Nov 2022 10:57:26 -0700 Subject: [PATCH 28/34] Set up the Open Source build for MediaPipe Tasks Web PiperOrigin-RevId: 486181061 --- .gitignore | 1 + BUILD.bazel | 9 +- WORKSPACE | 43 ++ mediapipe/framework/deps/BUILD | 11 +- mediapipe/framework/port/build_config.bzl | 55 +- .../tasks/web/components/containers/BUILD | 21 + .../web/components/containers/category.d.ts | 38 ++ .../containers/classifications.d.ts | 51 ++ .../web/components/containers/landmark.d.ts | 35 + .../tasks/web/components/processors/BUILD | 33 + .../web/components/processors/base_options.ts | 50 ++ .../processors/classifier_options.ts | 62 ++ .../processors/classifier_result.ts | 61 ++ mediapipe/tasks/web/core/BUILD | 21 + mediapipe/tasks/web/core/base_options.d.ts | 31 + .../tasks/web/core/classifier_options.d.ts | 52 ++ .../tasks/web/core/wasm_loader_options.d.ts | 25 + package.json | 15 + tsconfig.json | 47 ++ yarn.lock | 626 ++++++++++++++++++ 20 files changed, 1275 insertions(+), 12 deletions(-) create mode 100644 mediapipe/tasks/web/components/containers/BUILD create mode 100644 mediapipe/tasks/web/components/containers/category.d.ts create mode 100644 mediapipe/tasks/web/components/containers/classifications.d.ts create mode 100644 mediapipe/tasks/web/components/containers/landmark.d.ts create mode 100644 mediapipe/tasks/web/components/processors/BUILD create mode 100644 mediapipe/tasks/web/components/processors/base_options.ts create mode 100644 mediapipe/tasks/web/components/processors/classifier_options.ts create mode 100644 mediapipe/tasks/web/components/processors/classifier_result.ts create mode 100644 mediapipe/tasks/web/core/BUILD create mode 100644 mediapipe/tasks/web/core/base_options.d.ts create mode 100644 mediapipe/tasks/web/core/classifier_options.d.ts create mode 100644 mediapipe/tasks/web/core/wasm_loader_options.d.ts create mode 100644 package.json create mode 100644 tsconfig.json create mode 100644 yarn.lock diff --git a/.gitignore b/.gitignore index b3a881711..525f0878e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,6 @@ bazel-* mediapipe/MediaPipe.xcodeproj mediapipe/MediaPipe.tulsiproj/*.tulsiconf-user mediapipe/provisioning_profile.mobileprovision +node_modules/ .configure.bazelrc .user.bazelrc diff --git a/BUILD.bazel b/BUILD.bazel index 1973f98af..e3443b83e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipe Authors. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,4 +14,9 @@ licenses(["notice"]) -exports_files(["LICENSE"]) +exports_files([ + "LICENSE", + "tsconfig.json", + "package.json", + "yarn.lock", +]) diff --git a/WORKSPACE b/WORKSPACE index 5a47cf6b7..d2d7d8ea7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -501,5 +501,48 @@ libedgetpu_dependencies() load("@coral_crosstool//:configure.bzl", "cc_crosstool") cc_crosstool(name = "crosstool") + +# Node dependencies +http_archive( + name = "build_bazel_rules_nodejs", + sha256 = "5aae76dced38f784b58d9776e4ab12278bc156a9ed2b1d9fcd3e39921dc88fda", + urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.7.1/rules_nodejs-5.7.1.tar.gz"], +) + +load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies") +build_bazel_rules_nodejs_dependencies() + +# fetches nodejs, npm, and yarn +load("@build_bazel_rules_nodejs//:index.bzl", "node_repositories", "yarn_install") +node_repositories() +yarn_install( + name = "npm", + package_json = "//:package.json", + yarn_lock = "//:yarn.lock", +) + +# Protobuf for Node dependencies +http_archive( + name = "rules_proto_grpc", + sha256 = "bbe4db93499f5c9414926e46f9e35016999a4e9f6e3522482d3760dc61011070", + strip_prefix = "rules_proto_grpc-4.2.0", + urls = ["https://github.com/rules-proto-grpc/rules_proto_grpc/archive/4.2.0.tar.gz"], +) + +http_archive( + name = "com_google_protobuf_javascript", + sha256 = "35bca1729532b0a77280bf28ab5937438e3dcccd6b31a282d9ae84c896b6f6e3", + strip_prefix = "protobuf-javascript-3.21.2", + urls = ["https://github.com/protocolbuffers/protobuf-javascript/archive/refs/tags/v3.21.2.tar.gz"], +) + +load("@rules_proto_grpc//:repositories.bzl", "rules_proto_grpc_toolchains", "rules_proto_grpc_repos") +rules_proto_grpc_toolchains() +rules_proto_grpc_repos() + +load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains") +rules_proto_dependencies() +rules_proto_toolchains() + load("//third_party:external_files.bzl", "external_files") external_files() diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index 120ddc711..a39d7476e 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -15,7 +15,7 @@ # Description: # The dependencies of mediapipe. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") licenses(["notice"]) @@ -38,7 +38,7 @@ bzl_library( visibility = ["//mediapipe/framework:__subpackages__"], ) -proto_library( +mediapipe_proto_library( name = "proto_descriptor_proto", srcs = ["proto_descriptor.proto"], visibility = [ @@ -47,13 +47,6 @@ proto_library( ], ) -mediapipe_cc_proto_library( - name = "proto_descriptor_cc_proto", - srcs = ["proto_descriptor.proto"], - visibility = ["//mediapipe/framework:__subpackages__"], - deps = [":proto_descriptor_proto"], -) - cc_library( name = "aligned_malloc_and_free", hdrs = ["aligned_malloc_and_free.h"], diff --git a/mediapipe/framework/port/build_config.bzl b/mediapipe/framework/port/build_config.bzl index fb01d4539..80e9bfc4d 100644 --- a/mediapipe/framework/port/build_config.bzl +++ b/mediapipe/framework/port/build_config.bzl @@ -4,6 +4,9 @@ """.bzl file for mediapipe open source build configs.""" load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "py_proto_library") +load("@npm//@bazel/typescript:index.bzl", "ts_project") +load("@rules_proto//proto:defs.bzl", "proto_library") +load("@rules_proto_grpc//js:defs.bzl", "js_proto_library") load("//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_options_library") def provided_args(**kwargs): @@ -71,7 +74,7 @@ def mediapipe_proto_library( def_jspb_proto: define the jspb_proto_library target def_options_lib: define the mediapipe_options_library target """ - _ignore = [def_portable_proto, def_objc_proto, def_java_proto, def_jspb_proto, portable_deps] + _ignore = [def_portable_proto, def_objc_proto, def_java_proto, portable_deps] # buildifier: disable=unused-variable # The proto_library targets for the compiled ".proto" source files. proto_deps = [":" + name] @@ -119,6 +122,24 @@ def mediapipe_proto_library( compatible_with = compatible_with, )) + if def_jspb_proto: + js_deps = replace_deps(deps, "_proto", "_jspb_proto", False) + proto_library( + name = replace_suffix(name, "_proto", "_lib_proto"), + srcs = srcs, + deps = deps, + ) + js_proto_library( + name = replace_suffix(name, "_proto", "_jspb_proto"), + protos = [replace_suffix(name, "_proto", "_lib_proto")], + output_mode = "NO_PREFIX_FLAT", + # Need to specify this to work around bug in js_proto_library() + # https://github.com/bazelbuild/rules_nodejs/issues/3503 + legacy_path = "unused", + deps = js_deps, + visibility = visibility, + ) + if def_options_lib: cc_deps = replace_deps(deps, "_proto", "_cc_proto") mediapipe_options_library(**provided_args( @@ -182,3 +203,35 @@ def mediapipe_cc_proto_library(name, srcs, visibility = None, deps = [], cc_deps default_runtime = "@com_google_protobuf//:protobuf", alwayslink = 1, )) + +def mediapipe_ts_library( + name, + srcs, + visibility = None, + deps = [], + testonly = 0, + allow_unoptimized_namespaces = False): + """Generate ts_project for MediaPipe open source version. + + Args: + name: the name of the cc_proto_library. + srcs: the .proto files of the cc_proto_library for Bazel use. + visibility: visibility of this target. + deps: a list of dependency labels for Bazel use; must be cc_proto_library. + testonly: test only or not. + allow_unoptimized_namespaces: ignored, used only internally + """ + _ignore = [allow_unoptimized_namespaces] # buildifier: disable=unused-variable + + ts_project(**provided_args( + name = name, + srcs = srcs, + visibility = visibility, + deps = deps + [ + "@npm//@types/offscreencanvas", + "@npm//@types/google-protobuf", + ], + testonly = testonly, + declaration = True, + tsconfig = "//:tsconfig.json", + )) diff --git a/mediapipe/tasks/web/components/containers/BUILD b/mediapipe/tasks/web/components/containers/BUILD new file mode 100644 index 000000000..7d13fadcb --- /dev/null +++ b/mediapipe/tasks/web/components/containers/BUILD @@ -0,0 +1,21 @@ +# This package contains options shared by all MediaPipe Tasks for Web. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "category", + srcs = ["category.d.ts"], +) + +mediapipe_ts_library( + name = "classifications", + srcs = ["classifications.d.ts"], + deps = [":category"], +) + +mediapipe_ts_library( + name = "landmark", + srcs = ["landmark.d.ts"], +) diff --git a/mediapipe/tasks/web/components/containers/category.d.ts b/mediapipe/tasks/web/components/containers/category.d.ts new file mode 100644 index 000000000..0da8a061f --- /dev/null +++ b/mediapipe/tasks/web/components/containers/category.d.ts @@ -0,0 +1,38 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** A classification category. */ +export interface Category { + /** The probability score of this label category. */ + score: number; + + /** The index of the category in the corresponding label file. */ + index: number; + + /** + * The label of this category object. Defaults to an empty string if there is + * no category. + */ + categoryName: string; + + /** + * The display name of the label, which may be translated for different + * locales. For example, a label, "apple", may be translated into Spanish for + * display purpose, so that the `display_name` is "manzana". Defaults to an + * empty string if there is no display name. + */ + displayName: string; +} diff --git a/mediapipe/tasks/web/components/containers/classifications.d.ts b/mediapipe/tasks/web/components/containers/classifications.d.ts new file mode 100644 index 000000000..67a259bbe --- /dev/null +++ b/mediapipe/tasks/web/components/containers/classifications.d.ts @@ -0,0 +1,51 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Category} from '../../../../tasks/web/components/containers/category'; + +/** List of predicted categories with an optional timestamp. */ +export interface ClassificationEntry { + /** + * The array of predicted categories, usually sorted by descending scores, + * e.g., from high to low probability. + */ + categories: Category[]; + + /** + * The optional timestamp (in milliseconds) associated to the classification + * entry. This is useful for time series use cases, e.g., audio + * classification. + */ + timestampMs?: number; +} + +/** Classifications for a given classifier head. */ +export interface Classifications { + /** A list of classification entries. */ + entries: ClassificationEntry[]; + + /** + * The index of the classifier head these categories refer to. This is + * useful for multi-head models. + */ + headIndex: number; + + /** + * The name of the classifier head, which is the corresponding tensor + * metadata name. + */ + headName: string; +} diff --git a/mediapipe/tasks/web/components/containers/landmark.d.ts b/mediapipe/tasks/web/components/containers/landmark.d.ts new file mode 100644 index 000000000..9a1badbba --- /dev/null +++ b/mediapipe/tasks/web/components/containers/landmark.d.ts @@ -0,0 +1,35 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Landmark represents a point in 3D space with x, y, z coordinates. If + * normalized is true, the landmark coordinates is normalized respect to the + * dimension of image, and the coordinates values are in the range of [0,1]. + * Otherwise, it represenet a point in world coordinates. + */ +export class Landmark { + /** The x coordinates of the landmark. */ + x: number; + + /** The y coordinates of the landmark. */ + y: number; + + /** The z coordinates of the landmark. */ + z: number; + + /** Whether this landmark is normalized with respect to the image size. */ + normalized: boolean; +} diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD new file mode 100644 index 000000000..e6b9adf20 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -0,0 +1,33 @@ +# This package contains options shared by all MediaPipe Tasks for Web. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "classifier_options", + srcs = ["classifier_options.ts"], + deps = [ + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_jspb_proto", + "//mediapipe/tasks/web/core:classifier_options", + ], +) + +mediapipe_ts_library( + name = "classifier_result", + srcs = ["classifier_result.ts"], + deps = [ + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/components/containers:classifications", + ], +) + +mediapipe_ts_library( + name = "base_options", + srcs = ["base_options.ts"], + deps = [ + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", + "//mediapipe/tasks/web/core", + ], +) diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts new file mode 100644 index 000000000..2f7d0db37 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/base_options.ts @@ -0,0 +1,50 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; +import {BaseOptions} from '../../../../tasks/web/core/base_options'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** + * Converts a BaseOptions API object to its Protobuf representation. + * @throws If neither a model assset path or buffer is provided + */ +export async function convertBaseOptionsToProto(baseOptions: BaseOptions): + Promise { + if (baseOptions.modelAssetPath && baseOptions.modelAssetBuffer) { + throw new Error( + 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); + } + if (!baseOptions.modelAssetPath && !baseOptions.modelAssetBuffer) { + throw new Error( + 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); + } + + let modelAssetBuffer = baseOptions.modelAssetBuffer; + if (!modelAssetBuffer) { + const response = await fetch(baseOptions.modelAssetPath!.toString()); + modelAssetBuffer = new Uint8Array(await response.arrayBuffer()); + } + + const proto = new BaseOptionsProto(); + const externalFile = new ExternalFile(); + externalFile.setFileContent(modelAssetBuffer); + proto.setModelAsset(externalFile); + return proto; +} diff --git a/mediapipe/tasks/web/components/processors/classifier_options.ts b/mediapipe/tasks/web/components/processors/classifier_options.ts new file mode 100644 index 000000000..8e01dd373 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/classifier_options.ts @@ -0,0 +1,62 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {ClassifierOptions as ClassifierOptionsProto} from '../../../../tasks/cc/components/processors/proto/classifier_options_pb'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; + +/** + * Converts a ClassifierOptions object to its Proto representation, optionally + * based on existing definition. + * @param options The options object to convert to a Proto. Only options that + * are expliclty provided are set. + * @param baseOptions A base object that options can be merged into. + */ +export function convertClassifierOptionsToProto( + options: ClassifierOptions, + baseOptions?: ClassifierOptionsProto): ClassifierOptionsProto { + const classifierOptions = + baseOptions ? baseOptions.clone() : new ClassifierOptionsProto(); + if (options.displayNamesLocale) { + classifierOptions.setDisplayNamesLocale(options.displayNamesLocale); + } else if (options.displayNamesLocale === undefined) { + classifierOptions.clearDisplayNamesLocale(); + } + + if (options.maxResults) { + classifierOptions.setMaxResults(options.maxResults); + } else if ('maxResults' in options) { // Check for undefined + classifierOptions.clearMaxResults(); + } + + if (options.scoreThreshold) { + classifierOptions.setScoreThreshold(options.scoreThreshold); + } else if ('scoreThreshold' in options) { // Check for undefined + classifierOptions.clearScoreThreshold(); + } + + if (options.categoryAllowlist) { + classifierOptions.setCategoryAllowlistList(options.categoryAllowlist); + } else if ('categoryAllowlist' in options) { // Check for undefined + classifierOptions.clearCategoryAllowlistList(); + } + + if (options.categoryDenylist) { + classifierOptions.setCategoryDenylistList(options.categoryDenylist); + } else if ('categoryDenylist' in options) { // Check for undefined + classifierOptions.clearCategoryDenylistList(); + } + return classifierOptions; +} diff --git a/mediapipe/tasks/web/components/processors/classifier_result.ts b/mediapipe/tasks/web/components/processors/classifier_result.ts new file mode 100644 index 000000000..ade967932 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/classifier_result.ts @@ -0,0 +1,61 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {ClassificationEntry as ClassificationEntryProto, ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; + +const DEFAULT_INDEX = -1; +const DEFAULT_SCORE = 0.0; + +/** + * Converts a ClassificationEntry proto to the ClassificationEntry result + * type. + */ +function convertFromClassificationEntryProto(source: ClassificationEntryProto): + ClassificationEntry { + const categories = source.getCategoriesList().map(category => { + return { + index: category.getIndex() ?? DEFAULT_INDEX, + score: category.getScore() ?? DEFAULT_SCORE, + displayName: category.getDisplayName() ?? '', + categoryName: category.getCategoryName() ?? '', + }; + }); + + return { + categories, + timestampMs: source.getTimestampMs(), + }; +} + +/** + * Converts a ClassificationResult proto to a list of classifications. + */ +export function convertFromClassificationResultProto( + classificationResult: ClassificationResult) : Classifications[] { + const result: Classifications[] = []; + for (const classificationsProto of + classificationResult.getClassificationsList()) { + const classifications: Classifications = { + entries: classificationsProto.getEntriesList().map( + entry => convertFromClassificationEntryProto(entry)), + headIndex: classificationsProto.getHeadIndex() ?? DEFAULT_INDEX, + headName: classificationsProto.getHeadName() ?? '', + }; + result.push(classifications); + } + return result; +} diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD new file mode 100644 index 000000000..a5547ad6e --- /dev/null +++ b/mediapipe/tasks/web/core/BUILD @@ -0,0 +1,21 @@ +# This package contains options shared by all MediaPipe Tasks for Web. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "core", + srcs = [ + "base_options.d.ts", + "wasm_loader_options.d.ts", + ], +) + +mediapipe_ts_library( + name = "classifier_options", + srcs = [ + "classifier_options.d.ts", + ], + deps = [":core"], +) diff --git a/mediapipe/tasks/web/core/base_options.d.ts b/mediapipe/tasks/web/core/base_options.d.ts new file mode 100644 index 000000000..02a288a87 --- /dev/null +++ b/mediapipe/tasks/web/core/base_options.d.ts @@ -0,0 +1,31 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Placeholder for internal dependency on trusted resource url + +/** Options to configure MediaPipe Tasks in general. */ +export interface BaseOptions { + /** + * The model path to the model asset file. Only one of `modelAssetPath` or + * `modelAssetBuffer` can be set. + */ + modelAssetPath?: string; + /** + * A buffer containing the model aaset. Only one of `modelAssetPath` or + * `modelAssetBuffer` can be set. + */ + modelAssetBuffer?: Uint8Array; +} diff --git a/mediapipe/tasks/web/core/classifier_options.d.ts b/mediapipe/tasks/web/core/classifier_options.d.ts new file mode 100644 index 000000000..356643c55 --- /dev/null +++ b/mediapipe/tasks/web/core/classifier_options.d.ts @@ -0,0 +1,52 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions} from '../../../tasks/web/core/base_options'; + +/** Options to configure the Mediapipe Classifier Task. */ +export interface ClassifierOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; + + /** + * The locale to use for display names specified through the TFLite Model + * Metadata, if any. Defaults to English. + */ + displayNamesLocale?: string|undefined; + + /** The maximum number of top-scored detection results to return. */ + maxResults?: number|undefined; + + /** + * Overrides the value provided in the model metadata. Results below this + * value are rejected. + */ + scoreThreshold?: number|undefined; + + /** + * Allowlist of category names. If non-empty, detection results whose category + * name is not in this set will be filtered out. Duplicate or unknown category + * names are ignored. Mutually exclusive with `categoryDenylist`. + */ + categoryAllowlist?: string[]|undefined; + + /** + * Denylist of category names. If non-empty, detection results whose category + * name is in this set will be filtered out. Duplicate or unknown category + * names are ignored. Mutually exclusive with `categoryAllowlist`. + */ + categoryDenylist?: string[]|undefined; +} diff --git a/mediapipe/tasks/web/core/wasm_loader_options.d.ts b/mediapipe/tasks/web/core/wasm_loader_options.d.ts new file mode 100644 index 000000000..1925aaf06 --- /dev/null +++ b/mediapipe/tasks/web/core/wasm_loader_options.d.ts @@ -0,0 +1,25 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Placeholder for internal dependency on trusted resource url + +/** An object containing the locations of all Wasm assets */ +export interface WasmLoaderOptions { + /** The path to the Wasm loader script. */ + wasmLoaderPath: string; + /** The path to the Wasm binary. */ + wasmBinaryPath: string; +} diff --git a/package.json b/package.json new file mode 100644 index 000000000..f8478a159 --- /dev/null +++ b/package.json @@ -0,0 +1,15 @@ +{ + "name": "medipipe-dev", + "version": "0.0.0-alphga", + "description": "MediaPipe GitHub repo", + "devDependencies": { + "@bazel/typescript": "^5.7.1", + "@types/google-protobuf": "^3.15.6", + "@types/offscreencanvas": "^2019.7.0", + "google-protobuf": "^3.21.2", + "protobufjs": "^7.1.2", + "protobufjs-cli": "^1.0.2", + "ts-protoc-gen": "^0.15.0", + "typescript": "^4.8.4" + } +} diff --git a/tsconfig.json b/tsconfig.json new file mode 100644 index 000000000..c17b1902e --- /dev/null +++ b/tsconfig.json @@ -0,0 +1,47 @@ +{ + "compilerOptions": { + "target": "es2017", + "module": "commonjs", + "lib": ["ES2017", "dom"], + "declaration": true, + "moduleResolution": "node", + "esModuleInterop": true, + "noImplicitAny": true, + "inlineSourceMap": true, + "inlineSources": true, + "strict": true, + "types": ["@types/offscreencanvas"], + "rootDirs": [ + ".", + "./bazel-out/host/bin", + "./bazel-out/darwin-dbg/bin", + "./bazel-out/darwin-fastbuild/bin", + "./bazel-out/darwin-opt/bin", + "./bazel-out/darwin_arm64-dbg/bin", + "./bazel-out/darwin_arm64-fastbuild/bin", + "./bazel-out/darwin_arm64-opt/bin", + "./bazel-out/k8-dbg/bin", + "./bazel-out/k8-fastbuild/bin", + "./bazel-out/k8-opt/bin", + "./bazel-out/x64_windows-dbg/bin", + "./bazel-out/x64_windows-fastbuild/bin", + "./bazel-out/x64_windows-opt/bin", + "./bazel-out/darwin-dbg/bin", + "./bazel-out/darwin-fastbuild/bin", + "./bazel-out/darwin-opt/bin", + "./bazel-out/k8-dbg/bin", + "./bazel-out/k8-fastbuild/bin", + "./bazel-out/k8-opt/bin", + "./bazel-out/x64_windows-dbg/bin", + "./bazel-out/x64_windows-fastbuild/bin", + "./bazel-out/x64_windows-opt/bin", + "./bazel-out/k8-fastbuild-ST-4a519fd6d3e4/bin" + ] + }, + "exclude": [ + "./_bazel_bin", + "./_bazel_buildbot", + "./_bazel_out", + "./_bazel_testlogs" + ] +} diff --git a/yarn.lock b/yarn.lock new file mode 100644 index 000000000..e6398fb1f --- /dev/null +++ b/yarn.lock @@ -0,0 +1,626 @@ +# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY. +# yarn lockfile v1 + + +"@babel/parser@^7.9.4": + version "7.20.1" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.1.tgz#3e045a92f7b4623cafc2425eddcb8cf2e54f9cc5" + integrity sha512-hp0AYxaZJhxULfM1zyp7Wgr+pSUKBcP3M+PHnSzWGdXOzg/kHWIgiUWARvubhUKGOEw3xqY4x+lyZ9ytBVcELw== + +"@bazel/typescript@^5.7.1": + version "5.7.1" + resolved "https://registry.yarnpkg.com/@bazel/typescript/-/typescript-5.7.1.tgz#e585bcdc54a4ccb23d99c3e1206abf4853cf0682" + integrity sha512-MAnAtFxA2znadm81+rbYXcyWX1DEF/urzZ1F4LBq+w27EQ4PGyqIqCM5om7JcoSZJwjjMoBJc3SflRsMrZZ6+g== + dependencies: + "@bazel/worker" "5.7.1" + semver "5.6.0" + source-map-support "0.5.9" + tsutils "3.21.0" + +"@bazel/worker@5.7.1": + version "5.7.1" + resolved "https://registry.yarnpkg.com/@bazel/worker/-/worker-5.7.1.tgz#2c4a9bd0e0ef75e496aec9599ff64a87307e7dad" + integrity sha512-UndmQVRqK0t0NMNl8I1P5XmxzdPvMA0X6jufszpfwy5gyzjOxeiOIzmC0ALCOx78CuJqOB/8WOI1pwTRmhd0tg== + dependencies: + google-protobuf "^3.6.1" + +"@protobufjs/aspromise@^1.1.1", "@protobufjs/aspromise@^1.1.2": + version "1.1.2" + resolved "https://registry.yarnpkg.com/@protobufjs/aspromise/-/aspromise-1.1.2.tgz#9b8b0cc663d669a7d8f6f5d0893a14d348f30fbf" + integrity sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ== + +"@protobufjs/base64@^1.1.2": + version "1.1.2" + resolved "https://registry.yarnpkg.com/@protobufjs/base64/-/base64-1.1.2.tgz#4c85730e59b9a1f1f349047dbf24296034bb2735" + integrity sha512-AZkcAA5vnN/v4PDqKyMR5lx7hZttPDgClv83E//FMNhR2TMcLUhfRUBHCmSl0oi9zMgDDqRUJkSxO3wm85+XLg== + +"@protobufjs/codegen@^2.0.4": + version "2.0.4" + resolved "https://registry.yarnpkg.com/@protobufjs/codegen/-/codegen-2.0.4.tgz#7ef37f0d010fb028ad1ad59722e506d9262815cb" + integrity sha512-YyFaikqM5sH0ziFZCN3xDC7zeGaB/d0IUb9CATugHWbd1FRFwWwt4ld4OYMPWu5a3Xe01mGAULCdqhMlPl29Jg== + +"@protobufjs/eventemitter@^1.1.0": + version "1.1.0" + resolved "https://registry.yarnpkg.com/@protobufjs/eventemitter/-/eventemitter-1.1.0.tgz#355cbc98bafad5978f9ed095f397621f1d066b70" + integrity sha512-j9ednRT81vYJ9OfVuXG6ERSTdEL1xVsNgqpkxMsbIabzSo3goCjDIveeGv5d03om39ML71RdmrGNjG5SReBP/Q== + +"@protobufjs/fetch@^1.1.0": + version "1.1.0" + resolved "https://registry.yarnpkg.com/@protobufjs/fetch/-/fetch-1.1.0.tgz#ba99fb598614af65700c1619ff06d454b0d84c45" + integrity sha512-lljVXpqXebpsijW71PZaCYeIcE5on1w5DlQy5WH6GLbFryLUrBD4932W/E2BSpfRJWseIL4v/KPgBFxDOIdKpQ== + dependencies: + "@protobufjs/aspromise" "^1.1.1" + "@protobufjs/inquire" "^1.1.0" + +"@protobufjs/float@^1.0.2": + version "1.0.2" + resolved "https://registry.yarnpkg.com/@protobufjs/float/-/float-1.0.2.tgz#5e9e1abdcb73fc0a7cb8b291df78c8cbd97b87d1" + integrity sha512-Ddb+kVXlXst9d+R9PfTIxh1EdNkgoRe5tOX6t01f1lYWOvJnSPDBlG241QLzcyPdoNTsblLUdujGSE4RzrTZGQ== + +"@protobufjs/inquire@^1.1.0": + version "1.1.0" + resolved "https://registry.yarnpkg.com/@protobufjs/inquire/-/inquire-1.1.0.tgz#ff200e3e7cf2429e2dcafc1140828e8cc638f089" + integrity sha512-kdSefcPdruJiFMVSbn801t4vFK7KB/5gd2fYvrxhuJYg8ILrmn9SKSX2tZdV6V+ksulWqS7aXjBcRXl3wHoD9Q== + +"@protobufjs/path@^1.1.2": + version "1.1.2" + resolved "https://registry.yarnpkg.com/@protobufjs/path/-/path-1.1.2.tgz#6cc2b20c5c9ad6ad0dccfd21ca7673d8d7fbf68d" + integrity sha512-6JOcJ5Tm08dOHAbdR3GrvP+yUUfkjG5ePsHYczMFLq3ZmMkAD98cDgcT2iA1lJ9NVwFd4tH/iSSoe44YWkltEA== + +"@protobufjs/pool@^1.1.0": + version "1.1.0" + resolved "https://registry.yarnpkg.com/@protobufjs/pool/-/pool-1.1.0.tgz#09fd15f2d6d3abfa9b65bc366506d6ad7846ff54" + integrity sha512-0kELaGSIDBKvcgS4zkjz1PeddatrjYcmMWOlAuAPwAeccUrPHdUqo/J6LiymHHEiJT5NrF1UVwxY14f+fy4WQw== + +"@protobufjs/utf8@^1.1.0": + version "1.1.0" + resolved "https://registry.yarnpkg.com/@protobufjs/utf8/-/utf8-1.1.0.tgz#a777360b5b39a1a2e5106f8e858f2fd2d060c570" + integrity sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw== + +"@types/google-protobuf@^3.15.6": + version "3.15.6" + resolved "https://registry.yarnpkg.com/@types/google-protobuf/-/google-protobuf-3.15.6.tgz#674a69493ef2c849b95eafe69167ea59079eb504" + integrity sha512-pYVNNJ+winC4aek+lZp93sIKxnXt5qMkuKmaqS3WGuTq0Bw1ZDYNBgzG5kkdtwcv+GmYJGo3yEg6z2cKKAiEdw== + +"@types/linkify-it@*": + version "3.0.2" + resolved "https://registry.yarnpkg.com/@types/linkify-it/-/linkify-it-3.0.2.tgz#fd2cd2edbaa7eaac7e7f3c1748b52a19143846c9" + integrity sha512-HZQYqbiFVWufzCwexrvh694SOim8z2d+xJl5UNamcvQFejLY/2YUtzXHYi3cHdI7PMlS8ejH2slRAOJQ32aNbA== + +"@types/markdown-it@^12.2.3": + version "12.2.3" + resolved "https://registry.yarnpkg.com/@types/markdown-it/-/markdown-it-12.2.3.tgz#0d6f6e5e413f8daaa26522904597be3d6cd93b51" + integrity sha512-GKMHFfv3458yYy+v/N8gjufHO6MSZKCOXpZc5GXIWWy8uldwfmPn98vp81gZ5f9SVw8YYBctgfJ22a2d7AOMeQ== + dependencies: + "@types/linkify-it" "*" + "@types/mdurl" "*" + +"@types/mdurl@*": + version "1.0.2" + resolved "https://registry.yarnpkg.com/@types/mdurl/-/mdurl-1.0.2.tgz#e2ce9d83a613bacf284c7be7d491945e39e1f8e9" + integrity sha512-eC4U9MlIcu2q0KQmXszyn5Akca/0jrQmwDRgpAMJai7qBWq4amIQhZyNau4VYGtCeALvW1/NtjzJJ567aZxfKA== + +"@types/node@>=13.7.0": + version "18.11.9" + resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.9.tgz#02d013de7058cea16d36168ef2fc653464cfbad4" + integrity sha512-CRpX21/kGdzjOpFsZSkcrXMGIBWMGNIHXXBVFSH+ggkftxg+XYP20TESbh+zFvFj3EQOl5byk0HTRn1IL6hbqg== + +"@types/offscreencanvas@^2019.7.0": + version "2019.7.0" + resolved "https://registry.yarnpkg.com/@types/offscreencanvas/-/offscreencanvas-2019.7.0.tgz#e4a932069db47bb3eabeb0b305502d01586fa90d" + integrity sha512-PGcyveRIpL1XIqK8eBsmRBt76eFgtzuPiSTyKHZxnGemp2yzGzWpjYKAfK3wIMiU7eH+851yEpiuP8JZerTmWg== + +acorn-jsx@^5.3.2: + version "5.3.2" + resolved "https://registry.yarnpkg.com/acorn-jsx/-/acorn-jsx-5.3.2.tgz#7ed5bb55908b3b2f1bc55c6af1653bada7f07937" + integrity sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ== + +acorn@^8.8.0: + version "8.8.1" + resolved "https://registry.yarnpkg.com/acorn/-/acorn-8.8.1.tgz#0a3f9cbecc4ec3bea6f0a80b66ae8dd2da250b73" + integrity sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA== + +ansi-styles@^4.1.0: + version "4.3.0" + resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937" + integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg== + dependencies: + color-convert "^2.0.1" + +argparse@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/argparse/-/argparse-2.0.1.tgz#246f50f3ca78a3240f6c997e8a9bd1eac49e4b38" + integrity sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q== + +balanced-match@^1.0.0: + version "1.0.2" + resolved "https://registry.yarnpkg.com/balanced-match/-/balanced-match-1.0.2.tgz#e83e3a7e3f300b34cb9d87f615fa0cbf357690ee" + integrity sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw== + +bluebird@^3.7.2: + version "3.7.2" + resolved "https://registry.yarnpkg.com/bluebird/-/bluebird-3.7.2.tgz#9f229c15be272454ffa973ace0dbee79a1b0c36f" + integrity sha512-XpNj6GDQzdfW+r2Wnn7xiSAd7TM3jzkxGXBGTtWKuSXv1xUV+azxAm8jdWZN06QTQk+2N2XB9jRDkvbmQmcRtg== + +brace-expansion@^1.1.7: + version "1.1.11" + resolved "https://registry.yarnpkg.com/brace-expansion/-/brace-expansion-1.1.11.tgz#3c7fcbf529d87226f3d2f52b966ff5271eb441dd" + integrity sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA== + dependencies: + balanced-match "^1.0.0" + concat-map "0.0.1" + +brace-expansion@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/brace-expansion/-/brace-expansion-2.0.1.tgz#1edc459e0f0c548486ecf9fc99f2221364b9a0ae" + integrity sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA== + dependencies: + balanced-match "^1.0.0" + +buffer-from@^1.0.0: + version "1.1.2" + resolved "https://registry.yarnpkg.com/buffer-from/-/buffer-from-1.1.2.tgz#2b146a6fd72e80b4f55d255f35ed59a3a9a41bd5" + integrity sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ== + +catharsis@^0.9.0: + version "0.9.0" + resolved "https://registry.yarnpkg.com/catharsis/-/catharsis-0.9.0.tgz#40382a168be0e6da308c277d3a2b3eb40c7d2121" + integrity sha512-prMTQVpcns/tzFgFVkVp6ak6RykZyWb3gu8ckUpd6YkTlacOd3DXGJjIpD4Q6zJirizvaiAjSSHlOsA+6sNh2A== + dependencies: + lodash "^4.17.15" + +chalk@^4.0.0: + version "4.1.2" + resolved "https://registry.yarnpkg.com/chalk/-/chalk-4.1.2.tgz#aac4e2b7734a740867aeb16bf02aad556a1e7a01" + integrity sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA== + dependencies: + ansi-styles "^4.1.0" + supports-color "^7.1.0" + +color-convert@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3" + integrity sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ== + dependencies: + color-name "~1.1.4" + +color-name@~1.1.4: + version "1.1.4" + resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2" + integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA== + +concat-map@0.0.1: + version "0.0.1" + resolved "https://registry.yarnpkg.com/concat-map/-/concat-map-0.0.1.tgz#d8a96bd77fd68df7793a73036a3ba0d5405d477b" + integrity sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg== + +deep-is@~0.1.3: + version "0.1.4" + resolved "https://registry.yarnpkg.com/deep-is/-/deep-is-0.1.4.tgz#a6f2dce612fadd2ef1f519b73551f17e85199831" + integrity sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ== + +entities@~2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/entities/-/entities-2.1.0.tgz#992d3129cf7df6870b96c57858c249a120f8b8b5" + integrity sha512-hCx1oky9PFrJ611mf0ifBLBRW8lUUVRlFolb5gWRfIELabBlbp9xZvrqZLZAs+NxFnbfQoeGd8wDkygjg7U85w== + +escape-string-regexp@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-2.0.0.tgz#a30304e99daa32e23b2fd20f51babd07cffca344" + integrity sha512-UpzcLCXolUWcNu5HtVMHYdXJjArjsF9C0aNnquZYY4uW/Vu0miy5YoWvbV345HauVvcAUnpRuhMMcqTcGOY2+w== + +escodegen@^1.13.0: + version "1.14.3" + resolved "https://registry.yarnpkg.com/escodegen/-/escodegen-1.14.3.tgz#4e7b81fba61581dc97582ed78cab7f0e8d63f503" + integrity sha512-qFcX0XJkdg+PB3xjZZG/wKSuT1PnQWx57+TVSjIMmILd2yC/6ByYElPwJnslDsuWuSAp4AwJGumarAAmJch5Kw== + dependencies: + esprima "^4.0.1" + estraverse "^4.2.0" + esutils "^2.0.2" + optionator "^0.8.1" + optionalDependencies: + source-map "~0.6.1" + +eslint-visitor-keys@^3.3.0: + version "3.3.0" + resolved "https://registry.yarnpkg.com/eslint-visitor-keys/-/eslint-visitor-keys-3.3.0.tgz#f6480fa6b1f30efe2d1968aa8ac745b862469826" + integrity sha512-mQ+suqKJVyeuwGYHAdjMFqjCyfl8+Ldnxuyp3ldiMBFKkvytrXUZWaiPCEav8qDHKty44bD+qV1IP4T+w+xXRA== + +espree@^9.0.0: + version "9.4.0" + resolved "https://registry.yarnpkg.com/espree/-/espree-9.4.0.tgz#cd4bc3d6e9336c433265fc0aa016fc1aaf182f8a" + integrity sha512-DQmnRpLj7f6TgN/NYb0MTzJXL+vJF9h3pHy4JhCIs3zwcgez8xmGg3sXHcEO97BrmO2OSvCwMdfdlyl+E9KjOw== + dependencies: + acorn "^8.8.0" + acorn-jsx "^5.3.2" + eslint-visitor-keys "^3.3.0" + +esprima@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/esprima/-/esprima-4.0.1.tgz#13b04cdb3e6c5d19df91ab6987a8695619b0aa71" + integrity sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A== + +estraverse@^4.2.0: + version "4.3.0" + resolved "https://registry.yarnpkg.com/estraverse/-/estraverse-4.3.0.tgz#398ad3f3c5a24948be7725e83d11a7de28cdbd1d" + integrity sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw== + +estraverse@^5.1.0: + version "5.3.0" + resolved "https://registry.yarnpkg.com/estraverse/-/estraverse-5.3.0.tgz#2eea5290702f26ab8fe5370370ff86c965d21123" + integrity sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA== + +esutils@^2.0.2: + version "2.0.3" + resolved "https://registry.yarnpkg.com/esutils/-/esutils-2.0.3.tgz#74d2eb4de0b8da1293711910d50775b9b710ef64" + integrity sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g== + +fast-levenshtein@~2.0.6: + version "2.0.6" + resolved "https://registry.yarnpkg.com/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz#3d8a5c66883a16a30ca8643e851f19baa7797917" + integrity sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw== + +fs.realpath@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/fs.realpath/-/fs.realpath-1.0.0.tgz#1504ad2523158caa40db4a2787cb01411994ea4f" + integrity sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw== + +glob@^7.1.3: + version "7.2.3" + resolved "https://registry.yarnpkg.com/glob/-/glob-7.2.3.tgz#b8df0fb802bbfa8e89bd1d938b4e16578ed44f2b" + integrity sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q== + dependencies: + fs.realpath "^1.0.0" + inflight "^1.0.4" + inherits "2" + minimatch "^3.1.1" + once "^1.3.0" + path-is-absolute "^1.0.0" + +glob@^8.0.0: + version "8.0.3" + resolved "https://registry.yarnpkg.com/glob/-/glob-8.0.3.tgz#415c6eb2deed9e502c68fa44a272e6da6eeca42e" + integrity sha512-ull455NHSHI/Y1FqGaaYFaLGkNMMJbavMrEGFXG/PGrg6y7sutWHUHrz6gy6WEBH6akM1M414dWKCNs+IhKdiQ== + dependencies: + fs.realpath "^1.0.0" + inflight "^1.0.4" + inherits "2" + minimatch "^5.0.1" + once "^1.3.0" + +google-protobuf@^3.15.5, google-protobuf@^3.21.2, google-protobuf@^3.6.1: + version "3.21.2" + resolved "https://registry.yarnpkg.com/google-protobuf/-/google-protobuf-3.21.2.tgz#4580a2bea8bbb291ee579d1fefb14d6fa3070ea4" + integrity sha512-3MSOYFO5U9mPGikIYCzK0SaThypfGgS6bHqrUGXG3DPHCrb+txNqeEcns1W0lkGfk0rCyNXm7xB9rMxnCiZOoA== + +graceful-fs@^4.1.9: + version "4.2.10" + resolved "https://registry.yarnpkg.com/graceful-fs/-/graceful-fs-4.2.10.tgz#147d3a006da4ca3ce14728c7aefc287c367d7a6c" + integrity sha512-9ByhssR2fPVsNZj478qUUbKfmL0+t5BDVyjShtyZZLiK7ZDAArFFfopyOTj0M05wE2tJPisA4iTnnXl2YoPvOA== + +has-flag@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-4.0.0.tgz#944771fd9c81c81265c4d6941860da06bb59479b" + integrity sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ== + +inflight@^1.0.4: + version "1.0.6" + resolved "https://registry.yarnpkg.com/inflight/-/inflight-1.0.6.tgz#49bd6331d7d02d0c09bc910a1075ba8165b56df9" + integrity sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA== + dependencies: + once "^1.3.0" + wrappy "1" + +inherits@2: + version "2.0.4" + resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.4.tgz#0fa2c64f932917c3433a0ded55363aae37416b7c" + integrity sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ== + +js2xmlparser@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/js2xmlparser/-/js2xmlparser-4.0.2.tgz#2a1fdf01e90585ef2ae872a01bc169c6a8d5e60a" + integrity sha512-6n4D8gLlLf1n5mNLQPRfViYzu9RATblzPEtm1SthMX1Pjao0r9YI9nw7ZIfRxQMERS87mcswrg+r/OYrPRX6jA== + dependencies: + xmlcreate "^2.0.4" + +jsdoc@^3.6.3: + version "3.6.11" + resolved "https://registry.yarnpkg.com/jsdoc/-/jsdoc-3.6.11.tgz#8bbb5747e6f579f141a5238cbad4e95e004458ce" + integrity sha512-8UCU0TYeIYD9KeLzEcAu2q8N/mx9O3phAGl32nmHlE0LpaJL71mMkP4d+QE5zWfNt50qheHtOZ0qoxVrsX5TUg== + dependencies: + "@babel/parser" "^7.9.4" + "@types/markdown-it" "^12.2.3" + bluebird "^3.7.2" + catharsis "^0.9.0" + escape-string-regexp "^2.0.0" + js2xmlparser "^4.0.2" + klaw "^3.0.0" + markdown-it "^12.3.2" + markdown-it-anchor "^8.4.1" + marked "^4.0.10" + mkdirp "^1.0.4" + requizzle "^0.2.3" + strip-json-comments "^3.1.0" + taffydb "2.6.2" + underscore "~1.13.2" + +klaw@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/klaw/-/klaw-3.0.0.tgz#b11bec9cf2492f06756d6e809ab73a2910259146" + integrity sha512-0Fo5oir+O9jnXu5EefYbVK+mHMBeEVEy2cmctR1O1NECcCkPRreJKrS6Qt/j3KC2C148Dfo9i3pCmCMsdqGr0g== + dependencies: + graceful-fs "^4.1.9" + +levn@~0.3.0: + version "0.3.0" + resolved "https://registry.yarnpkg.com/levn/-/levn-0.3.0.tgz#3b09924edf9f083c0490fdd4c0bc4421e04764ee" + integrity sha512-0OO4y2iOHix2W6ujICbKIaEQXvFQHue65vUG3pb5EUomzPI90z9hsA1VsO/dbIIpC53J8gxM9Q4Oho0jrCM/yA== + dependencies: + prelude-ls "~1.1.2" + type-check "~0.3.2" + +linkify-it@^3.0.1: + version "3.0.3" + resolved "https://registry.yarnpkg.com/linkify-it/-/linkify-it-3.0.3.tgz#a98baf44ce45a550efb4d49c769d07524cc2fa2e" + integrity sha512-ynTsyrFSdE5oZ/O9GEf00kPngmOfVwazR5GKDq6EYfhlpFug3J2zybX56a2PRRpc9P+FuSoGNAwjlbDs9jJBPQ== + dependencies: + uc.micro "^1.0.1" + +lodash@^4.17.14, lodash@^4.17.15: + version "4.17.21" + resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c" + integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg== + +long@^5.0.0: + version "5.2.0" + resolved "https://registry.yarnpkg.com/long/-/long-5.2.0.tgz#2696dadf4b4da2ce3f6f6b89186085d94d52fd61" + integrity sha512-9RTUNjK60eJbx3uz+TEGF7fUr29ZDxR5QzXcyDpeSfeH28S9ycINflOgOlppit5U+4kNTe83KQnMEerw7GmE8w== + +lru-cache@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/lru-cache/-/lru-cache-6.0.0.tgz#6d6fe6570ebd96aaf90fcad1dafa3b2566db3a94" + integrity sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA== + dependencies: + yallist "^4.0.0" + +markdown-it-anchor@^8.4.1: + version "8.6.5" + resolved "https://registry.yarnpkg.com/markdown-it-anchor/-/markdown-it-anchor-8.6.5.tgz#30c4bc5bbff327f15ce3c429010ec7ba75e7b5f8" + integrity sha512-PI1qEHHkTNWT+X6Ip9w+paonfIQ+QZP9sCeMYi47oqhH+EsW8CrJ8J7CzV19QVOj6il8ATGbK2nTECj22ZHGvQ== + +markdown-it@^12.3.2: + version "12.3.2" + resolved "https://registry.yarnpkg.com/markdown-it/-/markdown-it-12.3.2.tgz#bf92ac92283fe983fe4de8ff8abfb5ad72cd0c90" + integrity sha512-TchMembfxfNVpHkbtriWltGWc+m3xszaRD0CZup7GFFhzIgQqxIfn3eGj1yZpfuflzPvfkt611B2Q/Bsk1YnGg== + dependencies: + argparse "^2.0.1" + entities "~2.1.0" + linkify-it "^3.0.1" + mdurl "^1.0.1" + uc.micro "^1.0.5" + +marked@^4.0.10: + version "4.2.1" + resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.1.tgz#eaa32594e45b4e58c02e4d118531fd04345de3b4" + integrity sha512-VK1/jNtwqDLvPktNpL0Fdg3qoeUZhmRsuiIjPEy/lHwXW4ouLoZfO4XoWd4ClDt+hupV1VLpkZhEovjU0W/kqA== + +mdurl@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/mdurl/-/mdurl-1.0.1.tgz#fe85b2ec75a59037f2adfec100fd6c601761152e" + integrity sha512-/sKlQJCBYVY9Ers9hqzKou4H6V5UWc/M59TH2dvkt+84itfnq7uFOMLpOiOS4ujvHP4etln18fmIxA5R5fll0g== + +minimatch@^3.1.1: + version "3.1.2" + resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-3.1.2.tgz#19cd194bfd3e428f049a70817c038d89ab4be35b" + integrity sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw== + dependencies: + brace-expansion "^1.1.7" + +minimatch@^5.0.1: + version "5.1.0" + resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-5.1.0.tgz#1717b464f4971b144f6aabe8f2d0b8e4511e09c7" + integrity sha512-9TPBGGak4nHfGZsPBohm9AWg6NoT7QTCehS3BIJABslyZbzxfV78QM2Y6+i741OPZIafFAaiiEMh5OyIrJPgtg== + dependencies: + brace-expansion "^2.0.1" + +minimist@^1.2.0: + version "1.2.7" + resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.7.tgz#daa1c4d91f507390437c6a8bc01078e7000c4d18" + integrity sha512-bzfL1YUZsP41gmu/qjrEk0Q6i2ix/cVeAhbCbqH9u3zYutS1cLg00qhrD0M2MVdCcx4Sc0UpP2eBWo9rotpq6g== + +mkdirp@^1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-1.0.4.tgz#3eb5ed62622756d79a5f0e2a221dfebad75c2f7e" + integrity sha512-vVqVZQyf3WLx2Shd0qJ9xuvqgAyKPLAiqITEtqW0oIUjzo3PePDd6fW9iFz30ef7Ysp/oiWqbhszeGWW2T6Gzw== + +once@^1.3.0: + version "1.4.0" + resolved "https://registry.yarnpkg.com/once/-/once-1.4.0.tgz#583b1aa775961d4b113ac17d9c50baef9dd76bd1" + integrity sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w== + dependencies: + wrappy "1" + +optionator@^0.8.1: + version "0.8.3" + resolved "https://registry.yarnpkg.com/optionator/-/optionator-0.8.3.tgz#84fa1d036fe9d3c7e21d99884b601167ec8fb495" + integrity sha512-+IW9pACdk3XWmmTXG8m3upGUJst5XRGzxMRjXzAuJ1XnIFNvfhjjIuYkDvysnPQ7qzqVzLt78BCruntqRhWQbA== + dependencies: + deep-is "~0.1.3" + fast-levenshtein "~2.0.6" + levn "~0.3.0" + prelude-ls "~1.1.2" + type-check "~0.3.2" + word-wrap "~1.2.3" + +path-is-absolute@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/path-is-absolute/-/path-is-absolute-1.0.1.tgz#174b9268735534ffbc7ace6bf53a5a9e1b5c5f5f" + integrity sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg== + +prelude-ls@~1.1.2: + version "1.1.2" + resolved "https://registry.yarnpkg.com/prelude-ls/-/prelude-ls-1.1.2.tgz#21932a549f5e52ffd9a827f570e04be62a97da54" + integrity sha512-ESF23V4SKG6lVSGZgYNpbsiaAkdab6ZgOxe52p7+Kid3W3u3bxR4Vfd/o21dmN7jSt0IwgZ4v5MUd26FEtXE9w== + +protobufjs-cli@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/protobufjs-cli/-/protobufjs-cli-1.0.2.tgz#905fc49007cf4aaf3c45d5f250eb294eedeea062" + integrity sha512-cz9Pq9p/Zs7okc6avH20W7QuyjTclwJPgqXG11jNaulfS3nbVisID8rC+prfgq0gbZE0w9LBFd1OKFF03kgFzg== + dependencies: + chalk "^4.0.0" + escodegen "^1.13.0" + espree "^9.0.0" + estraverse "^5.1.0" + glob "^8.0.0" + jsdoc "^3.6.3" + minimist "^1.2.0" + semver "^7.1.2" + tmp "^0.2.1" + uglify-js "^3.7.7" + +protobufjs@^7.1.2: + version "7.1.2" + resolved "https://registry.yarnpkg.com/protobufjs/-/protobufjs-7.1.2.tgz#a0cf6aeaf82f5625bffcf5a38b7cd2a7de05890c" + integrity sha512-4ZPTPkXCdel3+L81yw3dG6+Kq3umdWKh7Dc7GW/CpNk4SX3hK58iPCWeCyhVTDrbkNeKrYNZ7EojM5WDaEWTLQ== + dependencies: + "@protobufjs/aspromise" "^1.1.2" + "@protobufjs/base64" "^1.1.2" + "@protobufjs/codegen" "^2.0.4" + "@protobufjs/eventemitter" "^1.1.0" + "@protobufjs/fetch" "^1.1.0" + "@protobufjs/float" "^1.0.2" + "@protobufjs/inquire" "^1.1.0" + "@protobufjs/path" "^1.1.2" + "@protobufjs/pool" "^1.1.0" + "@protobufjs/utf8" "^1.1.0" + "@types/node" ">=13.7.0" + long "^5.0.0" + +requizzle@^0.2.3: + version "0.2.3" + resolved "https://registry.yarnpkg.com/requizzle/-/requizzle-0.2.3.tgz#4675c90aacafb2c036bd39ba2daa4a1cb777fded" + integrity sha512-YanoyJjykPxGHii0fZP0uUPEXpvqfBDxWV7s6GKAiiOsiqhX6vHNyW3Qzdmqp/iq/ExbhaGbVrjB4ruEVSM4GQ== + dependencies: + lodash "^4.17.14" + +rimraf@^3.0.0: + version "3.0.2" + resolved "https://registry.yarnpkg.com/rimraf/-/rimraf-3.0.2.tgz#f1a5402ba6220ad52cc1282bac1ae3aa49fd061a" + integrity sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA== + dependencies: + glob "^7.1.3" + +semver@5.6.0: + version "5.6.0" + resolved "https://registry.yarnpkg.com/semver/-/semver-5.6.0.tgz#7e74256fbaa49c75aa7c7a205cc22799cac80004" + integrity sha512-RS9R6R35NYgQn++fkDWaOmqGoj4Ek9gGs+DPxNUZKuwE183xjJroKvyo1IzVFeXvUrvmALy6FWD5xrdJT25gMg== + +semver@^7.1.2: + version "7.3.8" + resolved "https://registry.yarnpkg.com/semver/-/semver-7.3.8.tgz#07a78feafb3f7b32347d725e33de7e2a2df67798" + integrity sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A== + dependencies: + lru-cache "^6.0.0" + +source-map-support@0.5.9: + version "0.5.9" + resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.9.tgz#41bc953b2534267ea2d605bccfa7bfa3111ced5f" + integrity sha512-gR6Rw4MvUlYy83vP0vxoVNzM6t8MUXqNuRsuBmBHQDu1Fh6X015FrLdgoDKcNdkwGubozq0P4N0Q37UyFVr1EA== + dependencies: + buffer-from "^1.0.0" + source-map "^0.6.0" + +source-map@^0.6.0, source-map@~0.6.1: + version "0.6.1" + resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.6.1.tgz#74722af32e9614e9c287a8d0bbde48b5e2f1a263" + integrity sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g== + +strip-json-comments@^3.1.0: + version "3.1.1" + resolved "https://registry.yarnpkg.com/strip-json-comments/-/strip-json-comments-3.1.1.tgz#31f1281b3832630434831c310c01cccda8cbe006" + integrity sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig== + +supports-color@^7.1.0: + version "7.2.0" + resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-7.2.0.tgz#1b7dcdcb32b8138801b3e478ba6a51caa89648da" + integrity sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw== + dependencies: + has-flag "^4.0.0" + +taffydb@2.6.2: + version "2.6.2" + resolved "https://registry.yarnpkg.com/taffydb/-/taffydb-2.6.2.tgz#7cbcb64b5a141b6a2efc2c5d2c67b4e150b2a268" + integrity sha512-y3JaeRSplks6NYQuCOj3ZFMO3j60rTwbuKCvZxsAraGYH2epusatvZ0baZYA01WsGqJBq/Dl6vOrMUJqyMj8kA== + +tmp@^0.2.1: + version "0.2.1" + resolved "https://registry.yarnpkg.com/tmp/-/tmp-0.2.1.tgz#8457fc3037dcf4719c251367a1af6500ee1ccf14" + integrity sha512-76SUhtfqR2Ijn+xllcI5P1oyannHNHByD80W1q447gU3mp9G9PSpGdWmjUOHRDPiHYacIk66W7ubDTuPF3BEtQ== + dependencies: + rimraf "^3.0.0" + +ts-protoc-gen@^0.15.0: + version "0.15.0" + resolved "https://registry.yarnpkg.com/ts-protoc-gen/-/ts-protoc-gen-0.15.0.tgz#2fec5930b46def7dcc9fa73c060d770b7b076b7b" + integrity sha512-TycnzEyrdVDlATJ3bWFTtra3SCiEP0W0vySXReAuEygXCUr1j2uaVyL0DhzjwuUdQoW5oXPwk6oZWeA0955V+g== + dependencies: + google-protobuf "^3.15.5" + +tslib@^1.8.1: + version "1.14.1" + resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.14.1.tgz#cf2d38bdc34a134bcaf1091c41f6619e2f672d00" + integrity sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg== + +tsutils@3.21.0: + version "3.21.0" + resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-3.21.0.tgz#b48717d394cea6c1e096983eed58e9d61715b623" + integrity sha512-mHKK3iUXL+3UF6xL5k0PEhKRUBKPBCv/+RkEOpjRWxxx27KKRBmmA60A9pgOUvMi8GKhRMPEmjBRPzs2W7O1OA== + dependencies: + tslib "^1.8.1" + +type-check@~0.3.2: + version "0.3.2" + resolved "https://registry.yarnpkg.com/type-check/-/type-check-0.3.2.tgz#5884cab512cf1d355e3fb784f30804b2b520db72" + integrity sha512-ZCmOJdvOWDBYJlzAoFkC+Q0+bUyEOS1ltgp1MGU03fqHG+dbi9tBFU2Rd9QKiDZFAYrhPh2JUf7rZRIuHRKtOg== + dependencies: + prelude-ls "~1.1.2" + +typescript@^4.8.4: + version "4.8.4" + resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.8.4.tgz#c464abca159669597be5f96b8943500b238e60e6" + integrity sha512-QCh+85mCy+h0IGff8r5XWzOVSbBO+KfeYrMQh7NJ58QujwcE22u+NUSmUxqF+un70P9GXKxa2HCNiTTMJknyjQ== + +uc.micro@^1.0.1, uc.micro@^1.0.5: + version "1.0.6" + resolved "https://registry.yarnpkg.com/uc.micro/-/uc.micro-1.0.6.tgz#9c411a802a409a91fc6cf74081baba34b24499ac" + integrity sha512-8Y75pvTYkLJW2hWQHXxoqRgV7qb9B+9vFEtidML+7koHUFapnVJAZ6cKs+Qjz5Aw3aZWHMC6u0wJE3At+nSGwA== + +uglify-js@^3.7.7: + version "3.17.4" + resolved "https://registry.yarnpkg.com/uglify-js/-/uglify-js-3.17.4.tgz#61678cf5fa3f5b7eb789bb345df29afb8257c22c" + integrity sha512-T9q82TJI9e/C1TAxYvfb16xO120tMVFZrGA3f9/P4424DNu6ypK103y0GPFVa17yotwSyZW5iYXgjYHkGrJW/g== + +underscore@~1.13.2: + version "1.13.6" + resolved "https://registry.yarnpkg.com/underscore/-/underscore-1.13.6.tgz#04786a1f589dc6c09f761fc5f45b89e935136441" + integrity sha512-+A5Sja4HP1M08MaXya7p5LvjuM7K6q/2EaC0+iovj/wOcMsTzMvDFbasi/oSapiwOlt252IqsKqPjCl7huKS0A== + +word-wrap@~1.2.3: + version "1.2.3" + resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c" + integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ== + +wrappy@1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/wrappy/-/wrappy-1.0.2.tgz#b5243d8f3ec1aa35f1364605bc0d1036e30ab69f" + integrity sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ== + +xmlcreate@^2.0.4: + version "2.0.4" + resolved "https://registry.yarnpkg.com/xmlcreate/-/xmlcreate-2.0.4.tgz#0c5ab0f99cdd02a81065fa9cd8f8ae87624889be" + integrity sha512-nquOebG4sngPmGPICTS5EnxqhKbCmz5Ox5hsszI2T6U5qdrJizBc+0ilYSEjTSzU0yZcmvppztXe/5Al5fUwdg== + +yallist@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72" + integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A== From 66e591d4bce676381994f1b3c3af68a2aada31af Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 4 Nov 2022 11:13:30 -0700 Subject: [PATCH 29/34] Remove the use of designated initializers. PiperOrigin-RevId: 486185274 --- .../cc/vision/image_segmenter/image_segmenter_graph.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index d3e522d92..31fed6d8d 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -287,10 +287,9 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { tensor_to_images[Output::Multiple(kSegmentationTag)][i])); } } - return {{ - .segmented_masks = segmented_masks, - .image = preprocessing[Output(kImageTag)], - }}; + return ImageSegmenterOutputs{ + /*segmented_masks=*/segmented_masks, + /*image=*/preprocessing[Output(kImageTag)]}; } }; From 5e1a2fcdbb0077baf96760b416a5e0d9b10c83e3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 4 Nov 2022 11:18:27 -0700 Subject: [PATCH 30/34] Update EmbeddingResult format and dependent tasks. PiperOrigin-RevId: 486186491 --- mediapipe/tasks/cc/components/BUILD | 35 ----- .../tasks/cc/components/calculators/BUILD | 4 +- .../tensors_to_embeddings_calculator.cc | 32 ++--- .../tensors_to_embeddings_calculator.proto | 6 +- .../tensors_to_embeddings_calculator_test.cc | 127 +++++++----------- .../tasks/cc/components/containers/BUILD | 9 ++ .../components/containers/embedding_result.cc | 57 ++++++++ .../components/containers/embedding_result.h | 72 ++++++++++ .../containers/proto/embeddings.proto | 31 ++--- .../tasks/cc/components/processors/BUILD | 35 +++++ .../{ => processors}/embedder_options.cc | 10 +- .../{ => processors}/embedder_options.h | 12 +- .../embedding_postprocessing_graph.cc | 20 +-- .../embedding_postprocessing_graph.h | 16 ++- .../embedding_postprocessing_graph_test.cc | 60 +++++---- .../cc/components/processors/proto/BUILD | 15 +++ .../proto/embedder_options.proto | 5 +- ...bedding_postprocessing_graph_options.proto | 2 +- mediapipe/tasks/cc/components/proto/BUILD | 15 --- mediapipe/tasks/cc/components/utils/BUILD | 4 +- .../cc/components/utils/cosine_similarity.cc | 40 +++--- .../cc/components/utils/cosine_similarity.h | 14 +- .../utils/cosine_similarity_test.cc | 48 ++++--- .../tasks/cc/vision/image_embedder/BUILD | 9 +- .../vision/image_embedder/image_embedder.cc | 39 +++--- .../cc/vision/image_embedder/image_embedder.h | 30 +++-- .../image_embedder/image_embedder_graph.cc | 23 ++-- .../image_embedder/image_embedder_test.cc | 109 +++++++-------- .../cc/vision/image_embedder/proto/BUILD | 2 +- .../proto/image_embedder_graph_options.proto | 4 +- .../mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 + 31 files changed, 499 insertions(+), 387 deletions(-) create mode 100644 mediapipe/tasks/cc/components/containers/embedding_result.cc create mode 100644 mediapipe/tasks/cc/components/containers/embedding_result.h rename mediapipe/tasks/cc/components/{ => processors}/embedder_options.cc (77%) rename mediapipe/tasks/cc/components/{ => processors}/embedder_options.h (79%) rename mediapipe/tasks/cc/components/{ => processors}/embedding_postprocessing_graph.cc (94%) rename mediapipe/tasks/cc/components/{ => processors}/embedding_postprocessing_graph.h (77%) rename mediapipe/tasks/cc/components/{ => processors}/embedding_postprocessing_graph_test.cc (71%) rename mediapipe/tasks/cc/components/{ => processors}/proto/embedder_options.proto (88%) rename mediapipe/tasks/cc/components/{ => processors}/proto/embedding_postprocessing_graph_options.proto (96%) diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD index 4b5439035..c90349ab2 100644 --- a/mediapipe/tasks/cc/components/BUILD +++ b/mediapipe/tasks/cc/components/BUILD @@ -61,41 +61,6 @@ cc_library( # TODO: Enable this test -cc_library( - name = "embedder_options", - srcs = ["embedder_options.cc"], - hdrs = ["embedder_options.h"], - deps = ["//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto"], -) - -cc_library( - name = "embedding_postprocessing_graph", - srcs = ["embedding_postprocessing_graph.cc"], - hdrs = ["embedding_postprocessing_graph.h"], - deps = [ - "//mediapipe/calculators/tensor:tensors_dequantization_calculator", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:tensor", - "//mediapipe/framework/tool:options_map", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", - "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto", - "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", - "//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto", - "//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:source_or_node_output", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/metadata:metadata_extractor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@org_tensorflow//tensorflow/lite/schema:schema_fbs", - ], - alwayslink = 1, -) - # TODO: Investigate rewriting the build rule to only link # the Bert Preprocessor if it's needed. cc_library( diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index a688f291a..061875272 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -163,7 +163,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:embedder_options_proto", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto", ], ) @@ -178,7 +178,7 @@ cc_library( "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", - "//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", ], diff --git a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.cc b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.cc index 05b3e1f3f..3ea9bcca4 100644 --- a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.cc @@ -26,14 +26,14 @@ #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" -#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" namespace mediapipe { namespace api2 { namespace { -using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; +using ::mediapipe::tasks::components::containers::proto::Embedding; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; // Computes the inverse L2 norm of the provided array of values. Returns 1.0 in @@ -66,7 +66,7 @@ float GetInverseL2Norm(const float* values, int size) { class TensorsToEmbeddingsCalculator : public Node { public: static constexpr Input> kTensorsIn{"TENSORS"}; - static constexpr Output kEmbeddingsOut{"EMBEDDING_RESULT"}; + static constexpr Output kEmbeddingsOut{"EMBEDDINGS"}; MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut); absl::Status Open(CalculatorContext* cc) override; @@ -77,8 +77,8 @@ class TensorsToEmbeddingsCalculator : public Node { bool quantize_; std::vector head_names_; - void FillFloatEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry); - void FillQuantizedEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry); + void FillFloatEmbedding(const Tensor& tensor, Embedding* embedding); + void FillQuantizedEmbedding(const Tensor& tensor, Embedding* embedding); }; absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) { @@ -104,42 +104,42 @@ absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) { for (int i = 0; i < tensors.size(); ++i) { const auto& tensor = tensors[i]; RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32); - auto* embeddings = result.add_embeddings(); - embeddings->set_head_index(i); + auto* embedding = result.add_embeddings(); + embedding->set_head_index(i); if (!head_names_.empty()) { - embeddings->set_head_name(head_names_[i]); + embedding->set_head_name(head_names_[i]); } if (quantize_) { - FillQuantizedEmbeddingEntry(tensor, embeddings->add_entries()); + FillQuantizedEmbedding(tensor, embedding); } else { - FillFloatEmbeddingEntry(tensor, embeddings->add_entries()); + FillFloatEmbedding(tensor, embedding); } } kEmbeddingsOut(cc).Send(result); return absl::OkStatus(); } -void TensorsToEmbeddingsCalculator::FillFloatEmbeddingEntry( - const Tensor& tensor, EmbeddingEntry* entry) { +void TensorsToEmbeddingsCalculator::FillFloatEmbedding(const Tensor& tensor, + Embedding* embedding) { int size = tensor.shape().num_elements(); auto tensor_view = tensor.GetCpuReadView(); const float* tensor_buffer = tensor_view.buffer(); float inv_l2_norm = l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f; - auto* float_embedding = entry->mutable_float_embedding(); + auto* float_embedding = embedding->mutable_float_embedding(); for (int i = 0; i < size; ++i) { float_embedding->add_values(tensor_buffer[i] * inv_l2_norm); } } -void TensorsToEmbeddingsCalculator::FillQuantizedEmbeddingEntry( - const Tensor& tensor, EmbeddingEntry* entry) { +void TensorsToEmbeddingsCalculator::FillQuantizedEmbedding( + const Tensor& tensor, Embedding* embedding) { int size = tensor.shape().num_elements(); auto tensor_view = tensor.GetCpuReadView(); const float* tensor_buffer = tensor_view.buffer(); float inv_l2_norm = l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f; - auto* values = entry->mutable_quantized_embedding()->mutable_values(); + auto* values = embedding->mutable_quantized_embedding()->mutable_values(); values->resize(size); for (int i = 0; i < size; ++i) { // Normalize. diff --git a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto index 2f088c503..745052afa 100644 --- a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto +++ b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto @@ -18,7 +18,7 @@ syntax = "proto2"; package mediapipe; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/embedder_options.proto"; +import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; message TensorsToEmbeddingsCalculatorOptions { extend mediapipe.CalculatorOptions { @@ -27,8 +27,8 @@ message TensorsToEmbeddingsCalculatorOptions { // The embedder options defining whether to L2-normalize or scalar-quantize // the outputs. - optional mediapipe.tasks.components.proto.EmbedderOptions embedder_options = - 1; + optional mediapipe.tasks.components.processors.proto.EmbedderOptions + embedder_options = 1; // The embedder head names. repeated string head_names = 2; diff --git a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator_test.cc index b6d319121..b79cf4863 100644 --- a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator_test.cc +++ b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator_test.cc @@ -55,7 +55,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, FailsWithInvalidHeadNamesNumber) { CalculatorRunner runner(ParseTextProtoOrDie(R"pb( calculator: "TensorsToEmbeddingsCalculator" input_stream: "TENSORS:tensors" - output_stream: "EMBEDDING_RESULT:embeddings" + output_stream: "EMBEDDINGS:embeddings" options { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" } } @@ -73,7 +73,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) { CalculatorRunner runner(ParseTextProtoOrDie(R"pb( calculator: "TensorsToEmbeddingsCalculator" input_stream: "TENSORS:tensors" - output_stream: "EMBEDDING_RESULT:embeddings" + output_stream: "EMBEDDINGS:embeddings" options { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { embedder_options { l2_normalize: false quantize: false } @@ -84,28 +84,24 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) { BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); MP_ASSERT_OK(runner.Run()); - const EmbeddingResult& result = runner.Outputs() - .Get("EMBEDDING_RESULT", 0) - .packets[0] - .Get(); - EXPECT_THAT( - result, - EqualsProto(ParseTextProtoOrDie( - R"pb(embeddings { - entries { float_embedding { values: 0.1 values: 0.2 } } - head_index: 0 - } - embeddings { - entries { float_embedding { values: -0.2 values: -0.3 } } - head_index: 1 - })pb"))); + const EmbeddingResult& result = + runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get(); + EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie( + R"pb(embeddings { + float_embedding { values: 0.1 values: 0.2 } + head_index: 0 + } + embeddings { + float_embedding { values: -0.2 values: -0.3 } + head_index: 1 + })pb"))); } TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) { CalculatorRunner runner(ParseTextProtoOrDie(R"pb( calculator: "TensorsToEmbeddingsCalculator" input_stream: "TENSORS:tensors" - output_stream: "EMBEDDING_RESULT:embeddings" + output_stream: "EMBEDDINGS:embeddings" options { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { embedder_options { l2_normalize: false quantize: false } @@ -118,30 +114,26 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) { BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); MP_ASSERT_OK(runner.Run()); - const EmbeddingResult& result = runner.Outputs() - .Get("EMBEDDING_RESULT", 0) - .packets[0] - .Get(); - EXPECT_THAT( - result, - EqualsProto(ParseTextProtoOrDie( - R"pb(embeddings { - entries { float_embedding { values: 0.1 values: 0.2 } } - head_index: 0 - head_name: "foo" - } - embeddings { - entries { float_embedding { values: -0.2 values: -0.3 } } - head_index: 1 - head_name: "bar" - })pb"))); + const EmbeddingResult& result = + runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get(); + EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie( + R"pb(embeddings { + float_embedding { values: 0.1 values: 0.2 } + head_index: 0 + head_name: "foo" + } + embeddings { + float_embedding { values: -0.2 values: -0.3 } + head_index: 1 + head_name: "bar" + })pb"))); } TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) { CalculatorRunner runner(ParseTextProtoOrDie(R"pb( calculator: "TensorsToEmbeddingsCalculator" input_stream: "TENSORS:tensors" - output_stream: "EMBEDDING_RESULT:embeddings" + output_stream: "EMBEDDINGS:embeddings" options { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { embedder_options { l2_normalize: true quantize: false } @@ -152,23 +144,17 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) { BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); MP_ASSERT_OK(runner.Run()); - const EmbeddingResult& result = runner.Outputs() - .Get("EMBEDDING_RESULT", 0) - .packets[0] - .Get(); + const EmbeddingResult& result = + runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get(); EXPECT_THAT( result, EqualsProto(ParseTextProtoOrDie( R"pb(embeddings { - entries { - float_embedding { values: 0.44721356 values: 0.8944271 } - } + float_embedding { values: 0.44721356 values: 0.8944271 } head_index: 0 } embeddings { - entries { - float_embedding { values: -0.5547002 values: -0.8320503 } - } + float_embedding { values: -0.5547002 values: -0.8320503 } head_index: 1 })pb"))); } @@ -177,7 +163,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) { CalculatorRunner runner(ParseTextProtoOrDie(R"pb( calculator: "TensorsToEmbeddingsCalculator" input_stream: "TENSORS:tensors" - output_stream: "EMBEDDING_RESULT:embeddings" + output_stream: "EMBEDDINGS:embeddings" options { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { embedder_options { l2_normalize: false quantize: true } @@ -188,22 +174,16 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) { BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); MP_ASSERT_OK(runner.Run()); - const EmbeddingResult& result = runner.Outputs() - .Get("EMBEDDING_RESULT", 0) - .packets[0] - .Get(); + const EmbeddingResult& result = + runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get(); EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie( R"pb(embeddings { - entries { - quantized_embedding { values: "\x0d\x1a" } # 13,26 - } + quantized_embedding { values: "\x0d\x1a" } # 13,26 head_index: 0 } embeddings { - entries { - quantized_embedding { values: "\xe6\xda" } # -26,-38 - } + quantized_embedding { values: "\xe6\xda" } # -26,-38 head_index: 1 })pb"))); } @@ -213,7 +193,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, CalculatorRunner runner(ParseTextProtoOrDie(R"pb( calculator: "TensorsToEmbeddingsCalculator" input_stream: "TENSORS:tensors" - output_stream: "EMBEDDING_RESULT:embeddings" + output_stream: "EMBEDDINGS:embeddings" options { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { embedder_options { l2_normalize: true quantize: true } @@ -224,25 +204,18 @@ TEST(TensorsToEmbeddingsCalculatorTest, BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); MP_ASSERT_OK(runner.Run()); - const EmbeddingResult& result = runner.Outputs() - .Get("EMBEDDING_RESULT", 0) - .packets[0] - .Get(); - EXPECT_THAT( - result, - EqualsProto(ParseTextProtoOrDie( - R"pb(embeddings { - entries { - quantized_embedding { values: "\x39\x72" } # 57,114 - } - head_index: 0 - } - embeddings { - entries { - quantized_embedding { values: "\xb9\x95" } # -71,-107 - } - head_index: 1 - })pb"))); + const EmbeddingResult& result = + runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get(); + EXPECT_THAT(result, + EqualsProto(ParseTextProtoOrDie( + R"pb(embeddings { + quantized_embedding { values: "\x39\x72" } # 57,114 + head_index: 0 + } + embeddings { + quantized_embedding { values: "\xb9\x95" } # -71,-107 + head_index: 1 + })pb"))); } } // namespace diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index 7a52f11e0..5004383d2 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -49,3 +49,12 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", ], ) + +cc_library( + name = "embedding_result", + srcs = ["embedding_result.cc"], + hdrs = ["embedding_result.h"], + deps = [ + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/containers/embedding_result.cc b/mediapipe/tasks/cc/components/containers/embedding_result.cc new file mode 100644 index 000000000..9de55911b --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/embedding_result.cc @@ -0,0 +1,57 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" + +#include +#include +#include +#include + +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" + +namespace mediapipe::tasks::components::containers { + +Embedding ConvertToEmbedding(const proto::Embedding& proto) { + Embedding embedding; + if (proto.has_float_embedding()) { + embedding.float_embedding = { + std::make_move_iterator(proto.float_embedding().values().begin()), + std::make_move_iterator(proto.float_embedding().values().end())}; + } else { + embedding.quantized_embedding = { + std::make_move_iterator(proto.quantized_embedding().values().begin()), + std::make_move_iterator(proto.quantized_embedding().values().end())}; + } + embedding.head_index = proto.head_index(); + if (proto.has_head_name()) { + embedding.head_name = proto.head_name(); + } + return embedding; +} + +EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto) { + EmbeddingResult embedding_result; + embedding_result.embeddings.reserve(proto.embeddings_size()); + for (const auto& embedding : proto.embeddings()) { + embedding_result.embeddings.push_back(ConvertToEmbedding(embedding)); + } + if (proto.has_timestamp_ms()) { + embedding_result.timestamp_ms = proto.timestamp_ms(); + } + return embedding_result; +} + +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/embedding_result.h b/mediapipe/tasks/cc/components/containers/embedding_result.h new file mode 100644 index 000000000..2d01d2f2a --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/embedding_result.h @@ -0,0 +1,72 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_ + +#include +#include +#include + +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" + +namespace mediapipe::tasks::components::containers { + +// Embedding result for a given embedder head. +// +// One and only one of the two 'float_embedding' and 'quantized_embedding' will +// contain data, based on whether or not the embedder was configured to perform +// scalar quantization. +struct Embedding { + // Floating-point embedding. Empty if the embedder was configured to perform + // scalar-quantization. + std::vector float_embedding; + // Scalar-quantized embedding. Empty if the embedder was not configured to + // perform scalar quantization. + std::string quantized_embedding; + // The index of the embedder head (i.e. output tensor) this embedding comes + // from. This is useful for multi-head models. + int head_index; + // The optional name of the embedder head, as provided in the TFLite Model + // Metadata [1] if present. This is useful for multi-head models. + // + // [1]: https://www.tensorflow.org/lite/convert/metadata + std::optional head_name = std::nullopt; +}; + +// Defines embedding results of a model. +struct EmbeddingResult { + // The embedding results for each head of the model. + std::vector embeddings; + // The optional timestamp (in milliseconds) of the start of the chunk of data + // corresponding to these results. + // + // This is only used for embedding extraction on time series (e.g. audio + // embedding). In these use cases, the amount of data to process might + // exceed the maximum size that the model can process: to solve this, the + // input data is split into multiple chunks starting at different timestamps. + std::optional timestamp_ms = std::nullopt; +}; + +// Utility function to convert from Embedding proto to Embedding struct. +Embedding ConvertToEmbedding(const proto::Embedding& proto); + +// Utility function to convert from EmbeddingResult proto to EmbeddingResult +// struct. +EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto); + +} // namespace mediapipe::tasks::components::containers + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_ diff --git a/mediapipe/tasks/cc/components/containers/proto/embeddings.proto b/mediapipe/tasks/cc/components/containers/proto/embeddings.proto index 39811e6c0..4f888c699 100644 --- a/mediapipe/tasks/cc/components/containers/proto/embeddings.proto +++ b/mediapipe/tasks/cc/components/containers/proto/embeddings.proto @@ -30,30 +30,31 @@ message QuantizedEmbedding { optional bytes values = 1; } -// Floating-point or scalar-quantized embedding with an optional timestamp. -message EmbeddingEntry { - // The actual embedding, either floating-point or scalar-quantized. +// Embedding result for a given embedder head. +message Embedding { + // The actual embedding, either floating-point or quantized. oneof embedding { FloatEmbedding float_embedding = 1; QuantizedEmbedding quantized_embedding = 2; } - // The optional timestamp (in milliseconds) associated to the embedding entry. - // This is useful for time series use cases, e.g. audio embedding. - optional int64 timestamp_ms = 3; -} - -// Embeddings for a given embedder head. -message Embeddings { - repeated EmbeddingEntry entries = 1; // The index of the embedder head that produced this embedding. This is useful // for multi-head models. - optional int32 head_index = 2; + optional int32 head_index = 3; // The name of the embedder head, which is the corresponding tensor metadata // name (if any). This is useful for multi-head models. - optional string head_name = 3; + optional string head_name = 4; } -// Contains one set of results per embedder head. +// Embedding results for a given embedder model. message EmbeddingResult { - repeated Embeddings embeddings = 1; + // The embedding results for each model head, i.e. one for each output tensor. + repeated Embedding embeddings = 1; + // The optional timestamp (in milliseconds) of the start of the chunk of data + // corresponding to these results. + // + // This is only used for embedding extraction on time series (e.g. audio + // embedding). In these use cases, the amount of data to process might + // exceed the maximum size that the model can process: to solve this, the + // input data is split into multiple chunks starting at different timestamps. + optional int64 timestamp_ms = 2; } diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 62f04dcb7..12af55ed9 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -62,3 +62,38 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "embedder_options", + srcs = ["embedder_options.cc"], + hdrs = ["embedder_options.h"], + deps = ["//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto"], +) + +cc_library( + name = "embedding_postprocessing_graph", + srcs = ["embedding_postprocessing_graph.cc"], + hdrs = ["embedding_postprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/tensor:tensors_dequantization_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/tool:options_map", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", + "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:source_or_node_output", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/components/embedder_options.cc b/mediapipe/tasks/cc/components/processors/embedder_options.cc similarity index 77% rename from mediapipe/tasks/cc/components/embedder_options.cc rename to mediapipe/tasks/cc/components/processors/embedder_options.cc index 9cc399f7b..fce7baa44 100644 --- a/mediapipe/tasks/cc/components/embedder_options.cc +++ b/mediapipe/tasks/cc/components/processors/embedder_options.cc @@ -13,22 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/embedder_options.h" +#include "mediapipe/tasks/cc/components/processors/embedder_options.h" -#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { -tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto( +proto::EmbedderOptions ConvertEmbedderOptionsToProto( EmbedderOptions* embedder_options) { - tasks::components::proto::EmbedderOptions options_proto; + proto::EmbedderOptions options_proto; options_proto.set_l2_normalize(embedder_options->l2_normalize); options_proto.set_quantize(embedder_options->quantize); return options_proto; } +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/embedder_options.h b/mediapipe/tasks/cc/components/processors/embedder_options.h similarity index 79% rename from mediapipe/tasks/cc/components/embedder_options.h rename to mediapipe/tasks/cc/components/processors/embedder_options.h index 9ed0fee87..b37171592 100644 --- a/mediapipe/tasks/cc/components/embedder_options.h +++ b/mediapipe/tasks/cc/components/processors/embedder_options.h @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_ -#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { // Embedder options for MediaPipe C++ embedding extraction tasks. struct EmbedderOptions { @@ -37,11 +38,12 @@ struct EmbedderOptions { bool quantize; }; -tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto( +proto::EmbedderOptions ConvertEmbedderOptionsToProto( EmbedderOptions* embedder_options); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_ diff --git a/mediapipe/tasks/cc/components/embedding_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc similarity index 94% rename from mediapipe/tasks/cc/components/embedding_postprocessing_graph.cc rename to mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc index 4ea009cb8..3a3884689 100644 --- a/mediapipe/tasks/cc/components/embedding_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" #include #include @@ -29,8 +29,8 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" -#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" @@ -39,6 +39,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { @@ -49,13 +50,12 @@ using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; -using ::mediapipe::tasks::components::proto::EmbedderOptions; using ::mediapipe::tasks::core::ModelResources; using TensorsSource = ::mediapipe::tasks::SourceOrNodeOutput>; constexpr char kTensorsTag[] = "TENSORS"; -constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT"; +constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; // Identifies whether or not the model has quantized outputs, and performs // sanity checks. @@ -144,7 +144,7 @@ absl::StatusOr> GetHeadNames( absl::Status ConfigureEmbeddingPostprocessing( const ModelResources& model_resources, - const EmbedderOptions& embedder_options, + const proto::EmbedderOptions& embedder_options, proto::EmbeddingPostprocessingGraphOptions* options) { ASSIGN_OR_RETURN(bool has_quantized_outputs, HasQuantizedOutputs(model_resources)); @@ -188,7 +188,7 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { BuildEmbeddingPostprocessing( sc->Options(), graph[Input>(kTensorsTag)], graph)); - embedding_result_out >> graph[Output(kEmbeddingResultTag)]; + embedding_result_out >> graph[Output(kEmbeddingsTag)]; return graph.GetConfig(); } @@ -220,13 +220,13 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { .GetOptions() .CopyFrom(options.tensors_to_embeddings_options()); dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag); - return tensors_to_embeddings_node[Output( - kEmbeddingResultTag)]; + return tensors_to_embeddings_node[Output(kEmbeddingsTag)]; } }; REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::EmbeddingPostprocessingGraph); + ::mediapipe::tasks::components::processors::EmbeddingPostprocessingGraph); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/embedding_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h similarity index 77% rename from mediapipe/tasks/cc/components/embedding_postprocessing_graph.h rename to mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h index af8fa6706..5e8f2c084 100644 --- a/mediapipe/tasks/cc/components/embedding_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h @@ -13,17 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_ #include "absl/status/status.h" -#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { // Configures an EmbeddingPostprocessingGraph using the provided model resources // and EmbedderOptions. @@ -44,18 +45,19 @@ namespace components { // The output tensors of an InferenceCalculator, to convert into // EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8. // Outputs: -// EMBEDDING_RESULT - EmbeddingResult +// EMBEDDINGS - EmbeddingResult // The output EmbeddingResult. // // TODO: add support for additional optional "TIMESTAMPS" input for // embeddings aggregation. absl::Status ConfigureEmbeddingPostprocessing( const tasks::core::ModelResources& model_resources, - const tasks::components::proto::EmbedderOptions& embedder_options, + const proto::EmbedderOptions& embedder_options, proto::EmbeddingPostprocessingGraphOptions* options); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/embedding_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc similarity index 71% rename from mediapipe/tasks/cc/components/embedding_postprocessing_graph_test.cc rename to mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc index 9c0d21ab2..62fab8f7e 100644 --- a/mediapipe/tasks/cc/components/embedding_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" #include @@ -25,8 +25,8 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" -#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" @@ -34,12 +34,10 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::proto::EmbedderOptions; -using ::mediapipe::tasks::components::proto:: - EmbeddingPostprocessingGraphOptions; using ::mediapipe::tasks::core::ModelResources; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/"; @@ -69,68 +67,72 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - EmbedderOptions options_in; + proto::EmbedderOptions options_in; options_in.set_l2_normalize(true); - EmbeddingPostprocessingGraphOptions options_out; + proto::EmbeddingPostprocessingGraphOptions options_out; MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, &options_out)); EXPECT_THAT( options_out, - EqualsProto(ParseTextProtoOrDie( - R"pb(tensors_to_embeddings_options { - embedder_options { l2_normalize: true } - head_names: "probability" - } - has_quantized_outputs: true)pb"))); + EqualsProto( + ParseTextProtoOrDie( + R"pb(tensors_to_embeddings_options { + embedder_options { l2_normalize: true } + head_names: "probability" + } + has_quantized_outputs: true)pb"))); } TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - EmbedderOptions options_in; + proto::EmbedderOptions options_in; options_in.set_quantize(true); - EmbeddingPostprocessingGraphOptions options_out; + proto::EmbeddingPostprocessingGraphOptions options_out; MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, &options_out)); EXPECT_THAT( options_out, - EqualsProto(ParseTextProtoOrDie( - R"pb(tensors_to_embeddings_options { - embedder_options { quantize: true } - } - has_quantized_outputs: true)pb"))); + EqualsProto( + ParseTextProtoOrDie( + R"pb(tensors_to_embeddings_options { + embedder_options { quantize: true } + } + has_quantized_outputs: true)pb"))); } TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { MP_ASSERT_OK_AND_ASSIGN(auto model_resources, CreateModelResourcesForModel(kMobileNetV3Embedder)); - EmbedderOptions options_in; + proto::EmbedderOptions options_in; options_in.set_quantize(true); options_in.set_l2_normalize(true); - EmbeddingPostprocessingGraphOptions options_out; + proto::EmbeddingPostprocessingGraphOptions options_out; MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, &options_out)); EXPECT_THAT( options_out, - EqualsProto(ParseTextProtoOrDie( - R"pb(tensors_to_embeddings_options { - embedder_options { quantize: true l2_normalize: true } - head_names: "feature" - } - has_quantized_outputs: false)pb"))); + EqualsProto( + ParseTextProtoOrDie( + R"pb(tensors_to_embeddings_options { + embedder_options { quantize: true l2_normalize: true } + head_names: "feature" + } + has_quantized_outputs: false)pb"))); } // TODO: add E2E Postprocessing tests once timestamp aggregation is // supported. } // namespace +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index d7cbe47ff..23ebbe008 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -34,3 +34,18 @@ mediapipe_proto_library( "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", ], ) + +mediapipe_proto_library( + name = "embedder_options_proto", + srcs = ["embedder_options.proto"], +) + +mediapipe_proto_library( + name = "embedding_postprocessing_graph_options_proto", + srcs = ["embedding_postprocessing_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/proto/embedder_options.proto b/mediapipe/tasks/cc/components/processors/proto/embedder_options.proto similarity index 88% rename from mediapipe/tasks/cc/components/proto/embedder_options.proto rename to mediapipe/tasks/cc/components/processors/proto/embedder_options.proto index 8a60a1398..8973ab248 100644 --- a/mediapipe/tasks/cc/components/proto/embedder_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/embedder_options.proto @@ -15,7 +15,10 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.components.processors.proto; + +option java_package = "com.google.mediapipe.tasks.components.processors.proto"; +option java_outer_classname = "EmbedderOptionsProto"; // Shared options used by all embedding extraction tasks. message EmbedderOptions { diff --git a/mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.proto similarity index 96% rename from mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.proto rename to mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.proto index 4e79f8178..f8dbf59f0 100644 --- a/mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.components.processors.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto"; diff --git a/mediapipe/tasks/cc/components/proto/BUILD b/mediapipe/tasks/cc/components/proto/BUILD index c11d6f95a..4534a1652 100644 --- a/mediapipe/tasks/cc/components/proto/BUILD +++ b/mediapipe/tasks/cc/components/proto/BUILD @@ -23,21 +23,6 @@ mediapipe_proto_library( srcs = ["segmenter_options.proto"], ) -mediapipe_proto_library( - name = "embedder_options_proto", - srcs = ["embedder_options.proto"], -) - -mediapipe_proto_library( - name = "embedding_postprocessing_graph_options_proto", - srcs = ["embedding_postprocessing_graph_options.proto"], - deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto", - ], -) - mediapipe_proto_library( name = "text_preprocessing_graph_options_proto", srcs = ["text_preprocessing_graph_options.proto"], diff --git a/mediapipe/tasks/cc/components/utils/BUILD b/mediapipe/tasks/cc/components/utils/BUILD index d16e2fbc4..8bb5b8415 100644 --- a/mediapipe/tasks/cc/components/utils/BUILD +++ b/mediapipe/tasks/cc/components/utils/BUILD @@ -26,7 +26,7 @@ cc_library( hdrs = ["cosine_similarity.h"], deps = [ "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/containers:embedding_result", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -39,7 +39,7 @@ cc_test( deps = [ ":cosine_similarity", "//mediapipe/framework/port:gtest_main", - "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/containers:embedding_result", ], ) diff --git a/mediapipe/tasks/cc/components/utils/cosine_similarity.cc b/mediapipe/tasks/cc/components/utils/cosine_similarity.cc index af471a2d8..1403700c8 100644 --- a/mediapipe/tasks/cc/components/utils/cosine_similarity.cc +++ b/mediapipe/tasks/cc/components/utils/cosine_similarity.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" namespace mediapipe { namespace tasks { @@ -30,7 +30,7 @@ namespace utils { namespace { -using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; +using ::mediapipe::tasks::components::containers::Embedding; template absl::StatusOr ComputeCosineSimilarity(const T& u, const T& v, @@ -66,39 +66,35 @@ absl::StatusOr ComputeCosineSimilarity(const T& u, const T& v, // an L2-norm of 0. // // [1]: https://en.wikipedia.org/wiki/Cosine_similarity -absl::StatusOr CosineSimilarity(const EmbeddingEntry& u, - const EmbeddingEntry& v) { - if (u.has_float_embedding() && v.has_float_embedding()) { - if (u.float_embedding().values().size() != - v.float_embedding().values().size()) { +absl::StatusOr CosineSimilarity(const Embedding& u, + const Embedding& v) { + if (!u.float_embedding.empty() && !v.float_embedding.empty()) { + if (u.float_embedding.size() != v.float_embedding.size()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrFormat("Cannot compute cosine similarity between embeddings " "of different sizes (%d vs. %d)", - u.float_embedding().values().size(), - v.float_embedding().values().size()), + u.float_embedding.size(), v.float_embedding.size()), MediaPipeTasksStatus::kInvalidArgumentError); } - return ComputeCosineSimilarity(u.float_embedding().values().data(), - v.float_embedding().values().data(), - u.float_embedding().values().size()); + return ComputeCosineSimilarity(u.float_embedding.data(), + v.float_embedding.data(), + u.float_embedding.size()); } - if (u.has_quantized_embedding() && v.has_quantized_embedding()) { - if (u.quantized_embedding().values().size() != - v.quantized_embedding().values().size()) { + if (!u.quantized_embedding.empty() && !v.quantized_embedding.empty()) { + if (u.quantized_embedding.size() != v.quantized_embedding.size()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrFormat("Cannot compute cosine similarity between embeddings " "of different sizes (%d vs. %d)", - u.quantized_embedding().values().size(), - v.quantized_embedding().values().size()), + u.quantized_embedding.size(), + v.quantized_embedding.size()), MediaPipeTasksStatus::kInvalidArgumentError); } - return ComputeCosineSimilarity(reinterpret_cast( - u.quantized_embedding().values().data()), - reinterpret_cast( - v.quantized_embedding().values().data()), - u.quantized_embedding().values().size()); + return ComputeCosineSimilarity( + reinterpret_cast(u.quantized_embedding.data()), + reinterpret_cast(v.quantized_embedding.data()), + u.quantized_embedding.size()); } return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, diff --git a/mediapipe/tasks/cc/components/utils/cosine_similarity.h b/mediapipe/tasks/cc/components/utils/cosine_similarity.h index 4356811cd..45ddce76e 100644 --- a/mediapipe/tasks/cc/components/utils/cosine_similarity.h +++ b/mediapipe/tasks/cc/components/utils/cosine_similarity.h @@ -17,22 +17,20 @@ limitations under the License. #define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_ #include "absl/status/statusor.h" -#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" namespace mediapipe { namespace tasks { namespace components { namespace utils { -// Utility function to compute cosine similarity [1] between two embedding -// entries. May return an InvalidArgumentError if e.g. the feature vectors are -// of different types (quantized vs. float), have different sizes, or have a -// an L2-norm of 0. +// Utility function to compute cosine similarity [1] between two embeddings. May +// return an InvalidArgumentError if e.g. the embeddings are of different types +// (quantized vs. float), have different sizes, or have a an L2-norm of 0. // // [1]: https://en.wikipedia.org/wiki/Cosine_similarity -absl::StatusOr CosineSimilarity( - const containers::proto::EmbeddingEntry& u, - const containers::proto::EmbeddingEntry& v); +absl::StatusOr CosineSimilarity(const containers::Embedding& u, + const containers::Embedding& v); } // namespace utils } // namespace components diff --git a/mediapipe/tasks/cc/components/utils/cosine_similarity_test.cc b/mediapipe/tasks/cc/components/utils/cosine_similarity_test.cc index 176f7f7a6..4ff9dfc3a 100644 --- a/mediapipe/tasks/cc/components/utils/cosine_similarity_test.cc +++ b/mediapipe/tasks/cc/components/utils/cosine_similarity_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/status_matchers.h" -#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" namespace mediapipe { namespace tasks { @@ -30,29 +30,27 @@ namespace components { namespace utils { namespace { -using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; +using ::mediapipe::tasks::components::containers::Embedding; using ::testing::HasSubstr; -// Helper function to generate float EmbeddingEntry. -EmbeddingEntry BuildFloatEntry(std::vector values) { - EmbeddingEntry entry; - for (const float value : values) { - entry.mutable_float_embedding()->add_values(value); - } - return entry; +// Helper function to generate float Embedding. +Embedding BuildFloatEmbedding(std::vector values) { + Embedding embedding; + embedding.float_embedding = values; + return embedding; } -// Helper function to generate quantized EmbeddingEntry. -EmbeddingEntry BuildQuantizedEntry(std::vector values) { - EmbeddingEntry entry; - entry.mutable_quantized_embedding()->set_values( - reinterpret_cast(values.data()), values.size()); - return entry; +// Helper function to generate quantized Embedding. +Embedding BuildQuantizedEmbedding(std::vector values) { + Embedding embedding; + uint8_t* data = reinterpret_cast(values.data()); + embedding.quantized_embedding = {data, data + values.size()}; + return embedding; } TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) { - auto u = BuildFloatEntry({0.1, 0.2}); - auto v = BuildQuantizedEntry({0, 1}); + auto u = BuildFloatEmbedding({0.1, 0.2}); + auto v = BuildQuantizedEmbedding({0, 1}); auto status = CosineSimilarity(u, v); @@ -63,8 +61,8 @@ TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) { } TEST(CosineSimilarity, FailsWithZeroNorm) { - auto u = BuildFloatEntry({0.1, 0.2}); - auto v = BuildFloatEntry({0.0, 0.0}); + auto u = BuildFloatEmbedding({0.1, 0.2}); + auto v = BuildFloatEmbedding({0.0, 0.0}); auto status = CosineSimilarity(u, v); @@ -75,8 +73,8 @@ TEST(CosineSimilarity, FailsWithZeroNorm) { } TEST(CosineSimilarity, FailsWithDifferentSizes) { - auto u = BuildFloatEntry({0.1, 0.2}); - auto v = BuildFloatEntry({0.1, 0.2, 0.3}); + auto u = BuildFloatEmbedding({0.1, 0.2}); + auto v = BuildFloatEmbedding({0.1, 0.2, 0.3}); auto status = CosineSimilarity(u, v); @@ -87,8 +85,8 @@ TEST(CosineSimilarity, FailsWithDifferentSizes) { } TEST(CosineSimilarity, SucceedsWithFloatEntries) { - auto u = BuildFloatEntry({1.0, 0.0, 0.0, 0.0}); - auto v = BuildFloatEntry({0.5, 0.5, 0.5, 0.5}); + auto u = BuildFloatEmbedding({1.0, 0.0, 0.0, 0.0}); + auto v = BuildFloatEmbedding({0.5, 0.5, 0.5, 0.5}); MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v)); @@ -96,8 +94,8 @@ TEST(CosineSimilarity, SucceedsWithFloatEntries) { } TEST(CosineSimilarity, SucceedsWithQuantizedEntries) { - auto u = BuildQuantizedEntry({127, 0, 0, 0}); - auto v = BuildQuantizedEntry({-128, 0, 0, 0}); + auto u = BuildQuantizedEmbedding({127, 0, 0, 0}); + auto v = BuildQuantizedEmbedding({-128, 0, 0, 0}); MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v)); diff --git a/mediapipe/tasks/cc/vision/image_embedder/BUILD b/mediapipe/tasks/cc/vision/image_embedder/BUILD index 0f63f87e4..ea7f40261 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/BUILD +++ b/mediapipe/tasks/cc/vision/image_embedder/BUILD @@ -26,12 +26,12 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:embedding_postprocessing_graph", "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", - "//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", "@com_google_absl//absl/status:statusor", @@ -49,9 +49,10 @@ cc_library( "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/tool:options_map", - "//mediapipe/tasks/cc/components:embedder_options", + "//mediapipe/tasks/cc/components/containers:embedding_result", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", - "//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedder_options", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", "//mediapipe/tasks/cc/components/utils:cosine_similarity", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:task_runner", diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc index 1dc316305..e3198090f 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc @@ -21,9 +21,10 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/tool/options_map.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" -#include "mediapipe/tasks/cc/components/embedder_options.h" -#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/embedder_options.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/utils/cosine_similarity.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" @@ -41,8 +42,8 @@ namespace image_embedder { namespace { -constexpr char kEmbeddingResultStreamName[] = "embedding_result_out"; -constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT"; +constexpr char kEmbeddingsStreamName[] = "embeddings_out"; +constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; @@ -53,7 +54,7 @@ constexpr char kGraphTypeName[] = "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; -using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; +using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::vision::image_embedder::proto:: @@ -71,13 +72,13 @@ CalculatorGraphConfig CreateGraphConfig( graph.In(kNormRectTag).SetName(kNormRectStreamName); auto& task_graph = graph.AddNode(kGraphTypeName); task_graph.GetOptions().Swap(options_proto.get()); - task_graph.Out(kEmbeddingResultTag).SetName(kEmbeddingResultStreamName) >> - graph.Out(kEmbeddingResultTag); + task_graph.Out(kEmbeddingsTag).SetName(kEmbeddingsStreamName) >> + graph.Out(kEmbeddingsTag); task_graph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); if (enable_flow_limiting) { return tasks::core::AddFlowLimiterCalculator( - graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingResultTag); + graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingsTag); } graph.In(kImageTag) >> task_graph.In(kImageTag); graph.In(kNormRectTag) >> task_graph.In(kNormRectTag); @@ -95,8 +96,8 @@ std::unique_ptr ConvertImageEmbedderOptionsToProto( options_proto->mutable_base_options()->set_use_stream_mode( options->running_mode != core::RunningMode::IMAGE); auto embedder_options_proto = - std::make_unique( - components::ConvertEmbedderOptionsToProto( + std::make_unique( + components::processors::ConvertEmbedderOptionsToProto( &(options->embedder_options))); options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get()); return options_proto; @@ -121,9 +122,10 @@ absl::StatusOr> ImageEmbedder::Create( return; } Packet embedding_result_packet = - status_or_packets.value()[kEmbeddingResultStreamName]; + status_or_packets.value()[kEmbeddingsStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName]; - result_callback(embedding_result_packet.Get(), + result_callback(ConvertToEmbeddingResult( + embedding_result_packet.Get()), image_packet.Get(), embedding_result_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); @@ -138,7 +140,7 @@ absl::StatusOr> ImageEmbedder::Create( std::move(packets_callback)); } -absl::StatusOr ImageEmbedder::Embed( +absl::StatusOr ImageEmbedder::Embed( Image image, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -155,10 +157,11 @@ absl::StatusOr ImageEmbedder::Embed( {{kImageInStreamName, MakePacket(std::move(image))}, {kNormRectStreamName, MakePacket(std::move(norm_rect))}})); - return output_packets[kEmbeddingResultStreamName].Get(); + return ConvertToEmbeddingResult( + output_packets[kEmbeddingsStreamName].Get()); } -absl::StatusOr ImageEmbedder::EmbedForVideo( +absl::StatusOr ImageEmbedder::EmbedForVideo( Image image, int64 timestamp_ms, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -178,7 +181,8 @@ absl::StatusOr ImageEmbedder::EmbedForVideo( {kNormRectStreamName, MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); - return output_packets[kEmbeddingResultStreamName].Get(); + return ConvertToEmbeddingResult( + output_packets[kEmbeddingsStreamName].Get()); } absl::Status ImageEmbedder::EmbedAsync( @@ -202,7 +206,8 @@ absl::Status ImageEmbedder::EmbedAsync( } absl::StatusOr ImageEmbedder::CosineSimilarity( - const EmbeddingEntry& u, const EmbeddingEntry& v) { + const components::containers::Embedding& u, + const components::containers::Embedding& v) { return components::utils::CosineSimilarity(u, v); } diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h index 3a2a1dbee..9320cbc35 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h @@ -21,8 +21,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" -#include "mediapipe/tasks/cc/components/embedder_options.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" +#include "mediapipe/tasks/cc/components/processors/embedder_options.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" @@ -33,6 +33,10 @@ namespace tasks { namespace vision { namespace image_embedder { +// Alias the shared EmbeddingResult struct as result typo. +using ImageEmbedderResult = + ::mediapipe::tasks::components::containers::EmbeddingResult; + // The options for configuring a MediaPipe image embedder task. struct ImageEmbedderOptions { // Base options for configuring MediaPipe Tasks, such as specifying the model @@ -50,14 +54,12 @@ struct ImageEmbedderOptions { // Options for configuring the embedder behavior, such as L2-normalization or // scalar-quantization. - components::EmbedderOptions embedder_options; + components::processors::EmbedderOptions embedder_options; // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. - std::function, - const Image&, int64)> + std::function, const Image&, int64)> result_callback = nullptr; }; @@ -104,7 +106,7 @@ class ImageEmbedder : core::BaseVisionTaskApi { // running mode. // // The image can be of any size with format RGB or RGBA. - absl::StatusOr Embed( + absl::StatusOr Embed( mediapipe::Image image, std::optional image_processing_options = std::nullopt); @@ -127,7 +129,7 @@ class ImageEmbedder : core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. It's required to // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. - absl::StatusOr EmbedForVideo( + absl::StatusOr EmbedForVideo( mediapipe::Image image, int64 timestamp_ms, std::optional image_processing_options = std::nullopt); @@ -168,15 +170,15 @@ class ImageEmbedder : core::BaseVisionTaskApi { // Shuts down the ImageEmbedder when all works are done. absl::Status Close() { return runner_->Close(); } - // Utility function to compute cosine similarity [1] between two embedding - // entries. May return an InvalidArgumentError if e.g. the feature vectors are - // of different types (quantized vs. float), have different sizes, or have a - // an L2-norm of 0. + // Utility function to compute cosine similarity [1] between two embeddings. + // May return an InvalidArgumentError if e.g. the embeddings are of different + // types (quantized vs. float), have different sizes, or have a an L2-norm of + // 0. // // [1]: https://en.wikipedia.org/wiki/Cosine_similarity static absl::StatusOr CosineSimilarity( - const components::containers::proto::EmbeddingEntry& u, - const components::containers::proto::EmbeddingEntry& v); + const components::containers::Embedding& u, + const components::containers::Embedding& v); }; } // namespace image_embedder diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index f0f440986..11e25144c 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -20,10 +20,10 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" -#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h" @@ -40,10 +40,8 @@ using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; -using ::mediapipe::tasks::components::proto:: - EmbeddingPostprocessingGraphOptions; -constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT"; +constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; constexpr char kImageTag[] = "IMAGE"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kTensorsTag[] = "TENSORS"; @@ -67,7 +65,7 @@ struct ImageEmbedderOutputStreams { // Describes region of image to perform embedding extraction on. // @Optional: rect covering the whole image is used if not specified. // Outputs: -// EMBEDDING_RESULT - EmbeddingResult +// EMBEDDINGS - EmbeddingResult // The embedding result. // IMAGE - Image // The image that embedding extraction runs on. @@ -76,7 +74,7 @@ struct ImageEmbedderOutputStreams { // node { // calculator: "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph" // input_stream: "IMAGE:image_in" -// output_stream: "EMBEDDING_RESULT:embedding_result_out" +// output_stream: "EMBEDDINGS:embedding_result_out" // output_stream: "IMAGE:image_out" // options { // [mediapipe.tasks.vision.image_embedder.proto.ImageEmbedderOptions.ext] @@ -107,7 +105,7 @@ class ImageEmbedderGraph : public core::ModelTaskGraph { graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph)); output_streams.embedding_result >> - graph[Output(kEmbeddingResultTag)]; + graph[Output(kEmbeddingsTag)]; output_streams.image >> graph[Output(kImageTag)]; return graph.GetConfig(); } @@ -152,16 +150,17 @@ class ImageEmbedderGraph : public core::ModelTaskGraph { // Adds postprocessing calculators and connects its input stream to the // inference results. auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::ConfigureEmbeddingPostprocessing( + "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); + MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + &postprocessing.GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the embedding results. return ImageEmbedderOutputStreams{ /*embedding_result=*/postprocessing[Output( - kEmbeddingResultTag)], + kEmbeddingsTag)], /*image=*/preprocessing[Output(kImageTag)]}; } }; diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc index 386b6c8eb..6098a9a70 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/status_matchers.h" -#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -42,7 +42,6 @@ namespace { using ::mediapipe::file::JoinPath; using ::mediapipe::tasks::components::containers::Rect; -using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -54,18 +53,14 @@ constexpr double kSimilarityTolerancy = 1e-6; // Utility function to check the sizes, head_index and head_names of a result // procuded by kMobileNetV3Embedder. -void CheckMobileNetV3Result(const EmbeddingResult& result, bool quantized) { - EXPECT_EQ(result.embeddings().size(), 1); - EXPECT_EQ(result.embeddings(0).head_index(), 0); - EXPECT_EQ(result.embeddings(0).head_name(), "feature"); - EXPECT_EQ(result.embeddings(0).entries().size(), 1); +void CheckMobileNetV3Result(const ImageEmbedderResult& result, bool quantized) { + EXPECT_EQ(result.embeddings.size(), 1); + EXPECT_EQ(result.embeddings[0].head_index, 0); + EXPECT_EQ(result.embeddings[0].head_name, "feature"); if (quantized) { - EXPECT_EQ( - result.embeddings(0).entries(0).quantized_embedding().values().size(), - 1024); + EXPECT_EQ(result.embeddings[0].quantized_embedding.size(), 1024); } else { - EXPECT_EQ(result.embeddings(0).entries(0).float_embedding().values().size(), - 1024); + EXPECT_EQ(result.embeddings[0].float_embedding.size(), 1024); } } @@ -154,7 +149,7 @@ TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); options->running_mode = running_mode; - options->result_callback = [](absl::StatusOr, + options->result_callback = [](absl::StatusOr, const Image& image, int64 timestamp_ms) {}; auto image_embedder = ImageEmbedder::Create(std::move(options)); @@ -231,19 +226,18 @@ TEST_F(ImageModeTest, SucceedsWithoutL2Normalization) { JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); // Extract both embeddings. - MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, + MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result, image_embedder->Embed(image)); - MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result, image_embedder->Embed(crop)); // Check results. CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(crop_result, false); // CheckCosineSimilarity. - MP_ASSERT_OK_AND_ASSIGN( - double similarity, - ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), - crop_result.embeddings(0).entries(0))); + MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity( + image_result.embeddings[0], + crop_result.embeddings[0])); double expected_similarity = 0.925519; EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); } @@ -264,19 +258,18 @@ TEST_F(ImageModeTest, SucceedsWithL2Normalization) { JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); // Extract both embeddings. - MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, + MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result, image_embedder->Embed(image)); - MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result, image_embedder->Embed(crop)); // Check results. CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(crop_result, false); // CheckCosineSimilarity. - MP_ASSERT_OK_AND_ASSIGN( - double similarity, - ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), - crop_result.embeddings(0).entries(0))); + MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity( + image_result.embeddings[0], + crop_result.embeddings[0])); double expected_similarity = 0.925519; EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); } @@ -297,19 +290,18 @@ TEST_F(ImageModeTest, SucceedsWithQuantization) { JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); // Extract both embeddings. - MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, + MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result, image_embedder->Embed(image)); - MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result, image_embedder->Embed(crop)); // Check results. CheckMobileNetV3Result(image_result, true); CheckMobileNetV3Result(crop_result, true); // CheckCosineSimilarity. - MP_ASSERT_OK_AND_ASSIGN( - double similarity, - ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), - crop_result.embeddings(0).entries(0))); + MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity( + image_result.embeddings[0], + crop_result.embeddings[0])); double expected_similarity = 0.926791; EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); } @@ -333,19 +325,18 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { // Extract both embeddings. MP_ASSERT_OK_AND_ASSIGN( - const EmbeddingResult& image_result, + const ImageEmbedderResult& image_result, image_embedder->Embed(image, image_processing_options)); - MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result, image_embedder->Embed(crop)); // Check results. CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(crop_result, false); // CheckCosineSimilarity. - MP_ASSERT_OK_AND_ASSIGN( - double similarity, - ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), - crop_result.embeddings(0).entries(0))); + MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity( + image_result.embeddings[0], + crop_result.embeddings[0])); double expected_similarity = 0.999931; EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); } @@ -367,20 +358,19 @@ TEST_F(ImageModeTest, SucceedsWithRotation) { image_processing_options.rotation_degrees = -90; // Extract both embeddings. - MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, + MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result, image_embedder->Embed(image)); MP_ASSERT_OK_AND_ASSIGN( - const EmbeddingResult& rotated_result, + const ImageEmbedderResult& rotated_result, image_embedder->Embed(rotated, image_processing_options)); // Check results. CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(rotated_result, false); // CheckCosineSimilarity. - MP_ASSERT_OK_AND_ASSIGN( - double similarity, - ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), - rotated_result.embeddings(0).entries(0))); + MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity( + image_result.embeddings[0], + rotated_result.embeddings[0])); double expected_similarity = 0.572265; EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); } @@ -403,20 +393,19 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { /*rotation_degrees=*/-90}; // Extract both embeddings. - MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result, image_embedder->Embed(crop)); MP_ASSERT_OK_AND_ASSIGN( - const EmbeddingResult& rotated_result, + const ImageEmbedderResult& rotated_result, image_embedder->Embed(rotated, image_processing_options)); // Check results. CheckMobileNetV3Result(crop_result, false); CheckMobileNetV3Result(rotated_result, false); // CheckCosineSimilarity. - MP_ASSERT_OK_AND_ASSIGN( - double similarity, - ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0), - rotated_result.embeddings(0).entries(0))); + MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity( + crop_result.embeddings[0], + rotated_result.embeddings[0])); double expected_similarity = 0.62838; EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); } @@ -487,16 +476,16 @@ TEST_F(VideoModeTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, ImageEmbedder::Create(std::move(options))); - EmbeddingResult previous_results; + ImageEmbedderResult previous_results; for (int i = 0; i < iterations; ++i) { MP_ASSERT_OK_AND_ASSIGN(auto results, image_embedder->EmbedForVideo(image, i)); CheckMobileNetV3Result(results, false); if (i > 0) { - MP_ASSERT_OK_AND_ASSIGN(double similarity, - ImageEmbedder::CosineSimilarity( - results.embeddings(0).entries(0), - previous_results.embeddings(0).entries(0))); + MP_ASSERT_OK_AND_ASSIGN( + double similarity, + ImageEmbedder::CosineSimilarity(results.embeddings[0], + previous_results.embeddings[0])); double expected_similarity = 1.000000; EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); } @@ -515,7 +504,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = [](absl::StatusOr, + options->result_callback = [](absl::StatusOr, const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, ImageEmbedder::Create(std::move(options))); @@ -546,7 +535,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = [](absl::StatusOr, + options->result_callback = [](absl::StatusOr, const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, ImageEmbedder::Create(std::move(options))); @@ -564,7 +553,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { } struct LiveStreamModeResults { - EmbeddingResult embedding_result; + ImageEmbedderResult embedding_result; std::pair image_size; int64 timestamp_ms; }; @@ -580,7 +569,7 @@ TEST_F(LiveStreamModeTest, Succeeds) { JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); options->running_mode = core::RunningMode::LIVE_STREAM; options->result_callback = - [&results](absl::StatusOr embedding_result, + [&results](absl::StatusOr embedding_result, const Image& image, int64 timestamp_ms) { MP_ASSERT_OK(embedding_result.status()); results.push_back( @@ -612,8 +601,8 @@ TEST_F(LiveStreamModeTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN( double similarity, ImageEmbedder::CosineSimilarity( - result.embedding_result.embeddings(0).entries(0), - results[i - 1].embedding_result.embeddings(0).entries(0))); + result.embedding_result.embeddings[0], + results[i - 1].embedding_result.embeddings[0])); double expected_similarity = 1.000000; EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); } diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/BUILD b/mediapipe/tasks/cc/vision/image_embedder/proto/BUILD index 83407001f..ecf8b0242 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_embedder/proto/BUILD @@ -24,7 +24,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:embedder_options_proto", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto index e5e31a729..4adba5ab7 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto @@ -18,7 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_embedder.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/embedder_options.proto"; +import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; message ImageEmbedderGraphOptions { @@ -31,5 +31,5 @@ message ImageEmbedderGraphOptions { // Options for configuring the embedder behavior, such as normalization or // quantization. - optional components.proto.EmbedderOptions embedder_options = 2; + optional components.processors.proto.EmbedderOptions embedder_options = 2; } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 0260e3fab..ff40768f7 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -23,6 +23,7 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite", From c10ebe8476ea70bc72e496aeb440e2e3fdd30ba4 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 4 Nov 2022 13:21:40 -0700 Subject: [PATCH 31/34] Remove the use of designated initializers. PiperOrigin-RevId: 486215798 --- .../gesture_recognizer/gesture_recognizer_graph.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc index 7ab4847dd..47d95100b 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -273,11 +273,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph { hand_gesture_subgraph[Output>( kHandGesturesTag)]; - return {{.gesture = hand_gestures, - .handedness = handedness, - .hand_landmarks = hand_landmarks, - .hand_world_landmarks = hand_world_landmarks, - .image = hand_landmarker_graph[Output(kImageTag)]}}; + return GestureRecognizerOutputs{ + /*gesture=*/hand_gestures, + /*handedness=*/handedness, + /*hand_landmarks=*/hand_landmarks, + /*hand_world_landmarks=*/hand_world_landmarks, + /*image=*/hand_landmarker_graph[Output(kImageTag)]}; } }; From 35f635d8ff9a4800ef8ef56f6526b833936ecfaf Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 4 Nov 2022 16:37:14 -0700 Subject: [PATCH 32/34] Add an argument to packet_creator.create_matrix to allow the input matrix to be transposed first. PiperOrigin-RevId: 486258078 --- mediapipe/python/pybind/packet_creator.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mediapipe/python/pybind/packet_creator.cc b/mediapipe/python/pybind/packet_creator.cc index 421ac44d3..b36fa306a 100644 --- a/mediapipe/python/pybind/packet_creator.cc +++ b/mediapipe/python/pybind/packet_creator.cc @@ -602,8 +602,11 @@ void PublicPacketCreators(pybind11::module* m) { // TODO: Should take "const Eigen::Ref&" // as the input argument. Investigate why bazel non-optimized mode // triggers a memory allocation bug in Eigen::internal::aligned_free(). - [](const Eigen::MatrixXf& matrix) { + [](const Eigen::MatrixXf& matrix, bool transpose) { // MakePacket copies the data. + if (transpose) { + return MakePacket(matrix.transpose()); + } return MakePacket(matrix); }, R"doc(Create a MediaPipe Matrix Packet from a 2d numpy float ndarray. @@ -613,6 +616,8 @@ void PublicPacketCreators(pybind11::module* m) { Args: matrix: A 2d numpy float ndarray. + transpose: A boolean to indicate if the input matrix needs to be transposed. + Default to False. Returns: A MediaPipe Matrix Packet. @@ -625,6 +630,7 @@ void PublicPacketCreators(pybind11::module* m) { np.array([[.1, .2, .3], [.4, .5, .6]]) matrix = mp.packet_getter.get_matrix(packet) )doc", + py::arg("matrix"), py::arg("transpose") = false, py::return_value_policy::move); } // NOLINT(readability/fn_size) From 416f91180b64b67520289a0fabed4173c6e32e5a Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 4 Nov 2022 16:55:26 -0700 Subject: [PATCH 33/34] Remove the use of designated initializers. PiperOrigin-RevId: 486261384 --- .../hand_landmarks_deduplication_calculator.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc index 5a5baa50e..564184c64 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc @@ -144,10 +144,10 @@ Rect CalculateBound(const NormalizedLandmarkList& list) { } // Populate normalized non rotated face bounding box - return {.left = bounding_box_left, - .top = bounding_box_top, - .right = bounding_box_right, - .bottom = bounding_box_bottom}; + return Rect{/*left=*/bounding_box_left, + /*top=*/bounding_box_top, + /*right=*/bounding_box_right, + /*bottom=*/bounding_box_bottom}; } // Uses IoU and distance of some corresponding hand landmarks to detect From 91782a27725d78efdaa84dc6fc1aabb06279af88 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Fri, 4 Nov 2022 19:45:46 -0700 Subject: [PATCH 34/34] Internal change PiperOrigin-RevId: 486283316 --- mediapipe/calculators/core/BUILD | 1 + .../core/flow_limiter_calculator.cc | 142 +++++++----- .../core/flow_limiter_calculator_test.cc | 212 +++++++++++++++++- 3 files changed, 290 insertions(+), 65 deletions(-) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 74398be42..ecd878115 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -936,6 +936,7 @@ cc_test( "//mediapipe/framework/tool:simulation_clock", "//mediapipe/framework/tool:simulation_clock_executor", "//mediapipe/framework/tool:sink", + "//mediapipe/util:packet_test_util", "@com_google_absl//absl/time", ], ) diff --git a/mediapipe/calculators/core/flow_limiter_calculator.cc b/mediapipe/calculators/core/flow_limiter_calculator.cc index d209b1dbb..5b08f3af5 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator.cc @@ -18,7 +18,6 @@ #include "mediapipe/calculators/core/flow_limiter_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" -#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/util/header_util.h" @@ -68,7 +67,7 @@ constexpr char kOptionsTag[] = "OPTIONS"; // FlowLimiterCalculator provides limited support for multiple input streams. // The first input stream is treated as the main input stream and successive // input streams are treated as auxiliary input streams. The auxiliary input -// streams are limited to timestamps passed on the main input stream. +// streams are limited to timestamps allowed by the "ALLOW" stream. // class FlowLimiterCalculator : public CalculatorBase { public: @@ -100,64 +99,11 @@ class FlowLimiterCalculator : public CalculatorBase { cc->InputSidePackets().Tag(kMaxInFlightTag).Get()); } input_queues_.resize(cc->Inputs().NumEntries("")); + allowed_[Timestamp::Unset()] = true; RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs()))); return absl::OkStatus(); } - // Returns true if an additional frame can be released for processing. - // The "ALLOW" output stream indicates this condition at each input frame. - bool ProcessingAllowed() { - return frames_in_flight_.size() < options_.max_in_flight(); - } - - // Outputs a packet indicating whether a frame was sent or dropped. - void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) { - if (cc->Outputs().HasTag(kAllowTag)) { - cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket(allow).At(ts)); - } - } - - // Sets the timestamp bound or closes an output stream. - void SetNextTimestampBound(Timestamp bound, OutputStream* stream) { - if (bound > Timestamp::Max()) { - stream->Close(); - } else { - stream->SetNextTimestampBound(bound); - } - } - - // Returns true if a certain timestamp is being processed. - bool IsInFlight(Timestamp timestamp) { - return std::find(frames_in_flight_.begin(), frames_in_flight_.end(), - timestamp) != frames_in_flight_.end(); - } - - // Releases input packets up to the latest settled input timestamp. - void ProcessAuxiliaryInputs(CalculatorContext* cc) { - Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound(); - for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) { - // Release settled frames from each input queue. - while (!input_queues_[i].empty() && - input_queues_[i].front().Timestamp() < settled_bound) { - Packet packet = input_queues_[i].front(); - input_queues_[i].pop_front(); - if (IsInFlight(packet.Timestamp())) { - cc->Outputs().Get("", i).AddPacket(packet); - } - } - - // Propagate each input timestamp bound. - if (!input_queues_[i].empty()) { - Timestamp bound = input_queues_[i].front().Timestamp(); - SetNextTimestampBound(bound, &cc->Outputs().Get("", i)); - } else { - Timestamp bound = - cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream(); - SetNextTimestampBound(bound, &cc->Outputs().Get("", i)); - } - } - } - // Releases input packets allowed by the max_in_flight constraint. absl::Status Process(CalculatorContext* cc) final { options_ = tool::RetrieveOptions(options_, cc->Inputs()); @@ -224,13 +170,97 @@ class FlowLimiterCalculator : public CalculatorBase { } ProcessAuxiliaryInputs(cc); + + // Discard old ALLOW ranges. + Timestamp input_bound = InputTimestampBound(cc); + auto first_range = std::prev(allowed_.upper_bound(input_bound)); + allowed_.erase(allowed_.begin(), first_range); return absl::OkStatus(); } + int LedgerSize() { + int result = frames_in_flight_.size() + allowed_.size(); + for (const auto& queue : input_queues_) { + result += queue.size(); + } + return result; + } + + private: + // Returns true if an additional frame can be released for processing. + // The "ALLOW" output stream indicates this condition at each input frame. + bool ProcessingAllowed() { + return frames_in_flight_.size() < options_.max_in_flight(); + } + + // Outputs a packet indicating whether a frame was sent or dropped. + void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) { + if (cc->Outputs().HasTag(kAllowTag)) { + cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket(allow).At(ts)); + } + allowed_[ts] = allow; + } + + // Returns true if a timestamp falls within a range of allowed timestamps. + bool IsAllowed(Timestamp timestamp) { + auto it = allowed_.upper_bound(timestamp); + return std::prev(it)->second; + } + + // Sets the timestamp bound or closes an output stream. + void SetNextTimestampBound(Timestamp bound, OutputStream* stream) { + if (bound > Timestamp::Max()) { + stream->Close(); + } else { + stream->SetNextTimestampBound(bound); + } + } + + // Returns the lowest unprocessed input Timestamp. + Timestamp InputTimestampBound(CalculatorContext* cc) { + Timestamp result = Timestamp::Done(); + for (int i = 0; i < input_queues_.size(); ++i) { + auto& queue = input_queues_[i]; + auto& stream = cc->Inputs().Get("", i); + Timestamp bound = queue.empty() + ? stream.Value().Timestamp().NextAllowedInStream() + : queue.front().Timestamp(); + result = std::min(result, bound); + } + return result; + } + + // Releases input packets up to the latest settled input timestamp. + void ProcessAuxiliaryInputs(CalculatorContext* cc) { + Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound(); + for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) { + // Release settled frames from each input queue. + while (!input_queues_[i].empty() && + input_queues_[i].front().Timestamp() < settled_bound) { + Packet packet = input_queues_[i].front(); + input_queues_[i].pop_front(); + if (IsAllowed(packet.Timestamp())) { + cc->Outputs().Get("", i).AddPacket(packet); + } + } + + // Propagate each input timestamp bound. + if (!input_queues_[i].empty()) { + Timestamp bound = input_queues_[i].front().Timestamp(); + SetNextTimestampBound(bound, &cc->Outputs().Get("", i)); + } else { + Timestamp bound = + cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream(); + SetNextTimestampBound(bound, &cc->Outputs().Get("", i)); + } + } + } + private: FlowLimiterCalculatorOptions options_; std::vector> input_queues_; std::deque frames_in_flight_; + std::map allowed_; }; REGISTER_CALCULATOR(FlowLimiterCalculator); diff --git a/mediapipe/calculators/core/flow_limiter_calculator_test.cc b/mediapipe/calculators/core/flow_limiter_calculator_test.cc index 962b1c81a..8a8cc9656 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator_test.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator_test.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include #include "absl/time/clock.h" @@ -32,6 +33,7 @@ #include "mediapipe/framework/tool/simulation_clock.h" #include "mediapipe/framework/tool/simulation_clock_executor.h" #include "mediapipe/framework/tool/sink.h" +#include "mediapipe/util/packet_test_util.h" namespace mediapipe { @@ -77,6 +79,77 @@ std::vector PacketValues(const std::vector& packets) { return result; } +template +std::vector MakePackets(std::vector> contents) { + std::vector result; + for (auto& entry : contents) { + result.push_back(MakePacket(entry.second).At(entry.first)); + } + return result; +} + +std::string SourceString(Timestamp t) { + return (t.IsSpecialValue()) + ? t.DebugString() + : absl::StrCat("Timestamp(", t.DebugString(), ")"); +} + +template +class PacketsEqMatcher + : public ::testing::MatcherInterface { + public: + PacketsEqMatcher(PacketContainer packets) : packets_(packets) {} + void DescribeTo(::std::ostream* os) const override { + *os << "The expected packet contents: \n"; + Print(packets_, os); + } + bool MatchAndExplain( + const PacketContainer& value, + ::testing::MatchResultListener* listener) const override { + if (!Equals(packets_, value)) { + if (listener->IsInterested()) { + *listener << "The actual packet contents: \n"; + Print(value, listener->stream()); + } + return false; + } + return true; + } + + private: + bool Equals(const PacketContainer& c1, const PacketContainer& c2) const { + if (c1.size() != c2.size()) { + return false; + } + for (auto i1 = c1.begin(), i2 = c2.begin(); i1 != c1.end(); ++i1, ++i2) { + Packet p1 = *i1, p2 = *i2; + if (p1.Timestamp() != p2.Timestamp() || + p1.Get() != p2.Get()) { + return false; + } + } + return true; + } + void Print(const PacketContainer& packets, ::std::ostream* os) const { + for (auto it = packets.begin(); it != packets.end(); ++it) { + const Packet& packet = *it; + *os << (it == packets.begin() ? "{" : "") << "{" + << SourceString(packet.Timestamp()) << ", " + << packet.Get() << "}" + << (std::next(it) == packets.end() ? "}" : ", "); + } + } + + const PacketContainer packets_; +}; + +template +::testing::Matcher PackestEq( + const PacketContainer& packets) { + return MakeMatcher( + new PacketsEqMatcher(packets)); +} + // A Calculator::Process callback function. typedef std::function @@ -651,11 +724,12 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) { input_packets_[17], input_packets_[19], input_packets_[20], }; EXPECT_EQ(out_1_packets_, expected_output); - // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled. + // The timestamps released by FlowLimiterCalculator for in_1_sampled, + // plus input_packets_[21]. std::vector expected_output_2 = { input_packets_[0], input_packets_[2], input_packets_[4], input_packets_[14], input_packets_[17], input_packets_[19], - input_packets_[20], + input_packets_[20], input_packets_[21], }; EXPECT_EQ(out_2_packets, expected_output_2); } @@ -665,6 +739,9 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) { // The processing time "sleep_time" is reduced from 22ms to 12ms to create // the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams. TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { + auto BoolPackestEq = PackestEq, bool>; + auto IntPackestEq = PackestEq, int>; + // Configure the test. SetUpInputData(); SetUpSimulationClock(); @@ -699,11 +776,10 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { } )pb"); - auto limiter_options = ParseTextProtoOrDie(R"pb( - max_in_flight: 1 - max_in_queue: 0 - in_flight_timeout: 100000 # 100 ms - )pb"); + auto limiter_options = ParseTextProtoOrDie( + R"pb( + max_in_flight: 1 max_in_queue: 0 in_flight_timeout: 100000 # 100 ms + )pb"); std::map side_packets = { {"limiter_options", MakePacket(limiter_options)}, @@ -759,13 +835,131 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { input_packets_[0], input_packets_[2], input_packets_[15], input_packets_[17], input_packets_[19], }; - EXPECT_EQ(out_1_packets_, expected_output); + EXPECT_THAT(out_1_packets_, IntPackestEq(expected_output)); // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled. std::vector expected_output_2 = { input_packets_[0], input_packets_[2], input_packets_[4], input_packets_[15], input_packets_[17], input_packets_[19], }; - EXPECT_EQ(out_2_packets, expected_output_2); + EXPECT_THAT(out_2_packets, IntPackestEq(expected_output_2)); + + // Validate the ALLOW stream output. + std::vector expected_allow = MakePackets( // + {{Timestamp(0), true}, {Timestamp(10000), false}, + {Timestamp(20000), true}, {Timestamp(30000), false}, + {Timestamp(40000), true}, {Timestamp(50000), false}, + {Timestamp(60000), false}, {Timestamp(70000), false}, + {Timestamp(80000), false}, {Timestamp(90000), false}, + {Timestamp(100000), false}, {Timestamp(110000), false}, + {Timestamp(120000), false}, {Timestamp(130000), false}, + {Timestamp(140000), false}, {Timestamp(150000), true}, + {Timestamp(160000), false}, {Timestamp(170000), true}, + {Timestamp(180000), false}, {Timestamp(190000), true}, + {Timestamp(200000), false}}); + EXPECT_THAT(allow_packets_, BoolPackestEq(expected_allow)); +} + +// Shows how FlowLimiterCalculator releases auxiliary input packets. +// In this test, auxiliary input packets arrive at twice the primary rate. +TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { + auto BoolPackestEq = PackestEq, bool>; + auto IntPackestEq = PackestEq, int>; + + // Configure the test. + SetUpInputData(); + SetUpSimulationClock(); + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(R"pb( + input_stream: 'in_1' + input_stream: 'in_2' + node { + calculator: 'FlowLimiterCalculator' + input_side_packet: 'OPTIONS:limiter_options' + input_stream: 'in_1' + input_stream: 'in_2' + input_stream: 'FINISHED:out_1' + input_stream_info: { tag_index: 'FINISHED' back_edge: true } + output_stream: 'in_1_sampled' + output_stream: 'in_2_sampled' + output_stream: 'ALLOW:allow' + } + node { + calculator: 'SleepCalculator' + input_side_packet: 'WARMUP_TIME:warmup_time' + input_side_packet: 'SLEEP_TIME:sleep_time' + input_side_packet: 'CLOCK:clock' + input_stream: 'PACKET:in_1_sampled' + output_stream: 'PACKET:out_1' + } + )pb"); + + auto limiter_options = ParseTextProtoOrDie( + R"pb( + max_in_flight: 1 max_in_queue: 0 in_flight_timeout: 1000000 # 1s + )pb"); + std::map side_packets = { + {"limiter_options", + MakePacket(limiter_options)}, + {"warmup_time", MakePacket(22000)}, + {"sleep_time", MakePacket(22000)}, + {"clock", MakePacket(clock_)}, + }; + + // Start the graph. + MP_ASSERT_OK(graph_.Initialize(graph_config)); + MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) { + out_1_packets_.push_back(p); + return absl::OkStatus(); + })); + std::vector out_2_packets; + MP_EXPECT_OK(graph_.ObserveOutputStream("in_2_sampled", [&](Packet p) { + out_2_packets.push_back(p); + return absl::OkStatus(); + })); + MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) { + allow_packets_.push_back(p); + return absl::OkStatus(); + })); + simulation_clock_->ThreadStart(); + MP_ASSERT_OK(graph_.StartRun(side_packets)); + + // Add packets 2,4,6,8 to stream in_1 and 1..9 to stream in_2. + clock_->Sleep(absl::Microseconds(10000)); + for (int i = 1; i < 10; ++i) { + if (i % 2 == 0) { + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i])); + } + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i])); + clock_->Sleep(absl::Microseconds(10000)); + } + + // Finish the graph run. + MP_EXPECT_OK(graph_.CloseAllPacketSources()); + clock_->Sleep(absl::Microseconds(40000)); + MP_EXPECT_OK(graph_.WaitUntilDone()); + simulation_clock_->ThreadFinish(); + + // Validate the output. + // Input packets 4 and 8 are dropped due to max_in_flight. + std::vector expected_output = { + input_packets_[2], + input_packets_[6], + }; + EXPECT_THAT(out_1_packets_, IntPackestEq(expected_output)); + // Packets following input packets 2 and 6, and not input packets 4 and 8. + std::vector expected_output_2 = { + input_packets_[1], input_packets_[2], input_packets_[3], + input_packets_[6], input_packets_[7], + }; + EXPECT_THAT(out_2_packets, IntPackestEq(expected_output_2)); + + // Validate the ALLOW stream output. + std::vector expected_allow = + MakePackets({{Timestamp(20000), 1}, + {Timestamp(40000), 0}, + {Timestamp(60000), 1}, + {Timestamp(80000), 0}}); + EXPECT_THAT(allow_packets_, BoolPackestEq(expected_allow)); } } // anonymous namespace