Metadata Writer: Add Bert metadata writer in Text Classifier.

PiperOrigin-RevId: 487354439
This commit is contained in:
Yuqi Li 2022-11-09 14:51:46 -08:00 committed by Copybara-Service
parent 116b4bb6c4
commit d2142e86a9
13 changed files with 31235 additions and 15 deletions

View File

@ -12,9 +12,9 @@ py_library(
srcs = [ srcs = [
"metadata_info.py", "metadata_info.py",
], ],
srcs_version = "PY3",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":writer_utils",
"//mediapipe/tasks/metadata:metadata_schema_py", "//mediapipe/tasks/metadata:metadata_schema_py",
"//mediapipe/tasks/metadata:schema_py", "//mediapipe/tasks/metadata:schema_py",
], ],

View File

@ -14,12 +14,14 @@
# ============================================================================== # ==============================================================================
"""Helper classes for common model metadata information.""" """Helper classes for common model metadata information."""
import collections
import csv import csv
import os import os
from typing import List, Optional, Type from typing import List, Optional, Type, Union
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
# Min and max values for UINT8 tensors. # Min and max values for UINT8 tensors.
_MIN_UINT8 = 0 _MIN_UINT8 = 0
@ -267,6 +269,86 @@ class RegexTokenizerMd:
return tokenizer return tokenizer
class BertTokenizerMd:
"""A container for the Bert tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
"""
def __init__(self, vocab_file_path: str):
"""Initializes a BertTokenizerMd object.
Args:
vocab_file_path: path to the vocabulary file.
"""
self._vocab_file_path = vocab_file_path
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the Bert tokenizer metadata based on the information.
Returns:
A Flatbuffers Python object of the Bert tokenizer metadata.
"""
vocab = _metadata_fb.AssociatedFileT()
vocab.name = self._vocab_file_path
vocab.description = _VOCAB_FILE_DESCRIPTION
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions
tokenizer.options = _metadata_fb.BertTokenizerOptionsT()
tokenizer.options.vocabFile = [vocab]
return tokenizer
class SentencePieceTokenizerMd:
"""A container for the sentence piece tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
"""
_SP_MODEL_DESCRIPTION = "The sentence piece model file."
_SP_VOCAB_FILE_DESCRIPTION = _VOCAB_FILE_DESCRIPTION + (
" This file is optional during tokenization, while the sentence piece "
"model is mandatory.")
def __init__(self,
sentence_piece_model_path: str,
vocab_file_path: Optional[str] = None):
"""Initializes a SentencePieceTokenizerMd object.
Args:
sentence_piece_model_path: path to the sentence piece model file.
vocab_file_path: path to the vocabulary file.
"""
self._sentence_piece_model_path = sentence_piece_model_path
self._vocab_file_path = vocab_file_path
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the sentence piece tokenizer metadata based on the information.
Returns:
A Flatbuffers Python object of the sentence piece tokenizer metadata.
"""
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = (
_metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions)
tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT()
sp_model = _metadata_fb.AssociatedFileT()
sp_model.name = self._sentence_piece_model_path
sp_model.description = self._SP_MODEL_DESCRIPTION
tokenizer.options.sentencePieceModel = [sp_model]
if self._vocab_file_path:
vocab = _metadata_fb.AssociatedFileT()
vocab.name = self._vocab_file_path
vocab.description = self._SP_VOCAB_FILE_DESCRIPTION
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
tokenizer.options.vocabFile = [vocab]
return tokenizer
class TensorMd: class TensorMd:
"""A container for common tensor metadata information. """A container for common tensor metadata information.
@ -486,6 +568,145 @@ class InputTextTensorMd(TensorMd):
return tensor_metadata return tensor_metadata
def _get_file_paths(files: List[_metadata_fb.AssociatedFileT]) -> List[str]:
"""Gets file paths from a list of associated files."""
if not files:
return []
return [file.name for file in files]
def _get_tokenizer_associated_files(
tokenizer_options: Optional[
Union[_metadata_fb.BertTokenizerOptionsT,
_metadata_fb.SentencePieceTokenizerOptionsT]]
) -> List[str]:
"""Gets a list of associated files packed in the tokenizer_options.
Args:
tokenizer_options: a tokenizer metadata object. Support the following
tokenizer types:
1. BertTokenizerOptions:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
2. SentencePieceTokenizerOptions:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
Returns:
A list of associated files included in tokenizer_options.
"""
if not tokenizer_options:
return []
if isinstance(tokenizer_options, _metadata_fb.BertTokenizerOptionsT):
return _get_file_paths(tokenizer_options.vocabFile)
elif isinstance(tokenizer_options,
_metadata_fb.SentencePieceTokenizerOptionsT):
return _get_file_paths(tokenizer_options.vocabFile) + _get_file_paths(
tokenizer_options.sentencePieceModel)
else:
return []
class BertInputTensorsMd:
"""A container for the input tensor metadata information of Bert models."""
_IDS_NAME = "ids"
_IDS_DESCRIPTION = "Tokenized ids of the input text."
_MASK_NAME = "mask"
_MASK_DESCRIPTION = ("Mask with 1 for real tokens and 0 for padding "
"tokens.")
_SEGMENT_IDS_NAME = "segment_ids"
_SEGMENT_IDS_DESCRIPTION = (
"0 for the first sequence, 1 for the second sequence if exists.")
def __init__(self,
model_buffer: bytearray,
ids_name: str,
mask_name: str,
segment_name: str,
tokenizer_md: Union[None, BertTokenizerMd,
SentencePieceTokenizerMd] = None):
"""Initializes a BertInputTensorsMd object.
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
in the TFLite schema, which help to determine the tensor order when
populating metadata.
Args:
model_buffer: valid buffer of the model file.
ids_name: name of the ids tensor, which represents the tokenized ids of
the input text.
mask_name: name of the mask tensor, which represents the mask with `1` for
real tokens and `0` for padding tokens.
segment_name: name of the segment ids tensor, where `0` stands for the
first sequence, and `1` stands for the second sequence if exists.
tokenizer_md: information of the tokenizer used to process the input
string, if any. Supported tokenizers are: `BertTokenizer` [1] and
`SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer` [3],
refer to `InputTensorsMd`.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436
[2]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473
[3]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475
"""
# Verify that tflite_input_names (read from the model) and
# input_name (collected from users) are aligned.
tflite_input_names = writer_utils.get_input_tensor_names(model_buffer)
input_names = [ids_name, mask_name, segment_name]
if collections.Counter(tflite_input_names) != collections.Counter(
input_names):
raise ValueError(
f"The input tensor names ({input_names}) do not match the tensor "
f"names read from the model ({tflite_input_names}).")
ids_md = TensorMd(
name=self._IDS_NAME,
description=self._IDS_DESCRIPTION,
tensor_name=ids_name)
mask_md = TensorMd(
name=self._MASK_NAME,
description=self._MASK_DESCRIPTION,
tensor_name=mask_name)
segment_ids_md = TensorMd(
name=self._SEGMENT_IDS_NAME,
description=self._SEGMENT_IDS_DESCRIPTION,
tensor_name=segment_name)
self._input_md = [ids_md, mask_md, segment_ids_md]
if not isinstance(tokenizer_md,
(type(None), BertTokenizerMd, SentencePieceTokenizerMd)):
raise ValueError(
f"The type of tokenizer_options, {type(tokenizer_md)}, is unsupported"
)
self._tokenizer_md = tokenizer_md
def create_input_process_unit_metadata(
self) -> List[_metadata_fb.ProcessUnitT]:
"""Creates the input process unit metadata."""
if self._tokenizer_md:
return [self._tokenizer_md.create_metadata()]
else:
return []
def get_tokenizer_associated_files(self) -> List[str]:
"""Gets the associated files that are packed in the tokenizer."""
if self._tokenizer_md:
return _get_tokenizer_associated_files(
self._tokenizer_md.create_metadata().options)
else:
return []
@property
def input_md(self) -> List[TensorMd]:
return self._input_md
class ClassificationTensorMd(TensorMd): class ClassificationTensorMd(TensorMd):
"""A container for the classification tensor metadata information. """A container for the classification tensor metadata information.

View File

@ -19,7 +19,7 @@ import csv
import dataclasses import dataclasses
import os import os
import tempfile import tempfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import flatbuffers import flatbuffers
from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb
@ -101,6 +101,34 @@ class RegexTokenizer:
vocab_file_path: str vocab_file_path: str
@dataclasses.dataclass
class BertTokenizer:
"""Parameters of the Bert tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
Attributes:
vocab_file_path: path to the vocabulary file.
"""
vocab_file_path: str
@dataclasses.dataclass
class SentencePieceTokenizer:
"""Parameters of the sentence piece tokenizer tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
Attributes:
sentence_piece_model_path: path to the sentence piece model file.
vocab_file_path: path to the vocabulary file.
"""
sentence_piece_model_path: str
vocab_file_path: Optional[str] = None
class Labels(object): class Labels(object):
"""Simple container holding classification labels of a particular tensor. """Simple container holding classification labels of a particular tensor.
@ -282,7 +310,9 @@ def _create_metadata_buffer(
model_buffer: bytearray, model_buffer: bytearray,
general_md: Optional[metadata_info.GeneralMd] = None, general_md: Optional[metadata_info.GeneralMd] = None,
input_md: Optional[List[metadata_info.TensorMd]] = None, input_md: Optional[List[metadata_info.TensorMd]] = None,
output_md: Optional[List[metadata_info.TensorMd]] = None) -> bytearray: output_md: Optional[List[metadata_info.TensorMd]] = None,
input_process_units: Optional[List[metadata_fb.ProcessUnitT]] = None
) -> bytearray:
"""Creates a buffer of the metadata. """Creates a buffer of the metadata.
Args: Args:
@ -290,7 +320,9 @@ def _create_metadata_buffer(
general_md: general information about the model. general_md: general information about the model.
input_md: metadata information of the input tensors. input_md: metadata information of the input tensors.
output_md: metadata information of the output tensors. output_md: metadata information of the output tensors.
input_process_units: a lists of metadata of the input process units [1].
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L655
Returns: Returns:
A buffer of the metadata. A buffer of the metadata.
@ -325,6 +357,8 @@ def _create_metadata_buffer(
subgraph_metadata = metadata_fb.SubGraphMetadataT() subgraph_metadata = metadata_fb.SubGraphMetadataT()
subgraph_metadata.inputTensorMetadata = input_metadata subgraph_metadata.inputTensorMetadata = input_metadata
subgraph_metadata.outputTensorMetadata = output_metadata subgraph_metadata.outputTensorMetadata = output_metadata
if input_process_units:
subgraph_metadata.inputProcessUnits = input_process_units
# Create the whole model metadata. # Create the whole model metadata.
if general_md is None: if general_md is None:
@ -366,6 +400,7 @@ class MetadataWriter(object):
self._model_buffer = model_buffer self._model_buffer = model_buffer
self._general_md = None self._general_md = None
self._input_mds = [] self._input_mds = []
self._input_process_units = []
self._output_mds = [] self._output_mds = []
self._associated_files = [] self._associated_files = []
self._temp_folder = tempfile.TemporaryDirectory() self._temp_folder = tempfile.TemporaryDirectory()
@ -416,7 +451,7 @@ class MetadataWriter(object):
description: Description of the input tensor. description: Description of the input tensor.
Returns: Returns:
The MetaWriter instance, can be used for chained operation. The MetadataWriter instance, can be used for chained operation.
[1]: [1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389 https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
@ -448,7 +483,7 @@ class MetadataWriter(object):
description: Description of the input tensor. description: Description of the input tensor.
Returns: Returns:
The MetaWriter instance, can be used for chained operation. The MetadataWriter instance, can be used for chained operation.
[1]: [1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500 https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
@ -462,6 +497,63 @@ class MetadataWriter(object):
self._associated_files.append(regex_tokenizer.vocab_file_path) self._associated_files.append(regex_tokenizer.vocab_file_path)
return self return self
def add_bert_text_input(self, tokenizer: Union[BertTokenizer,
SentencePieceTokenizer],
ids_name: str, mask_name: str,
segment_name: str) -> 'MetadataWriter':
"""Adds an metadata for the text input with bert / sentencepiece tokenizer.
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
in the TFLite schema, which help to determine the tensor order when
populating metadata.
Args:
tokenizer: information of the tokenizer used to process the input string,
if any. Supported tokenziers are: `BertTokenizer` [1] and
`SentencePieceTokenizer` [2].
ids_name: name of the ids tensor, which represents the tokenized ids of
the input text.
mask_name: name of the mask tensor, which represents the mask with `1` for
real tokens and `0` for padding tokens.
segment_name: name of the segment ids tensor, where `0` stands for the
first sequence, and `1` stands for the second sequence if exists.
Returns:
The MetadataWriter instance, can be used for chained operation.
Raises:
ValueError: if the type tokenizer is not BertTokenizer or
SentencePieceTokenizer.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
"""
if isinstance(tokenizer, BertTokenizer):
tokenizer_md = metadata_info.BertTokenizerMd(
vocab_file_path=tokenizer.vocab_file_path)
elif isinstance(tokenizer, SentencePieceTokenizer):
tokenizer_md = metadata_info.SentencePieceTokenizerMd(
sentence_piece_model_path=tokenizer.sentence_piece_model_path,
vocab_file_path=tokenizer.vocab_file_path)
else:
raise ValueError(
f'The type of tokenizer, {type(tokenizer)}, is unsupported')
bert_input_md = metadata_info.BertInputTensorsMd(
self._model_buffer,
ids_name,
mask_name,
segment_name,
tokenizer_md=tokenizer_md)
self._input_mds.extend(bert_input_md.input_md)
self._associated_files.extend(
bert_input_md.get_tokenizer_associated_files())
self._input_process_units.extend(
bert_input_md.create_input_process_unit_metadata())
return self
def add_classification_output( def add_classification_output(
self, self,
labels: Optional[Labels] = None, labels: Optional[Labels] = None,
@ -546,7 +638,8 @@ class MetadataWriter(object):
model_buffer=self._model_buffer, model_buffer=self._model_buffer,
general_md=self._general_md, general_md=self._general_md,
input_md=self._input_mds, input_md=self._input_mds,
output_md=self._output_mds) output_md=self._output_mds,
input_process_units=self._input_process_units)
populator.load_metadata_buffer(metadata_buffer) populator.load_metadata_buffer(metadata_buffer)
if self._associated_files: if self._associated_files:
populator.load_associated_files(self._associated_files) populator.load_associated_files(self._associated_files)

View File

@ -14,11 +14,18 @@
# ============================================================================== # ==============================================================================
"""Writes metadata and label file to the Text classifier models.""" """Writes metadata and label file to the Text classifier models."""
from typing import Union
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
_MODEL_NAME = "TextClassifier" _MODEL_NAME = "TextClassifier"
_MODEL_DESCRIPTION = ("Classify the input text into a set of known categories.") _MODEL_DESCRIPTION = ("Classify the input text into a set of known categories.")
# The input tensor names of models created by Model Maker.
_DEFAULT_ID_NAME = "serving_default_input_word_ids:0"
_DEFAULT_MASK_NAME = "serving_default_input_mask:0"
_DEFAULT_SEGMENT_ID_NAME = "serving_default_input_type_ids:0"
class MetadataWriter(metadata_writer.MetadataWriterBase): class MetadataWriter(metadata_writer.MetadataWriterBase):
"""MetadataWriter to write the metadata into the text classifier.""" """MetadataWriter to write the metadata into the text classifier."""
@ -62,3 +69,51 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
writer.add_regex_text_input(regex_tokenizer) writer.add_regex_text_input(regex_tokenizer)
writer.add_classification_output(labels) writer.add_classification_output(labels)
return cls(writer) return cls(writer)
@classmethod
def create_for_bert_model(
cls,
model_buffer: bytearray,
tokenizer: Union[metadata_writer.BertTokenizer,
metadata_writer.SentencePieceTokenizer],
labels: metadata_writer.Labels,
ids_name: str = _DEFAULT_ID_NAME,
mask_name: str = _DEFAULT_MASK_NAME,
segment_name: str = _DEFAULT_SEGMENT_ID_NAME,
) -> "MetadataWriter":
"""Creates MetadataWriter for models with {Bert/SentencePiece}Tokenizer.
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
in the TFLite schema, which help to determine the tensor order when
populating metadata. The default values come from Model Maker.
Args:
model_buffer: valid buffer of the model file.
tokenizer: information of the tokenizer used to process the input string,
if any. Supported tokenziers are: `BertTokenizer` [1] and
`SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer` [3],
refer to `create_for_regex_model`.
labels: an instance of Labels helper class used in the output
classification tensor [4].
ids_name: name of the ids tensor, which represents the tokenized ids of
the input text.
mask_name: name of the mask tensor, which represents the mask with `1` for
real tokens and `0` for padding tokens.
segment_name: name of the segment ids tensor, where `0` stands for the
first sequence, and `1` stands for the second sequence if exists. [1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
[3]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
[4]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
Returns:
A MetadataWriter object.
"""
writer = metadata_writer.MetadataWriter(model_buffer)
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
writer.add_bert_text_input(tokenizer, ids_name, mask_name, segment_name)
writer.add_classification_output(labels)
return cls(writer)

View File

@ -367,6 +367,42 @@ class ScoreThresholdingMdTest(absltest.TestCase):
self.assertEqual(metadata_json, expected_json) self.assertEqual(metadata_json, expected_json)
class BertTokenizerMdTest(absltest.TestCase):
_VOCAB_FILE = "vocab.txt"
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "bert_tokenizer_meta.json"))
def test_create_metadata_should_succeed(self):
tokenizer_md = metadata_info.BertTokenizerMd(self._VOCAB_FILE)
tokenizer_metadata = tokenizer_md.create_metadata()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_process_uint(tokenizer_metadata))
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
class SentencePieceTokenizerMdTest(absltest.TestCase):
_VOCAB_FILE = "vocab.txt"
_SP_MODEL = "sp.model"
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "sentence_piece_tokenizer_meta.json"))
def test_create_metadata_should_succeed(self):
tokenizer_md = metadata_info.SentencePieceTokenizerMd(
self._SP_MODEL, self._VOCAB_FILE)
tokenizer_metadata = tokenizer_md.create_metadata()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_process_uint(tokenizer_metadata))
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def _create_dummy_model_metadata_with_tensor( def _create_dummy_model_metadata_with_tensor(
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes: tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
# Create a dummy model using the tensor metadata. # Create a dummy model using the tensor metadata.

View File

@ -21,28 +21,64 @@ from mediapipe.tasks.python.metadata.metadata_writers import text_classifier
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
_TEST_DIR = "mediapipe/tasks/testdata/metadata/" _TEST_DIR = "mediapipe/tasks/testdata/metadata/"
_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite") _REGEX_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite")
_LABEL_FILE = test_utils.get_test_data_path(_TEST_DIR + _LABEL_FILE = test_utils.get_test_data_path(_TEST_DIR +
"movie_review_labels.txt") "movie_review_labels.txt")
_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt") _REGEX_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt")
_DELIM_REGEX_PATTERN = r"[^\w\']+" _DELIM_REGEX_PATTERN = r"[^\w\']+"
_JSON_FILE = test_utils.get_test_data_path("movie_review.json") _REGEX_JSON_FILE = test_utils.get_test_data_path("movie_review.json")
_BERT_MODEL = test_utils.get_test_data_path(
_TEST_DIR + "bert_text_classifier_no_metadata.tflite")
_BERT_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR +
"mobilebert_vocab.txt")
_SP_MODEL_FILE = test_utils.get_test_data_path(_TEST_DIR + "30k-clean.model")
_BERT_JSON_FILE = test_utils.get_test_data_path(
_TEST_DIR + "bert_text_classifier_with_bert_tokenizer.json")
_SENTENCE_PIECE_JSON_FILE = test_utils.get_test_data_path(
_TEST_DIR + "bert_text_classifier_with_sentence_piece.json")
class TextClassifierTest(absltest.TestCase): class TextClassifierTest(absltest.TestCase):
def test_write_metadata(self,): def test_write_metadata_for_regex_model(self):
with open(_MODEL, "rb") as f: with open(_REGEX_MODEL, "rb") as f:
model_buffer = f.read() model_buffer = f.read()
writer = text_classifier.MetadataWriter.create_for_regex_model( writer = text_classifier.MetadataWriter.create_for_regex_model(
model_buffer, model_buffer,
regex_tokenizer=metadata_writer.RegexTokenizer( regex_tokenizer=metadata_writer.RegexTokenizer(
delim_regex_pattern=_DELIM_REGEX_PATTERN, delim_regex_pattern=_DELIM_REGEX_PATTERN,
vocab_file_path=_VOCAB_FILE), vocab_file_path=_REGEX_VOCAB_FILE),
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE)) labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
_, metadata_json = writer.populate() _, metadata_json = writer.populate()
with open(_JSON_FILE, "r") as f: with open(_REGEX_JSON_FILE, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def test_write_metadata_for_bert(self):
with open(_BERT_MODEL, "rb") as f:
model_buffer = f.read()
writer = text_classifier.MetadataWriter.create_for_bert_model(
model_buffer,
tokenizer=metadata_writer.BertTokenizer(_BERT_VOCAB_FILE),
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
_, metadata_json = writer.populate()
with open(_BERT_JSON_FILE, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def test_write_metadata_for_sentence_piece(self):
with open(_BERT_MODEL, "rb") as f:
model_buffer = f.read()
writer = text_classifier.MetadataWriter.create_for_bert_model(
model_buffer,
tokenizer=metadata_writer.SentencePieceTokenizer(_SP_MODEL_FILE),
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
_, metadata_json = writer.populate()
with open(_SENTENCE_PIECE_JSON_FILE, "r") as f:
expected_json = f.read() expected_json = f.read()
self.assertEqual(metadata_json, expected_json) self.assertEqual(metadata_json, expected_json)

View File

@ -23,10 +23,13 @@ package(
) )
mediapipe_files(srcs = [ mediapipe_files(srcs = [
"30k-clean.model",
"bert_text_classifier_no_metadata.tflite",
"mobile_ica_8bit-with-metadata.tflite", "mobile_ica_8bit-with-metadata.tflite",
"mobile_ica_8bit-with-unsupported-metadata-version.tflite", "mobile_ica_8bit-with-unsupported-metadata-version.tflite",
"mobile_ica_8bit-without-model-metadata.tflite", "mobile_ica_8bit-without-model-metadata.tflite",
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite", "mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
"mobilebert_vocab.txt",
"mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite",
"mobilenet_v2_1.0_224_quant.tflite", "mobilenet_v2_1.0_224_quant.tflite",
"mobilenet_v2_1.0_224_quant_without_metadata.tflite", "mobilenet_v2_1.0_224_quant_without_metadata.tflite",
@ -60,11 +63,17 @@ exports_files([
"movie_review_labels.txt", "movie_review_labels.txt",
"regex_vocab.txt", "regex_vocab.txt",
"movie_review.json", "movie_review.json",
"bert_tokenizer_meta.json",
"bert_text_classifier_with_sentence_piece.json",
"sentence_piece_tokenizer_meta.json",
"bert_text_classifier_with_bert_tokenizer.json",
]) ])
filegroup( filegroup(
name = "model_files", name = "model_files",
srcs = [ srcs = [
"30k-clean.model",
"bert_text_classifier_no_metadata.tflite",
"mobile_ica_8bit-with-metadata.tflite", "mobile_ica_8bit-with-metadata.tflite",
"mobile_ica_8bit-with-unsupported-metadata-version.tflite", "mobile_ica_8bit-with-unsupported-metadata-version.tflite",
"mobile_ica_8bit-without-model-metadata.tflite", "mobile_ica_8bit-without-model-metadata.tflite",
@ -81,6 +90,9 @@ filegroup(
name = "data_files", name = "data_files",
srcs = [ srcs = [
"associated_file_meta.json", "associated_file_meta.json",
"bert_text_classifier_with_bert_tokenizer.json",
"bert_text_classifier_with_sentence_piece.json",
"bert_tokenizer_meta.json",
"bounding_box_tensor_meta.json", "bounding_box_tensor_meta.json",
"classification_tensor_float_meta.json", "classification_tensor_float_meta.json",
"classification_tensor_uint8_meta.json", "classification_tensor_uint8_meta.json",
@ -96,6 +108,7 @@ filegroup(
"input_text_tensor_default_meta.json", "input_text_tensor_default_meta.json",
"input_text_tensor_meta.json", "input_text_tensor_meta.json",
"labels.txt", "labels.txt",
"mobilebert_vocab.txt",
"mobilenet_v2_1.0_224.json", "mobilenet_v2_1.0_224.json",
"mobilenet_v2_1.0_224_quant.json", "mobilenet_v2_1.0_224_quant.json",
"movie_review.json", "movie_review.json",
@ -105,5 +118,6 @@ filegroup(
"score_calibration_file_meta.json", "score_calibration_file_meta.json",
"score_calibration_tensor_meta.json", "score_calibration_tensor_meta.json",
"score_thresholding_meta.json", "score_thresholding_meta.json",
"sentence_piece_tokenizer_meta.json",
], ],
) )

View File

@ -0,0 +1,84 @@
{
"name": "TextClassifier",
"description": "Classify the input text into a set of known categories.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "ids",
"description": "Tokenized ids of the input text.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
},
{
"name": "segment_ids",
"description": "0 for the first sequence, 1 for the second sequence if exists.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
},
{
"name": "mask",
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
],
"input_process_units": [
{
"options_type": "BertTokenizerOptions",
"options": {
"vocab_file": [
{
"name": "mobilebert_vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors.",
"type": "VOCABULARY"
}
]
}
}
]
}
],
"min_parser_version": "1.1.0"
}

View File

@ -0,0 +1,83 @@
{
"name": "TextClassifier",
"description": "Classify the input text into a set of known categories.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "ids",
"description": "Tokenized ids of the input text.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
},
{
"name": "segment_ids",
"description": "0 for the first sequence, 1 for the second sequence if exists.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
},
{
"name": "mask",
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
],
"input_process_units": [
{
"options_type": "SentencePieceTokenizerOptions",
"options": {
"sentencePiece_model": [
{
"name": "30k-clean.model",
"description": "The sentence piece model file."
}
]
}
}
]
}
],
"min_parser_version": "1.1.0"
}

View File

@ -0,0 +1,20 @@
{
"subgraph_metadata": [
{
"input_process_units": [
{
"options_type": "BertTokenizerOptions",
"options": {
"vocab_file": [
{
"name": "vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors.",
"type": "VOCABULARY"
}
]
}
}
]
}
]
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,26 @@
{
"subgraph_metadata": [
{
"input_process_units": [
{
"options_type": "SentencePieceTokenizerOptions",
"options": {
"sentencePiece_model": [
{
"name": "sp.model",
"description": "The sentence piece model file."
}
],
"vocab_file": [
{
"name": "vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors. This file is optional during tokenization, while the sentence piece model is mandatory.",
"type": "VOCABULARY"
}
]
}
}
]
}
]
}

View File

@ -28,12 +28,36 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/associated_file_meta.json?generation=1665422792304395"], urls = ["https://storage.googleapis.com/mediapipe-assets/associated_file_meta.json?generation=1665422792304395"],
) )
http_file(
name = "com_google_mediapipe_bert_text_classifier_no_metadata_tflite",
sha256 = "9b4554f6e28a72a3f40511964eed1ccf4e74cc074f81543cacca4faf169a173e",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_no_metadata.tflite?generation=1667948360250899"],
)
http_file( http_file(
name = "com_google_mediapipe_bert_text_classifier_tflite", name = "com_google_mediapipe_bert_text_classifier_tflite",
sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600", sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1666144699858747"], urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1666144699858747"],
) )
http_file(
name = "com_google_mediapipe_bert_text_classifier_with_bert_tokenizer_json",
sha256 = "49f148a13a4e3b486b1d3c2400e46e5ebd0d375674c0154278b835760e873a95",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_with_bert_tokenizer.json?generation=1667948363241334"],
)
http_file(
name = "com_google_mediapipe_bert_text_classifier_with_sentence_piece_json",
sha256 = "113091f3892691de57e379387256b2ce0cc18a1b5185af866220a46da8221f26",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_with_sentence_piece.json?generation=1667948366009530"],
)
http_file(
name = "com_google_mediapipe_bert_tokenizer_meta_json",
sha256 = "116d70c7c3ef413a8bff54ab758f9ed3d6e51fdc5621d8c920ad2f0035831804",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_tokenizer_meta.json?generation=1667948368809108"],
)
http_file( http_file(
name = "com_google_mediapipe_bounding_box_tensor_meta_json", name = "com_google_mediapipe_bounding_box_tensor_meta_json",
sha256 = "cc019cee86529955a24a3d43ca3d778fa366bcb90d67c8eaf55696789833841a", sha256 = "cc019cee86529955a24a3d43ca3d778fa366bcb90d67c8eaf55696789833841a",
@ -790,6 +814,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1661875931201364"], urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1661875931201364"],
) )
http_file(
name = "com_google_mediapipe_sentence_piece_tokenizer_meta_json",
sha256 = "416bfe231710502e4a93e1b1950c0c6e5db49cffb256d241ef3d3f2d0d57718b",
urls = ["https://storage.googleapis.com/mediapipe-assets/sentence_piece_tokenizer_meta.json?generation=1667948375508564"],
)
http_file( http_file(
name = "com_google_mediapipe_speech_16000_hz_mono_wav", name = "com_google_mediapipe_speech_16000_hz_mono_wav",
sha256 = "71caf50b8757d6ab9cad5eae4d36669d3c20c225a51660afd7fe0dc44cdb74f6", sha256 = "71caf50b8757d6ab9cad5eae4d36669d3c20c225a51660afd7fe0dc44cdb74f6",