From 6e9a070dd1046398a16ddedf9f8e473bede0bb21 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 3 Nov 2022 02:29:40 -0700 Subject: [PATCH 1/5] Added text classifier implementation and tests --- mediapipe/python/BUILD | 1 + mediapipe/tasks/python/test/text/BUILD | 37 +++ mediapipe/tasks/python/test/text/__init__.py | 13 + .../python/test/text/text_classifier_test.py | 246 ++++++++++++++++++ mediapipe/tasks/python/text/BUILD | 39 +++ mediapipe/tasks/python/text/__init__.py | 31 +++ mediapipe/tasks/python/text/core/BUILD | 31 +++ mediapipe/tasks/python/text/core/__init__.py | 14 + .../python/text/core/base_text_task_api.py | 57 ++++ .../tasks/python/text/text_classifier.py | 146 +++++++++++ 10 files changed, 615 insertions(+) create mode 100644 mediapipe/tasks/python/test/text/BUILD create mode 100644 mediapipe/tasks/python/test/text/__init__.py create mode 100644 mediapipe/tasks/python/test/text/text_classifier_test.py create mode 100644 mediapipe/tasks/python/text/BUILD create mode 100644 mediapipe/tasks/python/text/__init__.py create mode 100644 mediapipe/tasks/python/text/core/BUILD create mode 100644 mediapipe/tasks/python/text/core/__init__.py create mode 100644 mediapipe/tasks/python/text/core/base_text_task_api.py create mode 100644 mediapipe/tasks/python/text/text_classifier.py diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index b5e87b490..f9dcb50b4 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -90,6 +90,7 @@ cc_library( "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", + "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", ], ) diff --git a/mediapipe/tasks/python/test/text/BUILD b/mediapipe/tasks/python/test/text/BUILD new file mode 100644 index 000000000..bbb39ba80 --- /dev/null +++ b/mediapipe/tasks/python/test/text/BUILD @@ -0,0 +1,37 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict test compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_test( + name = "text_classifier_test", + srcs = ["text_classifier_test.py"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/processors:classifier_options", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/text:text_classifier", + ], +) diff --git a/mediapipe/tasks/python/test/text/__init__.py b/mediapipe/tasks/python/test/text/__init__.py new file mode 100644 index 000000000..65c1214af --- /dev/null +++ b/mediapipe/tasks/python/test/text/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/tasks/python/test/text/text_classifier_test.py b/mediapipe/tasks/python/test/text/text_classifier_test.py new file mode 100644 index 000000000..c97299270 --- /dev/null +++ b/mediapipe/tasks/python/test/text/text_classifier_test.py @@ -0,0 +1,246 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for text classifier.""" + +import enum +import os + +from absl.testing import absltest +from absl.testing import parameterized + + +from mediapipe.tasks.python.components.containers import category +from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.python.components.processors import classifier_options +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils +from mediapipe.tasks.python.text import text_classifier + +_BaseOptions = base_options_module.BaseOptions +_ClassifierOptions = classifier_options.ClassifierOptions +_Category = category.Category +_ClassificationEntry = classifications_module.ClassificationEntry +_Classifications = classifications_module.Classifications +_ClassificationResult = classifications_module.ClassificationResult +_TextClassifier = text_classifier.TextClassifier +_TextClassifierOptions = text_classifier.TextClassifierOptions + +_BERT_MODEL_FILE = 'bert_text_classifier.tflite' +_REGEX_MODEL_FILE = 'test_model_text_classifier_with_regex_tokenizer.tflite' +_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text' + +_NEGATIVE_TEXT = "What a waste of my time." +_POSITIVE_TEXT = ("This is the best movie I’ve seen in recent years." + "Strongly recommend it!") + +_BERT_NEGATIVE_RESULTS = _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category( + index=0, score=0.999479, display_name='', + category_name='negative'), + _Category( + index=1, score=0.00052154, display_name='', + category_name='positive') + ], + timestamp_ms=0 + ) + ], + head_index=0, + head_name='probability') + ]) +_BERT_POSITIVE_RESULTS = _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category( + index=1, score=0.999466, display_name='', + category_name='positive'), + _Category( + index=0, score=0.000533596, display_name='', + category_name='negative') + ], + timestamp_ms=0 + ) + ], + head_index=0, + head_name='probability') + ]) +_REGEX_NEGATIVE_RESULTS = _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category( + index=0, score=0.81313, display_name='', + category_name='Negative'), + _Category( + index=1, score=0.1868704, display_name='', + category_name='Positive') + ], + timestamp_ms=0 + ) + ], + head_index=0, + head_name='probability') + ]) +_REGEX_POSITIVE_RESULTS = _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category( + index=1, score=0.5134273, display_name='', + category_name='Positive'), + _Category( + index=0, score=0.486573, display_name='', + category_name='Negative') + ], + timestamp_ms=0 + ) + ], + head_index=0, + head_name='probability') + ]) + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class ImageClassifierTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE)) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _TextClassifier.create_from_model_path(self.model_path) as classifier: + self.assertIsInstance(classifier, _TextClassifier) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _TextClassifierOptions(base_options=base_options) + with _TextClassifier.create_from_options(options) as classifier: + self.assertIsInstance(classifier, _TextClassifier) + + def test_create_from_options_fails_with_invalid_model_path(self): + # Invalid empty model path. + with self.assertRaisesRegex( + ValueError, + r"ExternalFile must specify at least one of 'file_content', " + r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): + base_options = _BaseOptions(model_asset_path='') + options = _TextClassifierOptions(base_options=base_options) + _TextClassifier.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _TextClassifierOptions(base_options=base_options) + classifier = _TextClassifier.create_from_options(options) + self.assertIsInstance(classifier, _TextClassifier) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _BERT_MODEL_FILE, _NEGATIVE_TEXT, + _BERT_NEGATIVE_RESULTS), + (ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE, _NEGATIVE_TEXT, + _BERT_NEGATIVE_RESULTS), + (ModelFileType.FILE_NAME, _BERT_MODEL_FILE, _POSITIVE_TEXT, + _BERT_POSITIVE_RESULTS), + (ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE, _POSITIVE_TEXT, + _BERT_POSITIVE_RESULTS), + (ModelFileType.FILE_NAME, _REGEX_MODEL_FILE, _NEGATIVE_TEXT, + _REGEX_NEGATIVE_RESULTS), + (ModelFileType.FILE_CONTENT, _REGEX_MODEL_FILE, _NEGATIVE_TEXT, + _REGEX_NEGATIVE_RESULTS), + (ModelFileType.FILE_NAME, _REGEX_MODEL_FILE, _POSITIVE_TEXT, + _REGEX_POSITIVE_RESULTS), + (ModelFileType.FILE_CONTENT, _REGEX_MODEL_FILE, _POSITIVE_TEXT, + _REGEX_POSITIVE_RESULTS)) + def test_classify(self, model_file_type, model_name, text, + expected_classification_result): + # Creates classifier. + model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, model_name)) + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + custom_classifier_options = _ClassifierOptions() + options = _TextClassifierOptions( + base_options=base_options, classifier_options=custom_classifier_options) + classifier = _TextClassifier.create_from_options(options) + + # Performs text classification on the input. + text_result = classifier.classify(text) + # Comparing results. + test_utils.assert_proto_equals(self, text_result.to_pb2(), + expected_classification_result.to_pb2()) + # Closes the classifier explicitly when the classifier is not used in + # a context. + classifier.close() + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _BERT_MODEL_FILE, _NEGATIVE_TEXT, + _BERT_NEGATIVE_RESULTS), + (ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE, _NEGATIVE_TEXT, + _BERT_NEGATIVE_RESULTS)) + def test_classify_in_context(self, model_file_type, model_name, text, + expected_classification_result): + # Creates classifier. + model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, model_name)) + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + custom_classifier_options = _ClassifierOptions() + options = _TextClassifierOptions( + base_options=base_options, classifier_options=custom_classifier_options) + + with _TextClassifier.create_from_options(options) as classifier: + # Performs text classification on the input. + text_result = classifier.classify(text) + # Comparing results. + test_utils.assert_proto_equals(self, text_result.to_pb2(), + expected_classification_result.to_pb2()) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD new file mode 100644 index 000000000..0975d38dc --- /dev/null +++ b/mediapipe/tasks/python/text/BUILD @@ -0,0 +1,39 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_library( + name = "text_classifier", + srcs = [ + "text_classifier.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2", + "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/processors:classifier_options", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/text/core:base_text_task_api", + ], +) diff --git a/mediapipe/tasks/python/text/__init__.py b/mediapipe/tasks/python/text/__init__.py new file mode 100644 index 000000000..def113178 --- /dev/null +++ b/mediapipe/tasks/python/text/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MediaPipe Tasks Vision API.""" + +import mediapipe.tasks.python.vision.core +import mediapipe.tasks.python.vision.image_classifier +import mediapipe.tasks.python.vision.object_detector + +ImageClassifier = image_classifier.ImageClassifier +ImageClassifierOptions = image_classifier.ImageClassifierOptions +ObjectDetector = object_detector.ObjectDetector +ObjectDetectorOptions = object_detector.ObjectDetectorOptions +RunningMode = core.vision_task_running_mode.VisionTaskRunningMode + +# Remove unnecessary modules to avoid duplication in API docs. +del core +del image_classifier +del object_detector +del mediapipe diff --git a/mediapipe/tasks/python/text/core/BUILD b/mediapipe/tasks/python/text/core/BUILD new file mode 100644 index 000000000..072a0c7d8 --- /dev/null +++ b/mediapipe/tasks/python/text/core/BUILD @@ -0,0 +1,31 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_library( + name = "base_text_task_api", + srcs = [ + "base_text_task_api.py", + ], + deps = [ + "//mediapipe/framework:calculator_py_pb2", + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) diff --git a/mediapipe/tasks/python/text/core/__init__.py b/mediapipe/tasks/python/text/core/__init__.py new file mode 100644 index 000000000..ad7f0fd95 --- /dev/null +++ b/mediapipe/tasks/python/text/core/__init__.py @@ -0,0 +1,14 @@ +"""Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/mediapipe/tasks/python/text/core/base_text_task_api.py b/mediapipe/tasks/python/text/core/base_text_task_api.py new file mode 100644 index 000000000..28a08199f --- /dev/null +++ b/mediapipe/tasks/python/text/core/base_text_task_api.py @@ -0,0 +1,57 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe text task base api.""" + +from mediapipe.framework import calculator_pb2 +from mediapipe.python._framework_bindings import task_runner +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_TaskRunner = task_runner.TaskRunner + + +class BaseTextTaskApi(object): + """The base class of the user-facing mediapipe text task api classes.""" + + def __init__( + self, + graph_config: calculator_pb2.CalculatorGraphConfig + ) -> None: + """Initializes the `BaseVisionTaskApi` object. + + Args: + graph_config: The mediapipe text task graph config proto. + """ + self._runner = _TaskRunner.create(graph_config) + + def close(self) -> None: + """Shuts down the mediapipe text task instance. + + Raises: + RuntimeError: If the mediapipe text task failed to close. + """ + self._runner.close() + + @doc_controls.do_not_generate_docs + def __enter__(self): + """Return `self` upon entering the runtime context.""" + return self + + @doc_controls.do_not_generate_docs + def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback): + """Shuts down the mediapipe text task instance on exit of the context manager. + + Raises: + RuntimeError: If the mediapipe text task failed to close. + """ + self.close() diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py new file mode 100644 index 000000000..ba459a629 --- /dev/null +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -0,0 +1,146 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe text classifier task.""" + +import dataclasses + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +# TODO: Import MPImage directly one we have an alias +from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2 +from mediapipe.tasks.python.components.containers import classifications +from mediapipe.tasks.python.components.processors import classifier_options +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.text.core import base_text_task_api + +_BaseOptions = base_options_module.BaseOptions +_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions +_ClassifierOptions = classifier_options.ClassifierOptions +_TaskInfo = task_info_module.TaskInfo + +_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' +_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' +_TEXT_IN_STREAM_NAME = 'text_in' +_TEXT_TAG = 'TEXT' +_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph' +_MICRO_SECONDS_PER_MILLISECOND = 1000 + + +@dataclasses.dataclass +class TextClassifierOptions: + """Options for the text classifier task. + + Attributes: + base_options: Base options for the text classifier task. + classifier_options: Options for the text classification task. + """ + base_options: _BaseOptions + classifier_options: _ClassifierOptions = _ClassifierOptions() + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _TextClassifierGraphOptionsProto: + """Generates an TextClassifierOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + classifier_options_proto = self.classifier_options.to_pb2() + + return _TextClassifierGraphOptionsProto( + base_options=base_options_proto, + classifier_options=classifier_options_proto) + + +class TextClassifier(base_text_task_api.BaseTextTaskApi): + """Class that performs classification on text.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'TextClassifier': + """Creates an `TextClassifier` object from a TensorFlow Lite model and the default `TextClassifierOptions`. + + Args: + model_path: Path to the model. + + Returns: + `TextClassifier` object that's created from the model file and the + default `TextClassifierOptions`. + + Raises: + ValueError: If failed to create `TextClassifier` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = TextClassifierOptions(base_options=base_options) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, + options: TextClassifierOptions) -> 'TextClassifier': + """Creates the `TextClassifier` object from text classifier options. + + Args: + options: Options for the text classifier task. + + Returns: + `TextClassifier` object that's created from `options`. + + Raises: + ValueError: If failed to create `TextClassifier` object from + `TextClassifierOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[ + ':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME]) + ], + output_streams=[ + ':'.join([ + _CLASSIFICATION_RESULT_TAG, + _CLASSIFICATION_RESULT_OUT_STREAM_NAME + ]) + ], + task_options=options) + return cls(task_info.generate_graph_config()) + + def classify( + self, + text: str, + ) -> classifications.ClassificationResult: + """Performs classification on the input `text`. + + Args: + text: The input text. + + Returns: + A classification result object that contains a list of classifications. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If text classification failed to run. + """ + output_packets = self._runner.process({ + _TEXT_IN_STREAM_NAME: packet_creator.create_string(text) + }) + + classification_result_proto = classifications_pb2.ClassificationResult() + classification_result_proto.CopyFrom( + packet_getter.get_proto( + output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])) + + return classifications.ClassificationResult([ + classifications.Classifications.create_from_pb2(classification) + for classification in classification_result_proto.classifications + ]) From 2001e2b77c57e988e57e860f01dd2871569ccf82 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 3 Nov 2022 02:32:55 -0700 Subject: [PATCH 2/5] Removed unused file --- mediapipe/tasks/python/text/__init__.py | 31 ------------------------- 1 file changed, 31 deletions(-) delete mode 100644 mediapipe/tasks/python/text/__init__.py diff --git a/mediapipe/tasks/python/text/__init__.py b/mediapipe/tasks/python/text/__init__.py deleted file mode 100644 index def113178..000000000 --- a/mediapipe/tasks/python/text/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""MediaPipe Tasks Vision API.""" - -import mediapipe.tasks.python.vision.core -import mediapipe.tasks.python.vision.image_classifier -import mediapipe.tasks.python.vision.object_detector - -ImageClassifier = image_classifier.ImageClassifier -ImageClassifierOptions = image_classifier.ImageClassifierOptions -ObjectDetector = object_detector.ObjectDetector -ObjectDetectorOptions = object_detector.ObjectDetectorOptions -RunningMode = core.vision_task_running_mode.VisionTaskRunningMode - -# Remove unnecessary modules to avoid duplication in API docs. -del core -del image_classifier -del object_detector -del mediapipe From de619e2702d4fd4c83503214a6d72d7351fd6eca Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 3 Nov 2022 02:33:50 -0700 Subject: [PATCH 3/5] Removed a comment --- mediapipe/tasks/python/text/text_classifier.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py index ba459a629..a6fa852ca 100644 --- a/mediapipe/tasks/python/text/text_classifier.py +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -17,7 +17,6 @@ import dataclasses from mediapipe.python import packet_creator from mediapipe.python import packet_getter -# TODO: Import MPImage directly one we have an alias from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2 from mediapipe.tasks.python.components.containers import classifications From 14e18a03d6c0bc589554fb395c583e754106379d Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 3 Nov 2022 02:54:57 -0700 Subject: [PATCH 4/5] Removed unused custom classifier options in tests --- mediapipe/tasks/python/test/text/text_classifier_test.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mediapipe/tasks/python/test/text/text_classifier_test.py b/mediapipe/tasks/python/test/text/text_classifier_test.py index c97299270..bb4528c75 100644 --- a/mediapipe/tasks/python/test/text/text_classifier_test.py +++ b/mediapipe/tasks/python/test/text/text_classifier_test.py @@ -196,9 +196,7 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - custom_classifier_options = _ClassifierOptions() - options = _TextClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + options = _TextClassifierOptions(base_options=base_options) classifier = _TextClassifier.create_from_options(options) # Performs text classification on the input. @@ -230,9 +228,7 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - custom_classifier_options = _ClassifierOptions() - options = _TextClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + options = _TextClassifierOptions(base_options=base_options) with _TextClassifier.create_from_options(options) as classifier: # Performs text classification on the input. From b44c49250b7afe7d082a795f490bcd974e58616d Mon Sep 17 00:00:00 2001 From: kinaryml Date: Fri, 4 Nov 2022 22:36:13 -0700 Subject: [PATCH 5/5] Removed unused constant --- mediapipe/tasks/python/text/text_classifier.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py index a6fa852ca..12f4e413b 100644 --- a/mediapipe/tasks/python/text/text_classifier.py +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -36,7 +36,6 @@ _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' _TEXT_IN_STREAM_NAME = 'text_in' _TEXT_TAG = 'TEXT' _TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph' -_MICRO_SECONDS_PER_MILLISECOND = 1000 @dataclasses.dataclass