Added Language Detector Python API and fixed a typo in Interactive Segmenter Options' docstring

This commit is contained in:
kinaryml 2023-04-21 11:46:21 -07:00
parent a6c35e9ba5
commit 2a2a55d1b8
6 changed files with 472 additions and 1 deletions

View File

@ -49,3 +49,18 @@ py_test(
"//mediapipe/tasks/python/text:text_embedder", "//mediapipe/tasks/python/text:text_embedder",
], ],
) )
py_test(
name = "language_detector_test",
srcs = ["language_detector_test.py"],
data = [
"//mediapipe/tasks/testdata/text:language_detector_model",
],
deps = [
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/text:language_detector",
],
)

View File

@ -0,0 +1,225 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for language detector."""
import enum
import os
from absl.testing import absltest
from absl.testing import parameterized
from mediapipe.tasks.python.components.containers import category
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.text import language_detector
LanguageDetectorResult = language_detector.LanguageDetectorResult
LanguageDetectorPrediction = language_detector.LanguageDetectorResult.LanguageDetectorPrediction
_BaseOptions = base_options_module.BaseOptions
_Category = category.Category
_Classifications = classification_result_module.Classifications
_LanguageDetector = language_detector.LanguageDetector
_LanguageDetectorOptions = language_detector.LanguageDetectorOptions
_LANGUAGE_DETECTOR_MODEL = 'language_detector.tflite'
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
_SCORE_THRESHOLD = 0.3
_EN_TEXT = "To be, or not to be, that is the question"
_EN_EXPECTED_RESULT = LanguageDetectorResult(
[
LanguageDetectorPrediction("en", 0.999856)
]
)
_FR_TEXT = (
"Il y a beaucoup de bouches qui parlent et fort peu de têtes qui pensent."
)
_FR_EXPECTED_RESULT = LanguageDetectorResult(
[
LanguageDetectorPrediction("fr", 0.999781)
]
)
_RU_TEXT = "это какой-то английский язык"
_RU_EXPECTED_RESULT = LanguageDetectorResult(
[
LanguageDetectorPrediction("ru", 0.993362)
]
)
_MIXED_TEXT = "分久必合合久必分"
_MIXED_EXPECTED_RESULT = LanguageDetectorResult(
[
LanguageDetectorPrediction("zh", 0.505424),
LanguageDetectorPrediction("ja", 0.481617)
]
)
_TOLERANCE = 1e-6
class ModelFileType(enum.Enum):
FILE_CONTENT = 1
FILE_NAME = 2
class LanguageDetectorTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _LANGUAGE_DETECTOR_MODEL))
def _expect_language_detector_result_correct(
self,
actual_result: LanguageDetectorResult,
expect_result: LanguageDetectorResult
):
for i, prediction in enumerate(actual_result.languages_and_scores):
expected_prediction = expect_result.languages_and_scores[i]
self.assertEqual(
prediction.language_code, expected_prediction.language_code,
)
self.assertAlmostEqual(
prediction.probability, expected_prediction.probability,
delta=_TOLERANCE
)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _LanguageDetector.create_from_model_path(self.model_path) as detector:
self.assertIsInstance(detector, _LanguageDetector)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _LanguageDetectorOptions(base_options=base_options)
with _LanguageDetector.create_from_options(options) as detector:
self.assertIsInstance(detector, _LanguageDetector)
def test_create_from_options_fails_with_invalid_model_path(self):
with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite')
options = _LanguageDetectorOptions(base_options=base_options)
_LanguageDetector.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _LanguageDetectorOptions(base_options=base_options)
detector = _LanguageDetector.create_from_options(options)
self.assertIsInstance(detector, _LanguageDetector)
@parameterized.parameters(
(ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT),
(ModelFileType.FILE_CONTENT, _EN_TEXT, _EN_EXPECTED_RESULT),
(ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT),
(ModelFileType.FILE_CONTENT, _FR_TEXT, _FR_EXPECTED_RESULT),
(ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT),
(ModelFileType.FILE_CONTENT, _RU_TEXT, _RU_EXPECTED_RESULT),
(ModelFileType.FILE_NAME, _MIXED_TEXT, _MIXED_EXPECTED_RESULT),
(ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT)
)
def test_detect(self, model_file_type, text, expected_result):
# Creates detector.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _LanguageDetectorOptions(
base_options=base_options, score_threshold=_SCORE_THRESHOLD
)
detector = _LanguageDetector.create_from_options(options)
# Performs language detection on the input.
text_result = detector.detect(text)
# Comparing results.
self._expect_language_detector_result_correct(text_result, expected_result)
# Closes the detector explicitly when the detector is not used in
# a context.
detector.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT),
(ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT),
(ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT),
(ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT)
)
def test_detect_in_context(self, model_file_type, text, expected_result):
# Creates detector.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _LanguageDetectorOptions(
base_options=base_options, score_threshold=_SCORE_THRESHOLD
)
with _LanguageDetector.create_from_options(options) as detector:
# Performs language detection on the input.
text_result = detector.detect(text)
# Comparing results.
self._expect_language_detector_result_correct(text_result, expected_result)
def test_allowlist_option(self):
# Creates detector.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _LanguageDetectorOptions(
base_options=base_options, score_threshold=_SCORE_THRESHOLD,
category_allowlist=["ja"]
)
with _LanguageDetector.create_from_options(options) as detector:
# Performs language detection on the input.
text_result = detector.detect(_MIXED_TEXT)
# Comparing results.
expected_result = LanguageDetectorResult(
[
LanguageDetectorPrediction("ja", 0.481617)
]
)
self._expect_language_detector_result_correct(text_result, expected_result)
def test_denylist_option(self):
# Creates detector.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _LanguageDetectorOptions(
base_options=base_options, score_threshold=_SCORE_THRESHOLD,
category_denylist=["ja"]
)
with _LanguageDetector.create_from_options(options) as detector:
# Performs language detection on the input.
text_result = detector.detect(_MIXED_TEXT)
# Comparing results.
expected_result = LanguageDetectorResult(
[
LanguageDetectorPrediction("zh", 0.505424)
]
)
self._expect_language_detector_result_correct(text_result, expected_result)
if __name__ == '__main__':
absltest.main()

View File

@ -57,3 +57,23 @@ py_library(
"//mediapipe/tasks/python/text/core:base_text_task_api", "//mediapipe/tasks/python/text/core:base_text_task_api",
], ],
) )
py_library(
name = "language_detector",
srcs = [
"language_detector.py",
],
visibility = ["//mediapipe/tasks:users"],
deps = [
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/text/core:base_text_task_api",
],
)

View File

@ -0,0 +1,205 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe language detector task."""
import dataclasses
from typing import Optional, List
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.text.core import base_text_task_api
_ClassificationResult = classification_result_module.ClassificationResult
_BaseOptions = base_options_module.BaseOptions
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
_TaskInfo = task_info_module.TaskInfo
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
_TEXT_IN_STREAM_NAME = 'text_in'
_TEXT_TAG = 'TEXT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'
@dataclasses.dataclass
class LanguageDetectorResult:
@dataclasses.dataclass
class LanguageDetectorPrediction:
"""A language code and its probability."""
language_code: str
probability: float
languages_and_scores: List[LanguageDetectorPrediction]
def _extract_language_detector_result(
classification_result: classification_result_module.ClassificationResult
) -> LanguageDetectorResult:
if len(classification_result.classifications) != 1:
raise ValueError(
"The LanguageDetector TextClassifierGraph should have exactly one "
"classification head."
)
languages_and_scores = classification_result.classifications[0]
language_detector_result = LanguageDetectorResult([])
for category in languages_and_scores.categories:
if category.category_name is None:
raise ValueError(
"LanguageDetector ClassificationResult has a missing language code.")
prediction = LanguageDetectorResult.LanguageDetectorPrediction(
category.category_name, category.score
)
language_detector_result.languages_and_scores.append(prediction)
return language_detector_result
@dataclasses.dataclass
class LanguageDetectorOptions:
"""Options for the language detector task.
Attributes:
base_options: Base options for the language detector task.
display_names_locale: The locale to use for display names specified through
the TFLite Model Metadata.
max_results: The maximum number of top-scored classification results to
return.
score_threshold: Overrides the ones provided in the model metadata. Results
below this value are rejected.
category_allowlist: Allowlist of category names. If non-empty,
classification results whose category name is not in this set will be
filtered out. Duplicate or unknown category names are ignored. Mutually
exclusive with `category_denylist`.
category_denylist: Denylist of category names. If non-empty, classification
results whose category name is in this set will be filtered out. Duplicate
or unknown category names are ignored. Mutually exclusive with
`category_allowlist`.
"""
base_options: _BaseOptions
display_names_locale: Optional[str] = None
max_results: Optional[int] = None
score_threshold: Optional[float] = None
category_allowlist: Optional[List[str]] = None
category_denylist: Optional[List[str]] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _TextClassifierGraphOptionsProto:
"""Generates an TextClassifierOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
classifier_options_proto = _ClassifierOptionsProto(
score_threshold=self.score_threshold,
category_allowlist=self.category_allowlist,
category_denylist=self.category_denylist,
display_names_locale=self.display_names_locale,
max_results=self.max_results)
return _TextClassifierGraphOptionsProto(
base_options=base_options_proto,
classifier_options=classifier_options_proto)
class LanguageDetector(base_text_task_api.BaseTextTaskApi):
"""Class that predicts the language of an input text.
This API expects a TFLite model with TFLite Model Metadata that contains the
mandatory (described below) input tensors, output tensor, and the language
codes in an AssociatedFile.
Input tensors:
(kTfLiteString)
- 1 input tensor that is scalar or has shape [1] containing the input
string.
Output tensor:
(kTfLiteFloat32)
- 1 output tensor of shape`[1 x N]` where `N` is the number of languages.
"""
@classmethod
def create_from_model_path(cls, model_path: str) -> 'LanguageDetector':
"""Creates an `LanguageDetector` object from a TensorFlow Lite model and the default `LanguageDetectorOptions`.
Args:
model_path: Path to the model.
Returns:
`LanguageDetector` object that's created from the model file and the
default `LanguageDetectorOptions`.
Raises:
ValueError: If failed to create `LanguageDetector` object from the provided
file such as invalid file path.
RuntimeError: If other types of error occurred.
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = LanguageDetectorOptions(base_options=base_options)
return cls.create_from_options(options)
@classmethod
def create_from_options(cls,
options: LanguageDetectorOptions) -> 'LanguageDetector':
"""Creates the `LanguageDetector` object from language detector options.
Args:
options: Options for the language detector task.
Returns:
`LanguageDetector` object that's created from `options`.
Raises:
ValueError: If failed to create `LanguageDetector` object from
`LanguageDetectorOptions` such as missing the model.
RuntimeError: If other types of error occurred.
"""
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])],
output_streams=[
':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME])
],
task_options=options)
return cls(task_info.generate_graph_config())
def detect(self, text: str) -> LanguageDetectorResult:
"""Predicts the language of the input `text`.
Args:
text: The input text.
Returns:
A `LanguageDetectorResult` object that contains a list of languages and
scores.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If language detection failed to run.
"""
output_packets = self._runner.process(
{_TEXT_IN_STREAM_NAME: packet_creator.create_string(text)})
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
classification_result = _ClassificationResult.create_from_pb2(
classification_result_proto
)
return _extract_language_detector_result(classification_result)

View File

@ -88,7 +88,7 @@ class InteractiveSegmenterOptions:
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
"""Generates an InteractiveSegmenterOptions protobuf object.""" """Generates an ImageSegmenterGraphOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False base_options_proto.use_stream_mode = False
segmenter_options_proto = _SegmenterOptionsProto() segmenter_options_proto = _SegmenterOptionsProto()

View File

@ -31,6 +31,7 @@ mediapipe_files(srcs = [
"bert_text_classifier.tflite", "bert_text_classifier.tflite",
"mobilebert_embedding_with_metadata.tflite", "mobilebert_embedding_with_metadata.tflite",
"mobilebert_with_metadata.tflite", "mobilebert_with_metadata.tflite",
"language_detector.tflite",
"regex_one_embedding_with_metadata.tflite", "regex_one_embedding_with_metadata.tflite",
"test_model_text_classifier_bool_output.tflite", "test_model_text_classifier_bool_output.tflite",
"test_model_text_classifier_with_regex_tokenizer.tflite", "test_model_text_classifier_with_regex_tokenizer.tflite",
@ -78,6 +79,11 @@ filegroup(
], ],
) )
filegroup(
name = "language_detector_model",
srcs = ["language_detector.tflite"],
)
filegroup( filegroup(
name = "text_classifier_models", name = "text_classifier_models",
srcs = [ srcs = [