Rename TextClassificationResult to TextClassifierResult.

PiperOrigin-RevId: 486685936
This commit is contained in:
Jiuqiang Tang 2022-11-07 09:37:03 -08:00 committed by Copybara-Service
parent 4c06303ec7
commit 17a2de8cf7
2 changed files with 9 additions and 9 deletions

View File

@ -31,7 +31,7 @@ _ClassifierOptions = classifier_options.ClassifierOptions
_Category = category.Category _Category = category.Category
_ClassificationEntry = classifications_module.ClassificationEntry _ClassificationEntry = classifications_module.ClassificationEntry
_Classifications = classifications_module.Classifications _Classifications = classifications_module.Classifications
_TextClassificationResult = classifications_module.ClassificationResult _TextClassifierResult = classifications_module.ClassificationResult
_TextClassifier = text_classifier.TextClassifier _TextClassifier = text_classifier.TextClassifier
_TextClassifierOptions = text_classifier.TextClassifierOptions _TextClassifierOptions = text_classifier.TextClassifierOptions
@ -43,7 +43,7 @@ _NEGATIVE_TEXT = 'What a waste of my time.'
_POSITIVE_TEXT = ('This is the best movie Ive seen in recent years.' _POSITIVE_TEXT = ('This is the best movie Ive seen in recent years.'
'Strongly recommend it!') 'Strongly recommend it!')
_BERT_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[ _BERT_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications( _Classifications(
entries=[ entries=[
_ClassificationEntry( _ClassificationEntry(
@ -64,7 +64,7 @@ _BERT_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[
head_index=0, head_index=0,
head_name='probability') head_name='probability')
]) ])
_BERT_POSITIVE_RESULTS = _TextClassificationResult(classifications=[ _BERT_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications( _Classifications(
entries=[ entries=[
_ClassificationEntry( _ClassificationEntry(
@ -85,7 +85,7 @@ _BERT_POSITIVE_RESULTS = _TextClassificationResult(classifications=[
head_index=0, head_index=0,
head_name='probability') head_name='probability')
]) ])
_REGEX_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[ _REGEX_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications( _Classifications(
entries=[ entries=[
_ClassificationEntry( _ClassificationEntry(
@ -106,7 +106,7 @@ _REGEX_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[
head_index=0, head_index=0,
head_name='probability') head_name='probability')
]) ])
_REGEX_POSITIVE_RESULTS = _TextClassificationResult(classifications=[ _REGEX_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications( _Classifications(
entries=[ entries=[
_ClassificationEntry( _ClassificationEntry(

View File

@ -26,7 +26,7 @@ from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.text.core import base_text_task_api from mediapipe.tasks.python.text.core import base_text_task_api
TextClassificationResult = classifications.ClassificationResult TextClassifierResult = classifications.ClassificationResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions _TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions _ClassifierOptions = classifier_options.ClassifierOptions
@ -112,14 +112,14 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi):
task_options=options) task_options=options)
return cls(task_info.generate_graph_config()) return cls(task_info.generate_graph_config())
def classify(self, text: str) -> TextClassificationResult: def classify(self, text: str) -> TextClassifierResult:
"""Performs classification on the input `text`. """Performs classification on the input `text`.
Args: Args:
text: The input text. text: The input text.
Returns: Returns:
A `TextClassificationResult` object that contains a list of text A `TextClassifierResult` object that contains a list of text
classifications. classifications.
Raises: Raises:
@ -134,7 +134,7 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi):
packet_getter.get_proto( packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])) output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
return TextClassificationResult([ return TextClassifierResult([
classifications.Classifications.create_from_pb2(classification) classifications.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications for classification in classification_result_proto.classifications
]) ])