Rename TextClassificationResult to TextClassifierResult.

PiperOrigin-RevId: 486685936
This commit is contained in:
Jiuqiang Tang 2022-11-07 09:37:03 -08:00 committed by Copybara-Service
parent 4c06303ec7
commit 17a2de8cf7
2 changed files with 9 additions and 9 deletions

View File

@ -31,7 +31,7 @@ _ClassifierOptions = classifier_options.ClassifierOptions
_Category = category.Category
_ClassificationEntry = classifications_module.ClassificationEntry
_Classifications = classifications_module.Classifications
_TextClassificationResult = classifications_module.ClassificationResult
_TextClassifierResult = classifications_module.ClassificationResult
_TextClassifier = text_classifier.TextClassifier
_TextClassifierOptions = text_classifier.TextClassifierOptions
@ -43,7 +43,7 @@ _NEGATIVE_TEXT = 'What a waste of my time.'
_POSITIVE_TEXT = ('This is the best movie Ive seen in recent years.'
'Strongly recommend it!')
_BERT_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[
_BERT_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
@ -64,7 +64,7 @@ _BERT_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[
head_index=0,
head_name='probability')
])
_BERT_POSITIVE_RESULTS = _TextClassificationResult(classifications=[
_BERT_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
@ -85,7 +85,7 @@ _BERT_POSITIVE_RESULTS = _TextClassificationResult(classifications=[
head_index=0,
head_name='probability')
])
_REGEX_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[
_REGEX_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
@ -106,7 +106,7 @@ _REGEX_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[
head_index=0,
head_name='probability')
])
_REGEX_POSITIVE_RESULTS = _TextClassificationResult(classifications=[
_REGEX_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(

View File

@ -26,7 +26,7 @@ from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.text.core import base_text_task_api
TextClassificationResult = classifications.ClassificationResult
TextClassifierResult = classifications.ClassificationResult
_BaseOptions = base_options_module.BaseOptions
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions
@ -112,14 +112,14 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi):
task_options=options)
return cls(task_info.generate_graph_config())
def classify(self, text: str) -> TextClassificationResult:
def classify(self, text: str) -> TextClassifierResult:
"""Performs classification on the input `text`.
Args:
text: The input text.
Returns:
A `TextClassificationResult` object that contains a list of text
A `TextClassifierResult` object that contains a list of text
classifications.
Raises:
@ -134,7 +134,7 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi):
packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
return TextClassificationResult([
return TextClassifierResult([
classifications.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])