Rename TextClassificationResult to TextClassifierResult.
PiperOrigin-RevId: 486685936
This commit is contained in:
parent
4c06303ec7
commit
17a2de8cf7
|
@ -31,7 +31,7 @@ _ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
_Category = category.Category
|
_Category = category.Category
|
||||||
_ClassificationEntry = classifications_module.ClassificationEntry
|
_ClassificationEntry = classifications_module.ClassificationEntry
|
||||||
_Classifications = classifications_module.Classifications
|
_Classifications = classifications_module.Classifications
|
||||||
_TextClassificationResult = classifications_module.ClassificationResult
|
_TextClassifierResult = classifications_module.ClassificationResult
|
||||||
_TextClassifier = text_classifier.TextClassifier
|
_TextClassifier = text_classifier.TextClassifier
|
||||||
_TextClassifierOptions = text_classifier.TextClassifierOptions
|
_TextClassifierOptions = text_classifier.TextClassifierOptions
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ _NEGATIVE_TEXT = 'What a waste of my time.'
|
||||||
_POSITIVE_TEXT = ('This is the best movie I’ve seen in recent years.'
|
_POSITIVE_TEXT = ('This is the best movie I’ve seen in recent years.'
|
||||||
'Strongly recommend it!')
|
'Strongly recommend it!')
|
||||||
|
|
||||||
_BERT_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[
|
_BERT_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||||
_Classifications(
|
_Classifications(
|
||||||
entries=[
|
entries=[
|
||||||
_ClassificationEntry(
|
_ClassificationEntry(
|
||||||
|
@ -64,7 +64,7 @@ _BERT_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[
|
||||||
head_index=0,
|
head_index=0,
|
||||||
head_name='probability')
|
head_name='probability')
|
||||||
])
|
])
|
||||||
_BERT_POSITIVE_RESULTS = _TextClassificationResult(classifications=[
|
_BERT_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||||
_Classifications(
|
_Classifications(
|
||||||
entries=[
|
entries=[
|
||||||
_ClassificationEntry(
|
_ClassificationEntry(
|
||||||
|
@ -85,7 +85,7 @@ _BERT_POSITIVE_RESULTS = _TextClassificationResult(classifications=[
|
||||||
head_index=0,
|
head_index=0,
|
||||||
head_name='probability')
|
head_name='probability')
|
||||||
])
|
])
|
||||||
_REGEX_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[
|
_REGEX_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||||
_Classifications(
|
_Classifications(
|
||||||
entries=[
|
entries=[
|
||||||
_ClassificationEntry(
|
_ClassificationEntry(
|
||||||
|
@ -106,7 +106,7 @@ _REGEX_NEGATIVE_RESULTS = _TextClassificationResult(classifications=[
|
||||||
head_index=0,
|
head_index=0,
|
||||||
head_name='probability')
|
head_name='probability')
|
||||||
])
|
])
|
||||||
_REGEX_POSITIVE_RESULTS = _TextClassificationResult(classifications=[
|
_REGEX_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||||
_Classifications(
|
_Classifications(
|
||||||
entries=[
|
entries=[
|
||||||
_ClassificationEntry(
|
_ClassificationEntry(
|
||||||
|
|
|
@ -26,7 +26,7 @@ from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
from mediapipe.tasks.python.text.core import base_text_task_api
|
from mediapipe.tasks.python.text.core import base_text_task_api
|
||||||
|
|
||||||
TextClassificationResult = classifications.ClassificationResult
|
TextClassifierResult = classifications.ClassificationResult
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
|
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
|
||||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
|
@ -112,14 +112,14 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi):
|
||||||
task_options=options)
|
task_options=options)
|
||||||
return cls(task_info.generate_graph_config())
|
return cls(task_info.generate_graph_config())
|
||||||
|
|
||||||
def classify(self, text: str) -> TextClassificationResult:
|
def classify(self, text: str) -> TextClassifierResult:
|
||||||
"""Performs classification on the input `text`.
|
"""Performs classification on the input `text`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: The input text.
|
text: The input text.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `TextClassificationResult` object that contains a list of text
|
A `TextClassifierResult` object that contains a list of text
|
||||||
classifications.
|
classifications.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -134,7 +134,7 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi):
|
||||||
packet_getter.get_proto(
|
packet_getter.get_proto(
|
||||||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
||||||
|
|
||||||
return TextClassificationResult([
|
return TextClassifierResult([
|
||||||
classifications.Classifications.create_from_pb2(classification)
|
classifications.Classifications.create_from_pb2(classification)
|
||||||
for classification in classification_result_proto.classifications
|
for classification in classification_result_proto.classifications
|
||||||
])
|
])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user