Added Language Detector Python API and fixed a typo in Interactive Segmenter Options' docstring
This commit is contained in:
parent
a6c35e9ba5
commit
2a2a55d1b8
|
@ -49,3 +49,18 @@ py_test(
|
||||||
"//mediapipe/tasks/python/text:text_embedder",
|
"//mediapipe/tasks/python/text:text_embedder",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "language_detector_test",
|
||||||
|
srcs = ["language_detector_test.py"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/text:language_detector_model",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/python/components/containers:category",
|
||||||
|
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
"//mediapipe/tasks/python/text:language_detector",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
225
mediapipe/tasks/python/test/text/language_detector_test.py
Normal file
225
mediapipe/tasks/python/test/text/language_detector_test.py
Normal file
|
@ -0,0 +1,225 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for language detector."""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import os
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from mediapipe.tasks.python.components.containers import category
|
||||||
|
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
from mediapipe.tasks.python.text import language_detector
|
||||||
|
|
||||||
|
LanguageDetectorResult = language_detector.LanguageDetectorResult
|
||||||
|
LanguageDetectorPrediction = language_detector.LanguageDetectorResult.LanguageDetectorPrediction
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_Category = category.Category
|
||||||
|
_Classifications = classification_result_module.Classifications
|
||||||
|
_LanguageDetector = language_detector.LanguageDetector
|
||||||
|
_LanguageDetectorOptions = language_detector.LanguageDetectorOptions
|
||||||
|
|
||||||
|
_LANGUAGE_DETECTOR_MODEL = 'language_detector.tflite'
|
||||||
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
|
||||||
|
|
||||||
|
_SCORE_THRESHOLD = 0.3
|
||||||
|
_EN_TEXT = "To be, or not to be, that is the question"
|
||||||
|
_EN_EXPECTED_RESULT = LanguageDetectorResult(
|
||||||
|
[
|
||||||
|
LanguageDetectorPrediction("en", 0.999856)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_FR_TEXT = (
|
||||||
|
"Il y a beaucoup de bouches qui parlent et fort peu de têtes qui pensent."
|
||||||
|
)
|
||||||
|
_FR_EXPECTED_RESULT = LanguageDetectorResult(
|
||||||
|
[
|
||||||
|
LanguageDetectorPrediction("fr", 0.999781)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_RU_TEXT = "это какой-то английский язык"
|
||||||
|
_RU_EXPECTED_RESULT = LanguageDetectorResult(
|
||||||
|
[
|
||||||
|
LanguageDetectorPrediction("ru", 0.993362)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_MIXED_TEXT = "分久必合合久必分"
|
||||||
|
_MIXED_EXPECTED_RESULT = LanguageDetectorResult(
|
||||||
|
[
|
||||||
|
LanguageDetectorPrediction("zh", 0.505424),
|
||||||
|
LanguageDetectorPrediction("ja", 0.481617)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_TOLERANCE = 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFileType(enum.Enum):
|
||||||
|
FILE_CONTENT = 1
|
||||||
|
FILE_NAME = 2
|
||||||
|
|
||||||
|
|
||||||
|
class LanguageDetectorTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _LANGUAGE_DETECTOR_MODEL))
|
||||||
|
|
||||||
|
def _expect_language_detector_result_correct(
|
||||||
|
self,
|
||||||
|
actual_result: LanguageDetectorResult,
|
||||||
|
expect_result: LanguageDetectorResult
|
||||||
|
):
|
||||||
|
for i, prediction in enumerate(actual_result.languages_and_scores):
|
||||||
|
expected_prediction = expect_result.languages_and_scores[i]
|
||||||
|
self.assertEqual(
|
||||||
|
prediction.language_code, expected_prediction.language_code,
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
prediction.probability, expected_prediction.probability,
|
||||||
|
delta=_TOLERANCE
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with default option and valid model file successfully.
|
||||||
|
with _LanguageDetector.create_from_model_path(self.model_path) as detector:
|
||||||
|
self.assertIsInstance(detector, _LanguageDetector)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with options containing model file successfully.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _LanguageDetectorOptions(base_options=base_options)
|
||||||
|
with _LanguageDetector.create_from_options(options) as detector:
|
||||||
|
self.assertIsInstance(detector, _LanguageDetector)
|
||||||
|
|
||||||
|
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
||||||
|
base_options = _BaseOptions(
|
||||||
|
model_asset_path='/path/to/invalid/model.tflite')
|
||||||
|
options = _LanguageDetectorOptions(base_options=base_options)
|
||||||
|
_LanguageDetector.create_from_options(options)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||||
|
# Creates with options containing model content successfully.
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||||
|
options = _LanguageDetectorOptions(base_options=base_options)
|
||||||
|
detector = _LanguageDetector.create_from_options(options)
|
||||||
|
self.assertIsInstance(detector, _LanguageDetector)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_CONTENT, _EN_TEXT, _EN_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_CONTENT, _FR_TEXT, _FR_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_CONTENT, _RU_TEXT, _RU_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_NAME, _MIXED_TEXT, _MIXED_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT)
|
||||||
|
)
|
||||||
|
def test_detect(self, model_file_type, text, expected_result):
|
||||||
|
# Creates detector.
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
options = _LanguageDetectorOptions(
|
||||||
|
base_options=base_options, score_threshold=_SCORE_THRESHOLD
|
||||||
|
)
|
||||||
|
detector = _LanguageDetector.create_from_options(options)
|
||||||
|
|
||||||
|
# Performs language detection on the input.
|
||||||
|
text_result = detector.detect(text)
|
||||||
|
# Comparing results.
|
||||||
|
self._expect_language_detector_result_correct(text_result, expected_result)
|
||||||
|
# Closes the detector explicitly when the detector is not used in
|
||||||
|
# a context.
|
||||||
|
detector.close()
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT)
|
||||||
|
)
|
||||||
|
def test_detect_in_context(self, model_file_type, text, expected_result):
|
||||||
|
# Creates detector.
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
options = _LanguageDetectorOptions(
|
||||||
|
base_options=base_options, score_threshold=_SCORE_THRESHOLD
|
||||||
|
)
|
||||||
|
with _LanguageDetector.create_from_options(options) as detector:
|
||||||
|
# Performs language detection on the input.
|
||||||
|
text_result = detector.detect(text)
|
||||||
|
# Comparing results.
|
||||||
|
self._expect_language_detector_result_correct(text_result, expected_result)
|
||||||
|
|
||||||
|
def test_allowlist_option(self):
|
||||||
|
# Creates detector.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _LanguageDetectorOptions(
|
||||||
|
base_options=base_options, score_threshold=_SCORE_THRESHOLD,
|
||||||
|
category_allowlist=["ja"]
|
||||||
|
)
|
||||||
|
with _LanguageDetector.create_from_options(options) as detector:
|
||||||
|
# Performs language detection on the input.
|
||||||
|
text_result = detector.detect(_MIXED_TEXT)
|
||||||
|
# Comparing results.
|
||||||
|
expected_result = LanguageDetectorResult(
|
||||||
|
[
|
||||||
|
LanguageDetectorPrediction("ja", 0.481617)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self._expect_language_detector_result_correct(text_result, expected_result)
|
||||||
|
|
||||||
|
def test_denylist_option(self):
|
||||||
|
# Creates detector.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _LanguageDetectorOptions(
|
||||||
|
base_options=base_options, score_threshold=_SCORE_THRESHOLD,
|
||||||
|
category_denylist=["ja"]
|
||||||
|
)
|
||||||
|
with _LanguageDetector.create_from_options(options) as detector:
|
||||||
|
# Performs language detection on the input.
|
||||||
|
text_result = detector.detect(_MIXED_TEXT)
|
||||||
|
# Comparing results.
|
||||||
|
expected_result = LanguageDetectorResult(
|
||||||
|
[
|
||||||
|
LanguageDetectorPrediction("zh", 0.505424)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self._expect_language_detector_result_correct(text_result, expected_result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
|
@ -57,3 +57,23 @@ py_library(
|
||||||
"//mediapipe/tasks/python/text/core:base_text_task_api",
|
"//mediapipe/tasks/python/text/core:base_text_task_api",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "language_detector",
|
||||||
|
srcs = [
|
||||||
|
"language_detector.py",
|
||||||
|
],
|
||||||
|
visibility = ["//mediapipe/tasks:users"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/python:packet_creator",
|
||||||
|
"//mediapipe/python:packet_getter",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2",
|
||||||
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
|
"//mediapipe/tasks/python/text/core:base_text_task_api",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
205
mediapipe/tasks/python/text/language_detector.py
Normal file
205
mediapipe/tasks/python/text/language_detector.py
Normal file
|
@ -0,0 +1,205 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe language detector task."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from mediapipe.python import packet_creator
|
||||||
|
from mediapipe.python import packet_getter
|
||||||
|
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||||
|
from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2
|
||||||
|
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
|
||||||
|
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
from mediapipe.tasks.python.text.core import base_text_task_api
|
||||||
|
|
||||||
|
_ClassificationResult = classification_result_module.ClassificationResult
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
|
||||||
|
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
|
||||||
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
|
||||||
|
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
|
||||||
|
_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
|
||||||
|
_TEXT_IN_STREAM_NAME = 'text_in'
|
||||||
|
_TEXT_TAG = 'TEXT'
|
||||||
|
_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class LanguageDetectorResult:
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class LanguageDetectorPrediction:
|
||||||
|
"""A language code and its probability."""
|
||||||
|
language_code: str
|
||||||
|
probability: float
|
||||||
|
|
||||||
|
languages_and_scores: List[LanguageDetectorPrediction]
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_language_detector_result(
|
||||||
|
classification_result: classification_result_module.ClassificationResult
|
||||||
|
) -> LanguageDetectorResult:
|
||||||
|
if len(classification_result.classifications) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"The LanguageDetector TextClassifierGraph should have exactly one "
|
||||||
|
"classification head."
|
||||||
|
)
|
||||||
|
languages_and_scores = classification_result.classifications[0]
|
||||||
|
language_detector_result = LanguageDetectorResult([])
|
||||||
|
for category in languages_and_scores.categories:
|
||||||
|
if category.category_name is None:
|
||||||
|
raise ValueError(
|
||||||
|
"LanguageDetector ClassificationResult has a missing language code.")
|
||||||
|
prediction = LanguageDetectorResult.LanguageDetectorPrediction(
|
||||||
|
category.category_name, category.score
|
||||||
|
)
|
||||||
|
language_detector_result.languages_and_scores.append(prediction)
|
||||||
|
return language_detector_result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class LanguageDetectorOptions:
|
||||||
|
"""Options for the language detector task.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
base_options: Base options for the language detector task.
|
||||||
|
display_names_locale: The locale to use for display names specified through
|
||||||
|
the TFLite Model Metadata.
|
||||||
|
max_results: The maximum number of top-scored classification results to
|
||||||
|
return.
|
||||||
|
score_threshold: Overrides the ones provided in the model metadata. Results
|
||||||
|
below this value are rejected.
|
||||||
|
category_allowlist: Allowlist of category names. If non-empty,
|
||||||
|
classification results whose category name is not in this set will be
|
||||||
|
filtered out. Duplicate or unknown category names are ignored. Mutually
|
||||||
|
exclusive with `category_denylist`.
|
||||||
|
category_denylist: Denylist of category names. If non-empty, classification
|
||||||
|
results whose category name is in this set will be filtered out. Duplicate
|
||||||
|
or unknown category names are ignored. Mutually exclusive with
|
||||||
|
`category_allowlist`.
|
||||||
|
"""
|
||||||
|
base_options: _BaseOptions
|
||||||
|
display_names_locale: Optional[str] = None
|
||||||
|
max_results: Optional[int] = None
|
||||||
|
score_threshold: Optional[float] = None
|
||||||
|
category_allowlist: Optional[List[str]] = None
|
||||||
|
category_denylist: Optional[List[str]] = None
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _TextClassifierGraphOptionsProto:
|
||||||
|
"""Generates an TextClassifierOptions protobuf object."""
|
||||||
|
base_options_proto = self.base_options.to_pb2()
|
||||||
|
classifier_options_proto = _ClassifierOptionsProto(
|
||||||
|
score_threshold=self.score_threshold,
|
||||||
|
category_allowlist=self.category_allowlist,
|
||||||
|
category_denylist=self.category_denylist,
|
||||||
|
display_names_locale=self.display_names_locale,
|
||||||
|
max_results=self.max_results)
|
||||||
|
|
||||||
|
return _TextClassifierGraphOptionsProto(
|
||||||
|
base_options=base_options_proto,
|
||||||
|
classifier_options=classifier_options_proto)
|
||||||
|
|
||||||
|
|
||||||
|
class LanguageDetector(base_text_task_api.BaseTextTaskApi):
|
||||||
|
"""Class that predicts the language of an input text.
|
||||||
|
|
||||||
|
This API expects a TFLite model with TFLite Model Metadata that contains the
|
||||||
|
mandatory (described below) input tensors, output tensor, and the language
|
||||||
|
codes in an AssociatedFile.
|
||||||
|
|
||||||
|
Input tensors:
|
||||||
|
(kTfLiteString)
|
||||||
|
- 1 input tensor that is scalar or has shape [1] containing the input
|
||||||
|
string.
|
||||||
|
Output tensor:
|
||||||
|
(kTfLiteFloat32)
|
||||||
|
- 1 output tensor of shape`[1 x N]` where `N` is the number of languages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_model_path(cls, model_path: str) -> 'LanguageDetector':
|
||||||
|
"""Creates an `LanguageDetector` object from a TensorFlow Lite model and the default `LanguageDetectorOptions`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`LanguageDetector` object that's created from the model file and the
|
||||||
|
default `LanguageDetectorOptions`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `LanguageDetector` object from the provided
|
||||||
|
file such as invalid file path.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
base_options = _BaseOptions(model_asset_path=model_path)
|
||||||
|
options = LanguageDetectorOptions(base_options=base_options)
|
||||||
|
return cls.create_from_options(options)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_options(cls,
|
||||||
|
options: LanguageDetectorOptions) -> 'LanguageDetector':
|
||||||
|
"""Creates the `LanguageDetector` object from language detector options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
options: Options for the language detector task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`LanguageDetector` object that's created from `options`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `LanguageDetector` object from
|
||||||
|
`LanguageDetectorOptions` such as missing the model.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
task_info = _TaskInfo(
|
||||||
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
|
input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])],
|
||||||
|
output_streams=[
|
||||||
|
':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME])
|
||||||
|
],
|
||||||
|
task_options=options)
|
||||||
|
return cls(task_info.generate_graph_config())
|
||||||
|
|
||||||
|
def detect(self, text: str) -> LanguageDetectorResult:
|
||||||
|
"""Predicts the language of the input `text`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The input text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `LanguageDetectorResult` object that contains a list of languages and
|
||||||
|
scores.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
RuntimeError: If language detection failed to run.
|
||||||
|
"""
|
||||||
|
output_packets = self._runner.process(
|
||||||
|
{_TEXT_IN_STREAM_NAME: packet_creator.create_string(text)})
|
||||||
|
|
||||||
|
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||||
|
classification_result_proto.CopyFrom(
|
||||||
|
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
|
||||||
|
|
||||||
|
classification_result = _ClassificationResult.create_from_pb2(
|
||||||
|
classification_result_proto
|
||||||
|
)
|
||||||
|
return _extract_language_detector_result(classification_result)
|
|
@ -88,7 +88,7 @@ class InteractiveSegmenterOptions:
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
|
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
|
||||||
"""Generates an InteractiveSegmenterOptions protobuf object."""
|
"""Generates an ImageSegmenterGraphOptions protobuf object."""
|
||||||
base_options_proto = self.base_options.to_pb2()
|
base_options_proto = self.base_options.to_pb2()
|
||||||
base_options_proto.use_stream_mode = False
|
base_options_proto.use_stream_mode = False
|
||||||
segmenter_options_proto = _SegmenterOptionsProto()
|
segmenter_options_proto = _SegmenterOptionsProto()
|
||||||
|
|
6
mediapipe/tasks/testdata/text/BUILD
vendored
6
mediapipe/tasks/testdata/text/BUILD
vendored
|
@ -31,6 +31,7 @@ mediapipe_files(srcs = [
|
||||||
"bert_text_classifier.tflite",
|
"bert_text_classifier.tflite",
|
||||||
"mobilebert_embedding_with_metadata.tflite",
|
"mobilebert_embedding_with_metadata.tflite",
|
||||||
"mobilebert_with_metadata.tflite",
|
"mobilebert_with_metadata.tflite",
|
||||||
|
"language_detector.tflite",
|
||||||
"regex_one_embedding_with_metadata.tflite",
|
"regex_one_embedding_with_metadata.tflite",
|
||||||
"test_model_text_classifier_bool_output.tflite",
|
"test_model_text_classifier_bool_output.tflite",
|
||||||
"test_model_text_classifier_with_regex_tokenizer.tflite",
|
"test_model_text_classifier_with_regex_tokenizer.tflite",
|
||||||
|
@ -78,6 +79,11 @@ filegroup(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "language_detector_model",
|
||||||
|
srcs = ["language_detector.tflite"],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "text_classifier_models",
|
name = "text_classifier_models",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
Loading…
Reference in New Issue
Block a user