From 2a2a55d1b84c908120fe237b9e892c708d05c532 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Fri, 21 Apr 2023 11:46:21 -0700 Subject: [PATCH] Added Language Detector Python API and fixed a typo in Interactive Segmenter Options' docstring --- mediapipe/tasks/python/test/text/BUILD | 15 ++ .../test/text/language_detector_test.py | 225 ++++++++++++++++++ mediapipe/tasks/python/text/BUILD | 20 ++ .../tasks/python/text/language_detector.py | 205 ++++++++++++++++ .../python/vision/interactive_segmenter.py | 2 +- mediapipe/tasks/testdata/text/BUILD | 6 + 6 files changed, 472 insertions(+), 1 deletion(-) create mode 100644 mediapipe/tasks/python/test/text/language_detector_test.py create mode 100644 mediapipe/tasks/python/text/language_detector.py diff --git a/mediapipe/tasks/python/test/text/BUILD b/mediapipe/tasks/python/test/text/BUILD index 5f2d18bc5..5f8551636 100644 --- a/mediapipe/tasks/python/test/text/BUILD +++ b/mediapipe/tasks/python/test/text/BUILD @@ -49,3 +49,18 @@ py_test( "//mediapipe/tasks/python/text:text_embedder", ], ) + +py_test( + name = "language_detector_test", + srcs = ["language_detector_test.py"], + data = [ + "//mediapipe/tasks/testdata/text:language_detector_model", + ], + deps = [ + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:classification_result", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/text:language_detector", + ], +) diff --git a/mediapipe/tasks/python/test/text/language_detector_test.py b/mediapipe/tasks/python/test/text/language_detector_test.py new file mode 100644 index 000000000..2443d4312 --- /dev/null +++ b/mediapipe/tasks/python/test/text/language_detector_test.py @@ -0,0 +1,225 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for language detector.""" + +import enum +import os + +from absl.testing import absltest +from absl.testing import parameterized + +from mediapipe.tasks.python.components.containers import category +from mediapipe.tasks.python.components.containers import classification_result as classification_result_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils +from mediapipe.tasks.python.text import language_detector + +LanguageDetectorResult = language_detector.LanguageDetectorResult +LanguageDetectorPrediction = language_detector.LanguageDetectorResult.LanguageDetectorPrediction +_BaseOptions = base_options_module.BaseOptions +_Category = category.Category +_Classifications = classification_result_module.Classifications +_LanguageDetector = language_detector.LanguageDetector +_LanguageDetectorOptions = language_detector.LanguageDetectorOptions + +_LANGUAGE_DETECTOR_MODEL = 'language_detector.tflite' +_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text' + +_SCORE_THRESHOLD = 0.3 +_EN_TEXT = "To be, or not to be, that is the question" +_EN_EXPECTED_RESULT = LanguageDetectorResult( + [ + LanguageDetectorPrediction("en", 0.999856) + ] +) +_FR_TEXT = ( + "Il y a beaucoup de bouches qui parlent et fort peu de têtes qui pensent." +) +_FR_EXPECTED_RESULT = LanguageDetectorResult( + [ + LanguageDetectorPrediction("fr", 0.999781) + ] +) +_RU_TEXT = "это какой-то английский язык" +_RU_EXPECTED_RESULT = LanguageDetectorResult( + [ + LanguageDetectorPrediction("ru", 0.993362) + ] +) +_MIXED_TEXT = "分久必合合久必分" +_MIXED_EXPECTED_RESULT = LanguageDetectorResult( + [ + LanguageDetectorPrediction("zh", 0.505424), + LanguageDetectorPrediction("ja", 0.481617) + ] +) +_TOLERANCE = 1e-6 + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class LanguageDetectorTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _LANGUAGE_DETECTOR_MODEL)) + + def _expect_language_detector_result_correct( + self, + actual_result: LanguageDetectorResult, + expect_result: LanguageDetectorResult + ): + for i, prediction in enumerate(actual_result.languages_and_scores): + expected_prediction = expect_result.languages_and_scores[i] + self.assertEqual( + prediction.language_code, expected_prediction.language_code, + ) + self.assertAlmostEqual( + prediction.probability, expected_prediction.probability, + delta=_TOLERANCE + ) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _LanguageDetector.create_from_model_path(self.model_path) as detector: + self.assertIsInstance(detector, _LanguageDetector) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _LanguageDetectorOptions(base_options=base_options) + with _LanguageDetector.create_from_options(options) as detector: + self.assertIsInstance(detector, _LanguageDetector) + + def test_create_from_options_fails_with_invalid_model_path(self): + with self.assertRaisesRegex( + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite') + options = _LanguageDetectorOptions(base_options=base_options) + _LanguageDetector.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _LanguageDetectorOptions(base_options=base_options) + detector = _LanguageDetector.create_from_options(options) + self.assertIsInstance(detector, _LanguageDetector) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT), + (ModelFileType.FILE_CONTENT, _EN_TEXT, _EN_EXPECTED_RESULT), + (ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT), + (ModelFileType.FILE_CONTENT, _FR_TEXT, _FR_EXPECTED_RESULT), + (ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT), + (ModelFileType.FILE_CONTENT, _RU_TEXT, _RU_EXPECTED_RESULT), + (ModelFileType.FILE_NAME, _MIXED_TEXT, _MIXED_EXPECTED_RESULT), + (ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT) + ) + def test_detect(self, model_file_type, text, expected_result): + # Creates detector. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _LanguageDetectorOptions( + base_options=base_options, score_threshold=_SCORE_THRESHOLD + ) + detector = _LanguageDetector.create_from_options(options) + + # Performs language detection on the input. + text_result = detector.detect(text) + # Comparing results. + self._expect_language_detector_result_correct(text_result, expected_result) + # Closes the detector explicitly when the detector is not used in + # a context. + detector.close() + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT), + (ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT), + (ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT), + (ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT) + ) + def test_detect_in_context(self, model_file_type, text, expected_result): + # Creates detector. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _LanguageDetectorOptions( + base_options=base_options, score_threshold=_SCORE_THRESHOLD + ) + with _LanguageDetector.create_from_options(options) as detector: + # Performs language detection on the input. + text_result = detector.detect(text) + # Comparing results. + self._expect_language_detector_result_correct(text_result, expected_result) + + def test_allowlist_option(self): + # Creates detector. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _LanguageDetectorOptions( + base_options=base_options, score_threshold=_SCORE_THRESHOLD, + category_allowlist=["ja"] + ) + with _LanguageDetector.create_from_options(options) as detector: + # Performs language detection on the input. + text_result = detector.detect(_MIXED_TEXT) + # Comparing results. + expected_result = LanguageDetectorResult( + [ + LanguageDetectorPrediction("ja", 0.481617) + ] + ) + self._expect_language_detector_result_correct(text_result, expected_result) + + def test_denylist_option(self): + # Creates detector. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _LanguageDetectorOptions( + base_options=base_options, score_threshold=_SCORE_THRESHOLD, + category_denylist=["ja"] + ) + with _LanguageDetector.create_from_options(options) as detector: + # Performs language detection on the input. + text_result = detector.detect(_MIXED_TEXT) + # Comparing results. + expected_result = LanguageDetectorResult( + [ + LanguageDetectorPrediction("zh", 0.505424) + ] + ) + self._expect_language_detector_result_correct(text_result, expected_result) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index 9d5d23261..b1dd3feb9 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -57,3 +57,23 @@ py_library( "//mediapipe/tasks/python/text/core:base_text_task_api", ], ) + +py_library( + name = "language_detector", + srcs = [ + "language_detector.py", + ], + visibility = ["//mediapipe/tasks:users"], + deps = [ + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2", + "//mediapipe/tasks/python/components/containers:classification_result", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/text/core:base_text_task_api", + ], +) diff --git a/mediapipe/tasks/python/text/language_detector.py b/mediapipe/tasks/python/text/language_detector.py new file mode 100644 index 000000000..8d933dd59 --- /dev/null +++ b/mediapipe/tasks/python/text/language_detector.py @@ -0,0 +1,205 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe language detector task.""" + +import dataclasses +from typing import Optional, List + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 +from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2 +from mediapipe.tasks.python.components.containers import classification_result as classification_result_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.text.core import base_text_task_api + +_ClassificationResult = classification_result_module.ClassificationResult +_BaseOptions = base_options_module.BaseOptions +_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions +_TaskInfo = task_info_module.TaskInfo + +_CLASSIFICATIONS_STREAM_NAME = 'classifications_out' +_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS' +_TEXT_IN_STREAM_NAME = 'text_in' +_TEXT_TAG = 'TEXT' +_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph' + + +@dataclasses.dataclass +class LanguageDetectorResult: + + @dataclasses.dataclass + class LanguageDetectorPrediction: + """A language code and its probability.""" + language_code: str + probability: float + + languages_and_scores: List[LanguageDetectorPrediction] + + +def _extract_language_detector_result( + classification_result: classification_result_module.ClassificationResult +) -> LanguageDetectorResult: + if len(classification_result.classifications) != 1: + raise ValueError( + "The LanguageDetector TextClassifierGraph should have exactly one " + "classification head." + ) + languages_and_scores = classification_result.classifications[0] + language_detector_result = LanguageDetectorResult([]) + for category in languages_and_scores.categories: + if category.category_name is None: + raise ValueError( + "LanguageDetector ClassificationResult has a missing language code.") + prediction = LanguageDetectorResult.LanguageDetectorPrediction( + category.category_name, category.score + ) + language_detector_result.languages_and_scores.append(prediction) + return language_detector_result + + +@dataclasses.dataclass +class LanguageDetectorOptions: + """Options for the language detector task. + + Attributes: + base_options: Base options for the language detector task. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. + """ + base_options: _BaseOptions + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _TextClassifierGraphOptionsProto: + """Generates an TextClassifierOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) + + return _TextClassifierGraphOptionsProto( + base_options=base_options_proto, + classifier_options=classifier_options_proto) + + +class LanguageDetector(base_text_task_api.BaseTextTaskApi): + """Class that predicts the language of an input text. + + This API expects a TFLite model with TFLite Model Metadata that contains the + mandatory (described below) input tensors, output tensor, and the language + codes in an AssociatedFile. + + Input tensors: + (kTfLiteString) + - 1 input tensor that is scalar or has shape [1] containing the input + string. + Output tensor: + (kTfLiteFloat32) + - 1 output tensor of shape`[1 x N]` where `N` is the number of languages. + """ + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'LanguageDetector': + """Creates an `LanguageDetector` object from a TensorFlow Lite model and the default `LanguageDetectorOptions`. + + Args: + model_path: Path to the model. + + Returns: + `LanguageDetector` object that's created from the model file and the + default `LanguageDetectorOptions`. + + Raises: + ValueError: If failed to create `LanguageDetector` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = LanguageDetectorOptions(base_options=base_options) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, + options: LanguageDetectorOptions) -> 'LanguageDetector': + """Creates the `LanguageDetector` object from language detector options. + + Args: + options: Options for the language detector task. + + Returns: + `LanguageDetector` object that's created from `options`. + + Raises: + ValueError: If failed to create `LanguageDetector` object from + `LanguageDetectorOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])], + output_streams=[ + ':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME]) + ], + task_options=options) + return cls(task_info.generate_graph_config()) + + def detect(self, text: str) -> LanguageDetectorResult: + """Predicts the language of the input `text`. + + Args: + text: The input text. + + Returns: + A `LanguageDetectorResult` object that contains a list of languages and + scores. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If language detection failed to run. + """ + output_packets = self._runner.process( + {_TEXT_IN_STREAM_NAME: packet_creator.create_string(text)}) + + classification_result_proto = classifications_pb2.ClassificationResult() + classification_result_proto.CopyFrom( + packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])) + + classification_result = _ClassificationResult.create_from_pb2( + classification_result_proto + ) + return _extract_language_detector_result(classification_result) diff --git a/mediapipe/tasks/python/vision/interactive_segmenter.py b/mediapipe/tasks/python/vision/interactive_segmenter.py index ad93c798c..1d9f5cf1a 100644 --- a/mediapipe/tasks/python/vision/interactive_segmenter.py +++ b/mediapipe/tasks/python/vision/interactive_segmenter.py @@ -88,7 +88,7 @@ class InteractiveSegmenterOptions: @doc_controls.do_not_generate_docs def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: - """Generates an InteractiveSegmenterOptions protobuf object.""" + """Generates an ImageSegmenterGraphOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False segmenter_options_proto = _SegmenterOptionsProto() diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 9813b6543..62251ed8b 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -31,6 +31,7 @@ mediapipe_files(srcs = [ "bert_text_classifier.tflite", "mobilebert_embedding_with_metadata.tflite", "mobilebert_with_metadata.tflite", + "language_detector.tflite", "regex_one_embedding_with_metadata.tflite", "test_model_text_classifier_bool_output.tflite", "test_model_text_classifier_with_regex_tokenizer.tflite", @@ -78,6 +79,11 @@ filegroup( ], ) +filegroup( + name = "language_detector_model", + srcs = ["language_detector.tflite"], +) + filegroup( name = "text_classifier_models", srcs = [