diff --git a/mediapipe/tasks/python/test/text/text_classifier_test.py b/mediapipe/tasks/python/test/text/text_classifier_test.py index 2d4c07b28..c93def48e 100644 --- a/mediapipe/tasks/python/test/text/text_classifier_test.py +++ b/mediapipe/tasks/python/test/text/text_classifier_test.py @@ -31,7 +31,7 @@ _ClassifierOptions = classifier_options.ClassifierOptions _Category = category.Category _ClassificationEntry = classifications_module.ClassificationEntry _Classifications = classifications_module.Classifications -_TextClassificationResult = classifications_module.ClassificationResult +_TextClassifierResult = classifications_module.ClassificationResult _TextClassifier = text_classifier.TextClassifier _TextClassifierOptions = text_classifier.TextClassifierOptions @@ -43,7 +43,7 @@ _NEGATIVE_TEXT = 'What a waste of my time.' _POSITIVE_TEXT = ('This is the best movie I’ve seen in recent years.' 'Strongly recommend it!') -_BERT_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[ +_BERT_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[ _Classifications( entries=[ _ClassificationEntry( @@ -64,7 +64,7 @@ _BERT_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[ head_index=0, head_name='probability') ]) -_BERT_POSITIVE_RESULTS = _TextClassificationResult(classifications=[ +_BERT_POSITIVE_RESULTS = _TextClassifierResult(classifications=[ _Classifications( entries=[ _ClassificationEntry( @@ -85,7 +85,7 @@ _BERT_POSITIVE_RESULTS = _TextClassificationResult(classifications=[ head_index=0, head_name='probability') ]) -_REGEX_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[ +_REGEX_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[ _Classifications( entries=[ _ClassificationEntry( @@ -106,7 +106,7 @@ _REGEX_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[ head_index=0, head_name='probability') ]) -_REGEX_POSITIVE_RESULTS = _TextClassificationResult(classifications=[ +_REGEX_POSITIVE_RESULTS = _TextClassifierResult(classifications=[ _Classifications( entries=[ _ClassificationEntry( diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py index 2ae5dee42..1e230ee20 100644 --- a/mediapipe/tasks/python/text/text_classifier.py +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -26,7 +26,7 @@ from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.text.core import base_text_task_api -TextClassificationResult = classifications.ClassificationResult +TextClassifierResult = classifications.ClassificationResult _BaseOptions = base_options_module.BaseOptions _TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions _ClassifierOptions = classifier_options.ClassifierOptions @@ -112,14 +112,14 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi): task_options=options) return cls(task_info.generate_graph_config()) - def classify(self, text: str) -> TextClassificationResult: + def classify(self, text: str) -> TextClassifierResult: """Performs classification on the input `text`. Args: text: The input text. Returns: - A `TextClassificationResult` object that contains a list of text + A `TextClassifierResult` object that contains a list of text classifications. Raises: @@ -134,7 +134,7 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi): packet_getter.get_proto( output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])) - return TextClassificationResult([ + return TextClassifierResult([ classifications.Classifications.create_from_pb2(classification) for classification in classification_result_proto.classifications ])