Added text classifier implementation and tests

This commit is contained in:
kinaryml 2022-11-03 02:29:40 -07:00
parent ddf37d014e
commit 6e9a070dd1
10 changed files with 615 additions and 0 deletions

View File

@ -90,6 +90,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
], ],
) )

View File

@ -0,0 +1,37 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict test compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
py_test(
name = "text_classifier_test",
srcs = ["text_classifier_test.py"],
data = [
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
"//mediapipe/tasks/testdata/text:text_classifier_models",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:classifications",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/text:text_classifier",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,246 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for text classifier."""
import enum
import os
from absl.testing import absltest
from absl.testing import parameterized
from mediapipe.tasks.python.components.containers import category
from mediapipe.tasks.python.components.containers import classifications as classifications_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.text import text_classifier
_BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_Category = category.Category
_ClassificationEntry = classifications_module.ClassificationEntry
_Classifications = classifications_module.Classifications
_ClassificationResult = classifications_module.ClassificationResult
_TextClassifier = text_classifier.TextClassifier
_TextClassifierOptions = text_classifier.TextClassifierOptions
_BERT_MODEL_FILE = 'bert_text_classifier.tflite'
_REGEX_MODEL_FILE = 'test_model_text_classifier_with_regex_tokenizer.tflite'
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
_NEGATIVE_TEXT = "What a waste of my time."
_POSITIVE_TEXT = ("This is the best movie Ive seen in recent years."
"Strongly recommend it!")
_BERT_NEGATIVE_RESULTS = _ClassificationResult(
classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=0, score=0.999479, display_name='',
category_name='negative'),
_Category(
index=1, score=0.00052154, display_name='',
category_name='positive')
],
timestamp_ms=0
)
],
head_index=0,
head_name='probability')
])
_BERT_POSITIVE_RESULTS = _ClassificationResult(
classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=1, score=0.999466, display_name='',
category_name='positive'),
_Category(
index=0, score=0.000533596, display_name='',
category_name='negative')
],
timestamp_ms=0
)
],
head_index=0,
head_name='probability')
])
_REGEX_NEGATIVE_RESULTS = _ClassificationResult(
classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=0, score=0.81313, display_name='',
category_name='Negative'),
_Category(
index=1, score=0.1868704, display_name='',
category_name='Positive')
],
timestamp_ms=0
)
],
head_index=0,
head_name='probability')
])
_REGEX_POSITIVE_RESULTS = _ClassificationResult(
classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=1, score=0.5134273, display_name='',
category_name='Positive'),
_Category(
index=0, score=0.486573, display_name='',
category_name='Negative')
],
timestamp_ms=0
)
],
head_index=0,
head_name='probability')
])
class ModelFileType(enum.Enum):
FILE_CONTENT = 1
FILE_NAME = 2
class ImageClassifierTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE))
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _TextClassifier.create_from_model_path(self.model_path) as classifier:
self.assertIsInstance(classifier, _TextClassifier)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _TextClassifierOptions(base_options=base_options)
with _TextClassifier.create_from_options(options) as classifier:
self.assertIsInstance(classifier, _TextClassifier)
def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex(
ValueError,
r"ExternalFile must specify at least one of 'file_content', "
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
base_options = _BaseOptions(model_asset_path='')
options = _TextClassifierOptions(base_options=base_options)
_TextClassifier.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _TextClassifierOptions(base_options=base_options)
classifier = _TextClassifier.create_from_options(options)
self.assertIsInstance(classifier, _TextClassifier)
@parameterized.parameters(
(ModelFileType.FILE_NAME, _BERT_MODEL_FILE, _NEGATIVE_TEXT,
_BERT_NEGATIVE_RESULTS),
(ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE, _NEGATIVE_TEXT,
_BERT_NEGATIVE_RESULTS),
(ModelFileType.FILE_NAME, _BERT_MODEL_FILE, _POSITIVE_TEXT,
_BERT_POSITIVE_RESULTS),
(ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE, _POSITIVE_TEXT,
_BERT_POSITIVE_RESULTS),
(ModelFileType.FILE_NAME, _REGEX_MODEL_FILE, _NEGATIVE_TEXT,
_REGEX_NEGATIVE_RESULTS),
(ModelFileType.FILE_CONTENT, _REGEX_MODEL_FILE, _NEGATIVE_TEXT,
_REGEX_NEGATIVE_RESULTS),
(ModelFileType.FILE_NAME, _REGEX_MODEL_FILE, _POSITIVE_TEXT,
_REGEX_POSITIVE_RESULTS),
(ModelFileType.FILE_CONTENT, _REGEX_MODEL_FILE, _POSITIVE_TEXT,
_REGEX_POSITIVE_RESULTS))
def test_classify(self, model_file_type, model_name, text,
expected_classification_result):
# Creates classifier.
model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, model_name))
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
custom_classifier_options = _ClassifierOptions()
options = _TextClassifierOptions(
base_options=base_options, classifier_options=custom_classifier_options)
classifier = _TextClassifier.create_from_options(options)
# Performs text classification on the input.
text_result = classifier.classify(text)
# Comparing results.
test_utils.assert_proto_equals(self, text_result.to_pb2(),
expected_classification_result.to_pb2())
# Closes the classifier explicitly when the classifier is not used in
# a context.
classifier.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME, _BERT_MODEL_FILE, _NEGATIVE_TEXT,
_BERT_NEGATIVE_RESULTS),
(ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE, _NEGATIVE_TEXT,
_BERT_NEGATIVE_RESULTS))
def test_classify_in_context(self, model_file_type, model_name, text,
expected_classification_result):
# Creates classifier.
model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, model_name))
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
custom_classifier_options = _ClassifierOptions()
options = _TextClassifierOptions(
base_options=base_options, classifier_options=custom_classifier_options)
with _TextClassifier.create_from_options(options) as classifier:
# Performs text classification on the input.
text_result = classifier.classify(text)
# Comparing results.
test_utils.assert_proto_equals(self, text_result.to_pb2(),
expected_classification_result.to_pb2())
if __name__ == '__main__':
absltest.main()

View File

@ -0,0 +1,39 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
py_library(
name = "text_classifier",
srcs = [
"text_classifier.py",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:classifications",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/text/core:base_text_task_api",
],
)

View File

@ -0,0 +1,31 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe Tasks Vision API."""
import mediapipe.tasks.python.vision.core
import mediapipe.tasks.python.vision.image_classifier
import mediapipe.tasks.python.vision.object_detector
ImageClassifier = image_classifier.ImageClassifier
ImageClassifierOptions = image_classifier.ImageClassifierOptions
ObjectDetector = object_detector.ObjectDetector
ObjectDetectorOptions = object_detector.ObjectDetectorOptions
RunningMode = core.vision_task_running_mode.VisionTaskRunningMode
# Remove unnecessary modules to avoid duplication in API docs.
del core
del image_classifier
del object_detector
del mediapipe

View File

@ -0,0 +1,31 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
py_library(
name = "base_text_task_api",
srcs = [
"base_text_task_api.py",
],
deps = [
"//mediapipe/framework:calculator_py_pb2",
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -0,0 +1,14 @@
"""Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

View File

@ -0,0 +1,57 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe text task base api."""
from mediapipe.framework import calculator_pb2
from mediapipe.python._framework_bindings import task_runner
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_TaskRunner = task_runner.TaskRunner
class BaseTextTaskApi(object):
"""The base class of the user-facing mediapipe text task api classes."""
def __init__(
self,
graph_config: calculator_pb2.CalculatorGraphConfig
) -> None:
"""Initializes the `BaseVisionTaskApi` object.
Args:
graph_config: The mediapipe text task graph config proto.
"""
self._runner = _TaskRunner.create(graph_config)
def close(self) -> None:
"""Shuts down the mediapipe text task instance.
Raises:
RuntimeError: If the mediapipe text task failed to close.
"""
self._runner.close()
@doc_controls.do_not_generate_docs
def __enter__(self):
"""Return `self` upon entering the runtime context."""
return self
@doc_controls.do_not_generate_docs
def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback):
"""Shuts down the mediapipe text task instance on exit of the context manager.
Raises:
RuntimeError: If the mediapipe text task failed to close.
"""
self.close()

View File

@ -0,0 +1,146 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe text classifier task."""
import dataclasses
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
# TODO: Import MPImage directly one we have an alias
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
from mediapipe.tasks.python.components.containers import classifications
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.text.core import base_text_task_api
_BaseOptions = base_options_module.BaseOptions
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_TaskInfo = task_info_module.TaskInfo
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
_TEXT_IN_STREAM_NAME = 'text_in'
_TEXT_TAG = 'TEXT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
@dataclasses.dataclass
class TextClassifierOptions:
"""Options for the text classifier task.
Attributes:
base_options: Base options for the text classifier task.
classifier_options: Options for the text classification task.
"""
base_options: _BaseOptions
classifier_options: _ClassifierOptions = _ClassifierOptions()
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _TextClassifierGraphOptionsProto:
"""Generates an TextClassifierOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
classifier_options_proto = self.classifier_options.to_pb2()
return _TextClassifierGraphOptionsProto(
base_options=base_options_proto,
classifier_options=classifier_options_proto)
class TextClassifier(base_text_task_api.BaseTextTaskApi):
"""Class that performs classification on text."""
@classmethod
def create_from_model_path(cls, model_path: str) -> 'TextClassifier':
"""Creates an `TextClassifier` object from a TensorFlow Lite model and the default `TextClassifierOptions`.
Args:
model_path: Path to the model.
Returns:
`TextClassifier` object that's created from the model file and the
default `TextClassifierOptions`.
Raises:
ValueError: If failed to create `TextClassifier` object from the provided
file such as invalid file path.
RuntimeError: If other types of error occurred.
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = TextClassifierOptions(base_options=base_options)
return cls.create_from_options(options)
@classmethod
def create_from_options(cls,
options: TextClassifierOptions) -> 'TextClassifier':
"""Creates the `TextClassifier` object from text classifier options.
Args:
options: Options for the text classifier task.
Returns:
`TextClassifier` object that's created from `options`.
Raises:
ValueError: If failed to create `TextClassifier` object from
`TextClassifierOptions` such as missing the model.
RuntimeError: If other types of error occurred.
"""
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[
':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])
],
output_streams=[
':'.join([
_CLASSIFICATION_RESULT_TAG,
_CLASSIFICATION_RESULT_OUT_STREAM_NAME
])
],
task_options=options)
return cls(task_info.generate_graph_config())
def classify(
self,
text: str,
) -> classifications.ClassificationResult:
"""Performs classification on the input `text`.
Args:
text: The input text.
Returns:
A classification result object that contains a list of classifications.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If text classification failed to run.
"""
output_packets = self._runner.process({
_TEXT_IN_STREAM_NAME: packet_creator.create_string(text)
})
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(
packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
return classifications.ClassificationResult([
classifications.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])