Metadata Writer: Add metadata writer for Text Classifier.
PiperOrigin-RevId: 486844428
This commit is contained in:
parent
b14178d305
commit
0a08e4768b
|
@ -43,3 +43,9 @@ py_library(
|
||||||
srcs = ["image_classifier.py"],
|
srcs = ["image_classifier.py"],
|
||||||
deps = [":metadata_writer"],
|
deps = [":metadata_writer"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "text_classifier",
|
||||||
|
srcs = ["text_classifier.py"],
|
||||||
|
deps = [":metadata_writer"],
|
||||||
|
)
|
||||||
|
|
|
@ -62,10 +62,10 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An MetadataWrite object.
|
A MetadataWriter object.
|
||||||
"""
|
"""
|
||||||
writer = metadata_writer.MetadataWriter(model_buffer)
|
writer = metadata_writer.MetadataWriter(model_buffer)
|
||||||
writer.add_genernal_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||||
writer.add_image_input(input_norm_mean, input_norm_std)
|
writer.add_image_input(input_norm_mean, input_norm_std)
|
||||||
writer.add_classification_output(labels, score_calibration)
|
writer.add_classification_output(labels, score_calibration)
|
||||||
return cls(writer)
|
return cls(writer)
|
||||||
|
|
|
@ -228,6 +228,45 @@ class ScoreThresholdingMd:
|
||||||
return score_thresholding
|
return score_thresholding
|
||||||
|
|
||||||
|
|
||||||
|
class RegexTokenizerMd:
|
||||||
|
"""A container for the Regex tokenizer [1] metadata information.
|
||||||
|
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, delim_regex_pattern: str, vocab_file_path: str):
|
||||||
|
"""Initializes a RegexTokenizerMd object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
delim_regex_pattern: the regular expression to segment strings and create
|
||||||
|
tokens.
|
||||||
|
vocab_file_path: path to the vocabulary file.
|
||||||
|
"""
|
||||||
|
self._delim_regex_pattern = delim_regex_pattern
|
||||||
|
self._vocab_file_path = vocab_file_path
|
||||||
|
|
||||||
|
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
|
||||||
|
"""Creates the Regex tokenizer metadata based on the information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Flatbuffers Python object of the Regex tokenizer metadata.
|
||||||
|
"""
|
||||||
|
vocab = _metadata_fb.AssociatedFileT()
|
||||||
|
vocab.name = self._vocab_file_path
|
||||||
|
vocab.description = _VOCAB_FILE_DESCRIPTION
|
||||||
|
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
|
||||||
|
|
||||||
|
# Create the RegexTokenizer.
|
||||||
|
tokenizer = _metadata_fb.ProcessUnitT()
|
||||||
|
tokenizer.optionsType = (
|
||||||
|
_metadata_fb.ProcessUnitOptions.RegexTokenizerOptions)
|
||||||
|
tokenizer.options = _metadata_fb.RegexTokenizerOptionsT()
|
||||||
|
tokenizer.options.delimRegexPattern = self._delim_regex_pattern
|
||||||
|
tokenizer.options.vocabFile = [vocab]
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
class TensorMd:
|
class TensorMd:
|
||||||
"""A container for common tensor metadata information.
|
"""A container for common tensor metadata information.
|
||||||
|
|
||||||
|
@ -397,6 +436,56 @@ class InputImageTensorMd(TensorMd):
|
||||||
return tensor_metadata
|
return tensor_metadata
|
||||||
|
|
||||||
|
|
||||||
|
class InputTextTensorMd(TensorMd):
|
||||||
|
"""A container for the input text tensor metadata information.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
tokenizer_md: information of the tokenizer in the input text tensor, if any.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
tokenizer_md: Optional[RegexTokenizerMd] = None):
|
||||||
|
"""Initializes the instance of InputTextTensorMd.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name of the tensor.
|
||||||
|
description: description of what the tensor is.
|
||||||
|
tokenizer_md: information of the tokenizer in the input text tensor, if
|
||||||
|
any. Only `RegexTokenizer` [1] is currenly supported. If the tokenizer
|
||||||
|
is `BertTokenizer` [2] or `SentencePieceTokenizer` [3], refer to
|
||||||
|
`BertInputTensorsMd` class.
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||||
|
[2]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||||
|
[3]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||||
|
"""
|
||||||
|
super().__init__(name, description)
|
||||||
|
self.tokenizer_md = tokenizer_md
|
||||||
|
|
||||||
|
def create_metadata(self) -> _metadata_fb.TensorMetadataT:
|
||||||
|
"""Creates the input text metadata based on the information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Flatbuffers Python object of the input text metadata.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the type of tokenizer_md is unsupported.
|
||||||
|
"""
|
||||||
|
if not isinstance(self.tokenizer_md, (type(None), RegexTokenizerMd)):
|
||||||
|
raise ValueError(
|
||||||
|
f"The type of tokenizer_options, {type(self.tokenizer_md)}, is "
|
||||||
|
f"unsupported")
|
||||||
|
|
||||||
|
tensor_metadata = super().create_metadata()
|
||||||
|
if self.tokenizer_md:
|
||||||
|
tensor_metadata.processUnits = [self.tokenizer_md.create_metadata()]
|
||||||
|
return tensor_metadata
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTensorMd(TensorMd):
|
class ClassificationTensorMd(TensorMd):
|
||||||
"""A container for the classification tensor metadata information.
|
"""A container for the classification tensor metadata information.
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,9 @@ from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
|
||||||
|
|
||||||
_INPUT_IMAGE_NAME = 'image'
|
_INPUT_IMAGE_NAME = 'image'
|
||||||
_INPUT_IMAGE_DESCRIPTION = 'Input image to be processed.'
|
_INPUT_IMAGE_DESCRIPTION = 'Input image to be processed.'
|
||||||
|
_INPUT_REGEX_TEXT_NAME = 'input_text'
|
||||||
|
_INPUT_REGEX_TEXT_DESCRIPTION = ('Embedding vectors representing the input '
|
||||||
|
'text to be processed.')
|
||||||
_OUTPUT_CLASSIFICATION_NAME = 'score'
|
_OUTPUT_CLASSIFICATION_NAME = 'score'
|
||||||
_OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively.'
|
_OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively.'
|
||||||
|
|
||||||
|
@ -82,6 +85,22 @@ class ScoreThresholding:
|
||||||
global_score_threshold: float
|
global_score_threshold: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class RegexTokenizer:
|
||||||
|
"""Parameters of the Regex tokenizer [1] metadata information.
|
||||||
|
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
delim_regex_pattern: the regular expression to segment strings and create
|
||||||
|
tokens.
|
||||||
|
vocab_file_path: path to the vocabulary file.
|
||||||
|
"""
|
||||||
|
delim_regex_pattern: str
|
||||||
|
vocab_file_path: str
|
||||||
|
|
||||||
|
|
||||||
class Labels(object):
|
class Labels(object):
|
||||||
"""Simple container holding classification labels of a particular tensor.
|
"""Simple container holding classification labels of a particular tensor.
|
||||||
|
|
||||||
|
@ -355,11 +374,11 @@ class MetadataWriter(object):
|
||||||
if os.path.exists(self._temp_folder.name):
|
if os.path.exists(self._temp_folder.name):
|
||||||
self._temp_folder.cleanup()
|
self._temp_folder.cleanup()
|
||||||
|
|
||||||
def add_genernal_info(
|
def add_general_info(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_description: Optional[str] = None) -> 'MetadataWriter':
|
model_description: Optional[str] = None) -> 'MetadataWriter':
|
||||||
"""Adds a genernal info metadata for the general metadata informantion."""
|
"""Adds a general info metadata for the general metadata informantion."""
|
||||||
# Will overwrite the previous `self._general_md` if exists.
|
# Will overwrite the previous `self._general_md` if exists.
|
||||||
self._general_md = metadata_info.GeneralMd(
|
self._general_md = metadata_info.GeneralMd(
|
||||||
name=model_name, description=model_description)
|
name=model_name, description=model_description)
|
||||||
|
@ -415,6 +434,34 @@ class MetadataWriter(object):
|
||||||
self._input_mds.append(input_md)
|
self._input_mds.append(input_md)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def add_regex_text_input(
|
||||||
|
self,
|
||||||
|
regex_tokenizer: RegexTokenizer,
|
||||||
|
name: str = _INPUT_REGEX_TEXT_NAME,
|
||||||
|
description: str = _INPUT_REGEX_TEXT_DESCRIPTION) -> 'MetadataWriter':
|
||||||
|
"""Adds an input text metadata for the text input with regex tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
regex_tokenizer: information of the regex tokenizer [1] used to process
|
||||||
|
the input string.
|
||||||
|
name: Name of the input tensor.
|
||||||
|
description: Description of the input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The MetaWriter instance, can be used for chained operation.
|
||||||
|
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||||
|
"""
|
||||||
|
tokenizer_md = metadata_info.RegexTokenizerMd(
|
||||||
|
delim_regex_pattern=regex_tokenizer.delim_regex_pattern,
|
||||||
|
vocab_file_path=regex_tokenizer.vocab_file_path)
|
||||||
|
input_md = metadata_info.InputTextTensorMd(
|
||||||
|
name=name, description=description, tokenizer_md=tokenizer_md)
|
||||||
|
self._input_mds.append(input_md)
|
||||||
|
self._associated_files.append(regex_tokenizer.vocab_file_path)
|
||||||
|
return self
|
||||||
|
|
||||||
def add_classification_output(
|
def add_classification_output(
|
||||||
self,
|
self,
|
||||||
labels: Optional[Labels] = None,
|
labels: Optional[Labels] = None,
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Writes metadata and label file to the Text classifier models."""
|
||||||
|
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
|
|
||||||
|
_MODEL_NAME = "TextClassifier"
|
||||||
|
_MODEL_DESCRIPTION = ("Classify the input text into a set of known categories.")
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataWriter(metadata_writer.MetadataWriterBase):
|
||||||
|
"""MetadataWriter to write the metadata into the text classifier."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_for_regex_model(
|
||||||
|
cls, model_buffer: bytearray,
|
||||||
|
regex_tokenizer: metadata_writer.RegexTokenizer,
|
||||||
|
labels: metadata_writer.Labels) -> "MetadataWriter":
|
||||||
|
"""Creates MetadataWriter for TFLite model with regex tokentizer.
|
||||||
|
|
||||||
|
The parameters required in this method are mandatory when using MediaPipe
|
||||||
|
Tasks.
|
||||||
|
|
||||||
|
Note that only the output TFLite is used for deployment. The output JSON
|
||||||
|
content is used to interpret the metadata content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_buffer: A valid flatbuffer loaded from the TFLite model file.
|
||||||
|
regex_tokenizer: information of the regex tokenizer [1] used to process
|
||||||
|
the input string. If the tokenizer is `BertTokenizer` [2] or
|
||||||
|
`SentencePieceTokenizer` [3], please refer to
|
||||||
|
`create_for_bert_model`.
|
||||||
|
labels: an instance of Labels helper class used in the output
|
||||||
|
classification tensor [4].
|
||||||
|
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||||
|
[2]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||||
|
[3]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||||
|
[4]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A MetadataWriter object.
|
||||||
|
"""
|
||||||
|
writer = metadata_writer.MetadataWriter(model_buffer)
|
||||||
|
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||||
|
writer.add_regex_text_input(regex_tokenizer)
|
||||||
|
writer.add_classification_output(labels)
|
||||||
|
return cls(writer)
|
|
@ -53,3 +53,17 @@ py_test(
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "text_classifier_test",
|
||||||
|
srcs = ["text_classifier_test.py"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/metadata:data_files",
|
||||||
|
"//mediapipe/tasks/testdata/metadata:model_files",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:text_classifier",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -191,6 +191,43 @@ class InputImageTensorMdTest(parameterized.TestCase):
|
||||||
f"{len(norm_mean)} and {len(norm_std)}", str(error.exception))
|
f"{len(norm_mean)} and {len(norm_std)}", str(error.exception))
|
||||||
|
|
||||||
|
|
||||||
|
class InputTextTensorMdTest(absltest.TestCase):
|
||||||
|
|
||||||
|
_NAME = "input text"
|
||||||
|
_DESCRIPTION = "The input string."
|
||||||
|
_VOCAB_FILE = "vocab.txt"
|
||||||
|
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||||
|
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, "input_text_tensor_meta.json"))
|
||||||
|
_EXPECTED_TENSOR_DEFAULT_JSON = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, "input_text_tensor_default_meta.json"))
|
||||||
|
|
||||||
|
def test_create_metadata_should_succeed(self):
|
||||||
|
regex_tokenizer_md = metadata_info.RegexTokenizerMd(
|
||||||
|
self._DELIM_REGEX_PATTERN, self._VOCAB_FILE)
|
||||||
|
|
||||||
|
text_tensor_md = metadata_info.InputTextTensorMd(self._NAME,
|
||||||
|
self._DESCRIPTION,
|
||||||
|
regex_tokenizer_md)
|
||||||
|
|
||||||
|
metadata_json = _metadata.convert_to_json(
|
||||||
|
_create_dummy_model_metadata_with_tensor(
|
||||||
|
text_tensor_md.create_metadata()))
|
||||||
|
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
|
||||||
|
expected_json = f.read()
|
||||||
|
self.assertEqual(metadata_json, expected_json)
|
||||||
|
|
||||||
|
def test_create_metadata_by_default_should_succeed(self):
|
||||||
|
text_tensor_md = metadata_info.InputTextTensorMd()
|
||||||
|
|
||||||
|
metadata_json = _metadata.convert_to_json(
|
||||||
|
_create_dummy_model_metadata_with_tensor(
|
||||||
|
text_tensor_md.create_metadata()))
|
||||||
|
with open(self._EXPECTED_TENSOR_DEFAULT_JSON, "r") as f:
|
||||||
|
expected_json = f.read()
|
||||||
|
self.assertEqual(metadata_json, expected_json)
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTensorMdTest(parameterized.TestCase):
|
class ClassificationTensorMdTest(parameterized.TestCase):
|
||||||
|
|
||||||
_NAME = "probability"
|
_NAME = "probability"
|
||||||
|
|
|
@ -113,7 +113,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
||||||
def test_initialize_and_populate(self):
|
def test_initialize_and_populate(self):
|
||||||
writer = metadata_writer.MetadataWriter.create(
|
writer = metadata_writer.MetadataWriter.create(
|
||||||
self.image_classifier_model_buffer)
|
self.image_classifier_model_buffer)
|
||||||
writer.add_genernal_info(
|
writer.add_general_info(
|
||||||
model_name='my_image_model', model_description='my_description')
|
model_name='my_image_model', model_description='my_description')
|
||||||
tflite_model, metadata_json = writer.populate()
|
tflite_model, metadata_json = writer.populate()
|
||||||
self.assertLen(tflite_model, 1882986)
|
self.assertLen(tflite_model, 1882986)
|
||||||
|
@ -142,7 +142,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
||||||
def test_add_feature_input_output(self):
|
def test_add_feature_input_output(self):
|
||||||
writer = metadata_writer.MetadataWriter.create(
|
writer = metadata_writer.MetadataWriter.create(
|
||||||
self.image_classifier_model_buffer)
|
self.image_classifier_model_buffer)
|
||||||
writer.add_genernal_info(
|
writer.add_general_info(
|
||||||
model_name='my_model', model_description='my_description')
|
model_name='my_model', model_description='my_description')
|
||||||
writer.add_feature_input(
|
writer.add_feature_input(
|
||||||
name='input_tesnor', description='a feature input tensor')
|
name='input_tesnor', description='a feature input tensor')
|
||||||
|
@ -191,7 +191,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
||||||
def test_image_classifier(self):
|
def test_image_classifier(self):
|
||||||
writer = metadata_writer.MetadataWriter.create(
|
writer = metadata_writer.MetadataWriter.create(
|
||||||
self.image_classifier_model_buffer)
|
self.image_classifier_model_buffer)
|
||||||
writer.add_genernal_info(
|
writer.add_general_info(
|
||||||
model_name='image_classifier',
|
model_name='image_classifier',
|
||||||
model_description='Imagenet classification model')
|
model_description='Imagenet classification model')
|
||||||
writer.add_image_input(
|
writer.add_image_input(
|
||||||
|
@ -282,7 +282,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_image_classifier_with_locale_and_score_calibration(self):
|
def test_image_classifier_with_locale_and_score_calibration(self):
|
||||||
writer = metadata_writer.MetadataWriter(self.image_classifier_model_buffer)
|
writer = metadata_writer.MetadataWriter(self.image_classifier_model_buffer)
|
||||||
writer.add_genernal_info(
|
writer.add_general_info(
|
||||||
model_name='image_classifier',
|
model_name='image_classifier',
|
||||||
model_description='Classify the input image.')
|
model_description='Classify the input image.')
|
||||||
writer.add_image_input(
|
writer.add_image_input(
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for metadata_writer.text_classifier."""
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import text_classifier
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
_TEST_DIR = "mediapipe/tasks/testdata/metadata/"
|
||||||
|
_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite")
|
||||||
|
_LABEL_FILE = test_utils.get_test_data_path(_TEST_DIR +
|
||||||
|
"movie_review_labels.txt")
|
||||||
|
_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt")
|
||||||
|
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||||
|
_JSON_FILE = test_utils.get_test_data_path("movie_review.json")
|
||||||
|
|
||||||
|
|
||||||
|
class TextClassifierTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_write_metadata(self,):
|
||||||
|
with open(_MODEL, "rb") as f:
|
||||||
|
model_buffer = f.read()
|
||||||
|
writer = text_classifier.MetadataWriter.create_for_regex_model(
|
||||||
|
model_buffer,
|
||||||
|
regex_tokenizer=metadata_writer.RegexTokenizer(
|
||||||
|
delim_regex_pattern=_DELIM_REGEX_PATTERN,
|
||||||
|
vocab_file_path=_VOCAB_FILE),
|
||||||
|
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
|
||||||
|
_, metadata_json = writer.populate()
|
||||||
|
|
||||||
|
with open(_JSON_FILE, "r") as f:
|
||||||
|
expected_json = f.read()
|
||||||
|
self.assertEqual(metadata_json, expected_json)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
absltest.main()
|
12
mediapipe/tasks/testdata/metadata/BUILD
vendored
12
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -31,6 +31,7 @@ mediapipe_files(srcs = [
|
||||||
"mobilenet_v2_1.0_224_quant.tflite",
|
"mobilenet_v2_1.0_224_quant.tflite",
|
||||||
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
|
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
|
||||||
"mobilenet_v2_1.0_224_without_metadata.tflite",
|
"mobilenet_v2_1.0_224_without_metadata.tflite",
|
||||||
|
"movie_review.tflite",
|
||||||
])
|
])
|
||||||
|
|
||||||
exports_files([
|
exports_files([
|
||||||
|
@ -54,6 +55,11 @@ exports_files([
|
||||||
"labels.txt",
|
"labels.txt",
|
||||||
"mobilenet_v2_1.0_224.json",
|
"mobilenet_v2_1.0_224.json",
|
||||||
"mobilenet_v2_1.0_224_quant.json",
|
"mobilenet_v2_1.0_224_quant.json",
|
||||||
|
"input_text_tensor_meta.json",
|
||||||
|
"input_text_tensor_default_meta.json",
|
||||||
|
"movie_review_labels.txt",
|
||||||
|
"regex_vocab.txt",
|
||||||
|
"movie_review.json",
|
||||||
])
|
])
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
|
@ -67,6 +73,7 @@ filegroup(
|
||||||
"mobilenet_v2_1.0_224_quant.tflite",
|
"mobilenet_v2_1.0_224_quant.tflite",
|
||||||
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
|
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
|
||||||
"mobilenet_v2_1.0_224_without_metadata.tflite",
|
"mobilenet_v2_1.0_224_without_metadata.tflite",
|
||||||
|
"movie_review.tflite",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -86,9 +93,14 @@ filegroup(
|
||||||
"input_image_tensor_float_meta.json",
|
"input_image_tensor_float_meta.json",
|
||||||
"input_image_tensor_uint8_meta.json",
|
"input_image_tensor_uint8_meta.json",
|
||||||
"input_image_tensor_unsupported_meta.json",
|
"input_image_tensor_unsupported_meta.json",
|
||||||
|
"input_text_tensor_default_meta.json",
|
||||||
|
"input_text_tensor_meta.json",
|
||||||
"labels.txt",
|
"labels.txt",
|
||||||
"mobilenet_v2_1.0_224.json",
|
"mobilenet_v2_1.0_224.json",
|
||||||
"mobilenet_v2_1.0_224_quant.json",
|
"mobilenet_v2_1.0_224_quant.json",
|
||||||
|
"movie_review.json",
|
||||||
|
"movie_review_labels.txt",
|
||||||
|
"regex_vocab.txt",
|
||||||
"score_calibration.txt",
|
"score_calibration.txt",
|
||||||
"score_calibration_file_meta.json",
|
"score_calibration_file_meta.json",
|
||||||
"score_calibration_tensor_meta.json",
|
"score_calibration_tensor_meta.json",
|
||||||
|
|
17
mediapipe/tasks/testdata/metadata/input_text_tensor_default_meta.json
vendored
Normal file
17
mediapipe/tasks/testdata/metadata/input_text_tensor_default_meta.json
vendored
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
{
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
34
mediapipe/tasks/testdata/metadata/input_text_tensor_meta.json
vendored
Normal file
34
mediapipe/tasks/testdata/metadata/input_text_tensor_meta.json
vendored
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
{
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "input text",
|
||||||
|
"description": "The input string.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "RegexTokenizerOptions",
|
||||||
|
"options": {
|
||||||
|
"delim_regex_pattern": "[^\\w\\']+",
|
||||||
|
"vocab_file": [
|
||||||
|
{
|
||||||
|
"name": "vocab.txt",
|
||||||
|
"description": "Vocabulary file to convert natural language words to embedding vectors.",
|
||||||
|
"type": "VOCABULARY"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stats": {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
63
mediapipe/tasks/testdata/metadata/movie_review.json
vendored
Normal file
63
mediapipe/tasks/testdata/metadata/movie_review.json
vendored
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
{
|
||||||
|
"name": "TextClassifier",
|
||||||
|
"description": "Classify the input text into a set of known categories.",
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "input_text",
|
||||||
|
"description": "Embedding vectors representing the input text to be processed.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "RegexTokenizerOptions",
|
||||||
|
"options": {
|
||||||
|
"delim_regex_pattern": "[^\\w\\']+",
|
||||||
|
"vocab_file": [
|
||||||
|
{
|
||||||
|
"name": "regex_vocab.txt",
|
||||||
|
"description": "Vocabulary file to convert natural language words to embedding vectors.",
|
||||||
|
"type": "VOCABULARY"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stats": {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "score",
|
||||||
|
"description": "Score of the labels respectively.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"associated_files": [
|
||||||
|
{
|
||||||
|
"name": "labels.txt",
|
||||||
|
"description": "Labels for categories that the model can recognize.",
|
||||||
|
"type": "TENSOR_AXIS_LABELS"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"min_parser_version": "1.2.1"
|
||||||
|
}
|
2
mediapipe/tasks/testdata/metadata/movie_review_labels.txt
vendored
Normal file
2
mediapipe/tasks/testdata/metadata/movie_review_labels.txt
vendored
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
Negative
|
||||||
|
Positive
|
10000
mediapipe/tasks/testdata/metadata/regex_vocab.txt
vendored
Normal file
10000
mediapipe/tasks/testdata/metadata/regex_vocab.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
28
third_party/external_files.bzl
vendored
28
third_party/external_files.bzl
vendored
|
@ -346,6 +346,18 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/input_image_tensor_unsupported_meta.json?generation=1665422835757143"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/input_image_tensor_unsupported_meta.json?generation=1665422835757143"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_input_text_tensor_default_meta_json",
|
||||||
|
sha256 = "9723e59960b0e6ca60d120494c32e798b054ea6e5a441b359c84f759bd2b3a36",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/input_text_tensor_default_meta.json?generation=1667855382021347"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_input_text_tensor_meta_json",
|
||||||
|
sha256 = "c6782f676220e2cc89b70bacccb649fc848c18e33bedc449bf49f5d839b3cc6c",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/input_text_tensor_meta.json?generation=1667855384891533"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_iris_and_gaze_tflite",
|
name = "com_google_mediapipe_iris_and_gaze_tflite",
|
||||||
sha256 = "b6dcb860a92a3c7264a8e50786f46cecb529672cdafc17d39c78931257da661d",
|
sha256 = "b6dcb860a92a3c7264a8e50786f46cecb529672cdafc17d39c78931257da661d",
|
||||||
|
@ -390,8 +402,8 @@ def external_files():
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_labels_txt",
|
name = "com_google_mediapipe_labels_txt",
|
||||||
sha256 = "536feacc519de3d418de26b2effb4d75694a8c4c0063e36499a46fa8061e2da9",
|
sha256 = "4b9b26392f765e7a872372131cd4cee8ad7c02e496b5a1228279619b138c4b7a",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1665988394538324"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1667855388142641"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
@ -538,6 +550,18 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/model_without_metadata.tflite?generation=1661875850966737"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/model_without_metadata.tflite?generation=1661875850966737"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_movie_review_json",
|
||||||
|
sha256 = "89ad347ad1cb7c587da144de6efbadec1d3e8ff0cd13e379dd16661a8186fbb5",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.json?generation=1667855392734031"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_movie_review_tflite",
|
||||||
|
sha256 = "3935ee73b13d435327d05af4d6f37dc3c146e117e1c3d572ae4d2ae0f5f412fe",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.tflite?generation=1667855395736217"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_mozart_square_jpg",
|
name = "com_google_mediapipe_mozart_square_jpg",
|
||||||
sha256 = "4feb4dadc5d6f853ade57b8c9d4c9a1f5ececd6469616c8e505f9a14823392b6",
|
sha256 = "4feb4dadc5d6f853ade57b8c9d4c9a1f5ececd6469616c8e505f9a14823392b6",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user