Metadata Writer: Add Bert metadata writer in Text Classifier.

PiperOrigin-RevId: 487354439
This commit is contained in:
Yuqi Li 2022-11-09 14:51:46 -08:00 committed by Copybara-Service
parent 116b4bb6c4
commit d2142e86a9
13 changed files with 31235 additions and 15 deletions

View File

@ -12,9 +12,9 @@ py_library(
srcs = [
"metadata_info.py",
],
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [
":writer_utils",
"//mediapipe/tasks/metadata:metadata_schema_py",
"//mediapipe/tasks/metadata:schema_py",
],

View File

@ -14,12 +14,14 @@
# ==============================================================================
"""Helper classes for common model metadata information."""
import collections
import csv
import os
from typing import List, Optional, Type
from typing import List, Optional, Type, Union
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
# Min and max values for UINT8 tensors.
_MIN_UINT8 = 0
@ -267,6 +269,86 @@ class RegexTokenizerMd:
return tokenizer
class BertTokenizerMd:
"""A container for the Bert tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
"""
def __init__(self, vocab_file_path: str):
"""Initializes a BertTokenizerMd object.
Args:
vocab_file_path: path to the vocabulary file.
"""
self._vocab_file_path = vocab_file_path
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the Bert tokenizer metadata based on the information.
Returns:
A Flatbuffers Python object of the Bert tokenizer metadata.
"""
vocab = _metadata_fb.AssociatedFileT()
vocab.name = self._vocab_file_path
vocab.description = _VOCAB_FILE_DESCRIPTION
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions
tokenizer.options = _metadata_fb.BertTokenizerOptionsT()
tokenizer.options.vocabFile = [vocab]
return tokenizer
class SentencePieceTokenizerMd:
"""A container for the sentence piece tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
"""
_SP_MODEL_DESCRIPTION = "The sentence piece model file."
_SP_VOCAB_FILE_DESCRIPTION = _VOCAB_FILE_DESCRIPTION + (
" This file is optional during tokenization, while the sentence piece "
"model is mandatory.")
def __init__(self,
sentence_piece_model_path: str,
vocab_file_path: Optional[str] = None):
"""Initializes a SentencePieceTokenizerMd object.
Args:
sentence_piece_model_path: path to the sentence piece model file.
vocab_file_path: path to the vocabulary file.
"""
self._sentence_piece_model_path = sentence_piece_model_path
self._vocab_file_path = vocab_file_path
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the sentence piece tokenizer metadata based on the information.
Returns:
A Flatbuffers Python object of the sentence piece tokenizer metadata.
"""
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = (
_metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions)
tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT()
sp_model = _metadata_fb.AssociatedFileT()
sp_model.name = self._sentence_piece_model_path
sp_model.description = self._SP_MODEL_DESCRIPTION
tokenizer.options.sentencePieceModel = [sp_model]
if self._vocab_file_path:
vocab = _metadata_fb.AssociatedFileT()
vocab.name = self._vocab_file_path
vocab.description = self._SP_VOCAB_FILE_DESCRIPTION
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
tokenizer.options.vocabFile = [vocab]
return tokenizer
class TensorMd:
"""A container for common tensor metadata information.
@ -486,6 +568,145 @@ class InputTextTensorMd(TensorMd):
return tensor_metadata
def _get_file_paths(files: List[_metadata_fb.AssociatedFileT]) -> List[str]:
"""Gets file paths from a list of associated files."""
if not files:
return []
return [file.name for file in files]
def _get_tokenizer_associated_files(
tokenizer_options: Optional[
Union[_metadata_fb.BertTokenizerOptionsT,
_metadata_fb.SentencePieceTokenizerOptionsT]]
) -> List[str]:
"""Gets a list of associated files packed in the tokenizer_options.
Args:
tokenizer_options: a tokenizer metadata object. Support the following
tokenizer types:
1. BertTokenizerOptions:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
2. SentencePieceTokenizerOptions:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
Returns:
A list of associated files included in tokenizer_options.
"""
if not tokenizer_options:
return []
if isinstance(tokenizer_options, _metadata_fb.BertTokenizerOptionsT):
return _get_file_paths(tokenizer_options.vocabFile)
elif isinstance(tokenizer_options,
_metadata_fb.SentencePieceTokenizerOptionsT):
return _get_file_paths(tokenizer_options.vocabFile) + _get_file_paths(
tokenizer_options.sentencePieceModel)
else:
return []
class BertInputTensorsMd:
"""A container for the input tensor metadata information of Bert models."""
_IDS_NAME = "ids"
_IDS_DESCRIPTION = "Tokenized ids of the input text."
_MASK_NAME = "mask"
_MASK_DESCRIPTION = ("Mask with 1 for real tokens and 0 for padding "
"tokens.")
_SEGMENT_IDS_NAME = "segment_ids"
_SEGMENT_IDS_DESCRIPTION = (
"0 for the first sequence, 1 for the second sequence if exists.")
def __init__(self,
model_buffer: bytearray,
ids_name: str,
mask_name: str,
segment_name: str,
tokenizer_md: Union[None, BertTokenizerMd,
SentencePieceTokenizerMd] = None):
"""Initializes a BertInputTensorsMd object.
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
in the TFLite schema, which help to determine the tensor order when
populating metadata.
Args:
model_buffer: valid buffer of the model file.
ids_name: name of the ids tensor, which represents the tokenized ids of
the input text.
mask_name: name of the mask tensor, which represents the mask with `1` for
real tokens and `0` for padding tokens.
segment_name: name of the segment ids tensor, where `0` stands for the
first sequence, and `1` stands for the second sequence if exists.
tokenizer_md: information of the tokenizer used to process the input
string, if any. Supported tokenizers are: `BertTokenizer` [1] and
`SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer` [3],
refer to `InputTensorsMd`.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436
[2]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473
[3]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475
"""
# Verify that tflite_input_names (read from the model) and
# input_name (collected from users) are aligned.
tflite_input_names = writer_utils.get_input_tensor_names(model_buffer)
input_names = [ids_name, mask_name, segment_name]
if collections.Counter(tflite_input_names) != collections.Counter(
input_names):
raise ValueError(
f"The input tensor names ({input_names}) do not match the tensor "
f"names read from the model ({tflite_input_names}).")
ids_md = TensorMd(
name=self._IDS_NAME,
description=self._IDS_DESCRIPTION,
tensor_name=ids_name)
mask_md = TensorMd(
name=self._MASK_NAME,
description=self._MASK_DESCRIPTION,
tensor_name=mask_name)
segment_ids_md = TensorMd(
name=self._SEGMENT_IDS_NAME,
description=self._SEGMENT_IDS_DESCRIPTION,
tensor_name=segment_name)
self._input_md = [ids_md, mask_md, segment_ids_md]
if not isinstance(tokenizer_md,
(type(None), BertTokenizerMd, SentencePieceTokenizerMd)):
raise ValueError(
f"The type of tokenizer_options, {type(tokenizer_md)}, is unsupported"
)
self._tokenizer_md = tokenizer_md
def create_input_process_unit_metadata(
self) -> List[_metadata_fb.ProcessUnitT]:
"""Creates the input process unit metadata."""
if self._tokenizer_md:
return [self._tokenizer_md.create_metadata()]
else:
return []
def get_tokenizer_associated_files(self) -> List[str]:
"""Gets the associated files that are packed in the tokenizer."""
if self._tokenizer_md:
return _get_tokenizer_associated_files(
self._tokenizer_md.create_metadata().options)
else:
return []
@property
def input_md(self) -> List[TensorMd]:
return self._input_md
class ClassificationTensorMd(TensorMd):
"""A container for the classification tensor metadata information.

View File

@ -19,7 +19,7 @@ import csv
import dataclasses
import os
import tempfile
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import flatbuffers
from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb
@ -101,6 +101,34 @@ class RegexTokenizer:
vocab_file_path: str
@dataclasses.dataclass
class BertTokenizer:
"""Parameters of the Bert tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
Attributes:
vocab_file_path: path to the vocabulary file.
"""
vocab_file_path: str
@dataclasses.dataclass
class SentencePieceTokenizer:
"""Parameters of the sentence piece tokenizer tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
Attributes:
sentence_piece_model_path: path to the sentence piece model file.
vocab_file_path: path to the vocabulary file.
"""
sentence_piece_model_path: str
vocab_file_path: Optional[str] = None
class Labels(object):
"""Simple container holding classification labels of a particular tensor.
@ -282,7 +310,9 @@ def _create_metadata_buffer(
model_buffer: bytearray,
general_md: Optional[metadata_info.GeneralMd] = None,
input_md: Optional[List[metadata_info.TensorMd]] = None,
output_md: Optional[List[metadata_info.TensorMd]] = None) -> bytearray:
output_md: Optional[List[metadata_info.TensorMd]] = None,
input_process_units: Optional[List[metadata_fb.ProcessUnitT]] = None
) -> bytearray:
"""Creates a buffer of the metadata.
Args:
@ -290,7 +320,9 @@ def _create_metadata_buffer(
general_md: general information about the model.
input_md: metadata information of the input tensors.
output_md: metadata information of the output tensors.
input_process_units: a lists of metadata of the input process units [1].
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L655
Returns:
A buffer of the metadata.
@ -325,6 +357,8 @@ def _create_metadata_buffer(
subgraph_metadata = metadata_fb.SubGraphMetadataT()
subgraph_metadata.inputTensorMetadata = input_metadata
subgraph_metadata.outputTensorMetadata = output_metadata
if input_process_units:
subgraph_metadata.inputProcessUnits = input_process_units
# Create the whole model metadata.
if general_md is None:
@ -366,6 +400,7 @@ class MetadataWriter(object):
self._model_buffer = model_buffer
self._general_md = None
self._input_mds = []
self._input_process_units = []
self._output_mds = []
self._associated_files = []
self._temp_folder = tempfile.TemporaryDirectory()
@ -416,7 +451,7 @@ class MetadataWriter(object):
description: Description of the input tensor.
Returns:
The MetaWriter instance, can be used for chained operation.
The MetadataWriter instance, can be used for chained operation.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
@ -448,7 +483,7 @@ class MetadataWriter(object):
description: Description of the input tensor.
Returns:
The MetaWriter instance, can be used for chained operation.
The MetadataWriter instance, can be used for chained operation.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
@ -462,6 +497,63 @@ class MetadataWriter(object):
self._associated_files.append(regex_tokenizer.vocab_file_path)
return self
def add_bert_text_input(self, tokenizer: Union[BertTokenizer,
SentencePieceTokenizer],
ids_name: str, mask_name: str,
segment_name: str) -> 'MetadataWriter':
"""Adds an metadata for the text input with bert / sentencepiece tokenizer.
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
in the TFLite schema, which help to determine the tensor order when
populating metadata.
Args:
tokenizer: information of the tokenizer used to process the input string,
if any. Supported tokenziers are: `BertTokenizer` [1] and
`SentencePieceTokenizer` [2].
ids_name: name of the ids tensor, which represents the tokenized ids of
the input text.
mask_name: name of the mask tensor, which represents the mask with `1` for
real tokens and `0` for padding tokens.
segment_name: name of the segment ids tensor, where `0` stands for the
first sequence, and `1` stands for the second sequence if exists.
Returns:
The MetadataWriter instance, can be used for chained operation.
Raises:
ValueError: if the type tokenizer is not BertTokenizer or
SentencePieceTokenizer.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
"""
if isinstance(tokenizer, BertTokenizer):
tokenizer_md = metadata_info.BertTokenizerMd(
vocab_file_path=tokenizer.vocab_file_path)
elif isinstance(tokenizer, SentencePieceTokenizer):
tokenizer_md = metadata_info.SentencePieceTokenizerMd(
sentence_piece_model_path=tokenizer.sentence_piece_model_path,
vocab_file_path=tokenizer.vocab_file_path)
else:
raise ValueError(
f'The type of tokenizer, {type(tokenizer)}, is unsupported')
bert_input_md = metadata_info.BertInputTensorsMd(
self._model_buffer,
ids_name,
mask_name,
segment_name,
tokenizer_md=tokenizer_md)
self._input_mds.extend(bert_input_md.input_md)
self._associated_files.extend(
bert_input_md.get_tokenizer_associated_files())
self._input_process_units.extend(
bert_input_md.create_input_process_unit_metadata())
return self
def add_classification_output(
self,
labels: Optional[Labels] = None,
@ -546,7 +638,8 @@ class MetadataWriter(object):
model_buffer=self._model_buffer,
general_md=self._general_md,
input_md=self._input_mds,
output_md=self._output_mds)
output_md=self._output_mds,
input_process_units=self._input_process_units)
populator.load_metadata_buffer(metadata_buffer)
if self._associated_files:
populator.load_associated_files(self._associated_files)

View File

@ -14,11 +14,18 @@
# ==============================================================================
"""Writes metadata and label file to the Text classifier models."""
from typing import Union
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
_MODEL_NAME = "TextClassifier"
_MODEL_DESCRIPTION = ("Classify the input text into a set of known categories.")
# The input tensor names of models created by Model Maker.
_DEFAULT_ID_NAME = "serving_default_input_word_ids:0"
_DEFAULT_MASK_NAME = "serving_default_input_mask:0"
_DEFAULT_SEGMENT_ID_NAME = "serving_default_input_type_ids:0"
class MetadataWriter(metadata_writer.MetadataWriterBase):
"""MetadataWriter to write the metadata into the text classifier."""
@ -62,3 +69,51 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
writer.add_regex_text_input(regex_tokenizer)
writer.add_classification_output(labels)
return cls(writer)
@classmethod
def create_for_bert_model(
cls,
model_buffer: bytearray,
tokenizer: Union[metadata_writer.BertTokenizer,
metadata_writer.SentencePieceTokenizer],
labels: metadata_writer.Labels,
ids_name: str = _DEFAULT_ID_NAME,
mask_name: str = _DEFAULT_MASK_NAME,
segment_name: str = _DEFAULT_SEGMENT_ID_NAME,
) -> "MetadataWriter":
"""Creates MetadataWriter for models with {Bert/SentencePiece}Tokenizer.
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
in the TFLite schema, which help to determine the tensor order when
populating metadata. The default values come from Model Maker.
Args:
model_buffer: valid buffer of the model file.
tokenizer: information of the tokenizer used to process the input string,
if any. Supported tokenziers are: `BertTokenizer` [1] and
`SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer` [3],
refer to `create_for_regex_model`.
labels: an instance of Labels helper class used in the output
classification tensor [4].
ids_name: name of the ids tensor, which represents the tokenized ids of
the input text.
mask_name: name of the mask tensor, which represents the mask with `1` for
real tokens and `0` for padding tokens.
segment_name: name of the segment ids tensor, where `0` stands for the
first sequence, and `1` stands for the second sequence if exists. [1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
[3]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
[4]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
Returns:
A MetadataWriter object.
"""
writer = metadata_writer.MetadataWriter(model_buffer)
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
writer.add_bert_text_input(tokenizer, ids_name, mask_name, segment_name)
writer.add_classification_output(labels)
return cls(writer)

View File

@ -367,6 +367,42 @@ class ScoreThresholdingMdTest(absltest.TestCase):
self.assertEqual(metadata_json, expected_json)
class BertTokenizerMdTest(absltest.TestCase):
_VOCAB_FILE = "vocab.txt"
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "bert_tokenizer_meta.json"))
def test_create_metadata_should_succeed(self):
tokenizer_md = metadata_info.BertTokenizerMd(self._VOCAB_FILE)
tokenizer_metadata = tokenizer_md.create_metadata()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_process_uint(tokenizer_metadata))
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
class SentencePieceTokenizerMdTest(absltest.TestCase):
_VOCAB_FILE = "vocab.txt"
_SP_MODEL = "sp.model"
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "sentence_piece_tokenizer_meta.json"))
def test_create_metadata_should_succeed(self):
tokenizer_md = metadata_info.SentencePieceTokenizerMd(
self._SP_MODEL, self._VOCAB_FILE)
tokenizer_metadata = tokenizer_md.create_metadata()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_process_uint(tokenizer_metadata))
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def _create_dummy_model_metadata_with_tensor(
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
# Create a dummy model using the tensor metadata.

View File

@ -21,28 +21,64 @@ from mediapipe.tasks.python.metadata.metadata_writers import text_classifier
from mediapipe.tasks.python.test import test_utils
_TEST_DIR = "mediapipe/tasks/testdata/metadata/"
_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite")
_REGEX_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite")
_LABEL_FILE = test_utils.get_test_data_path(_TEST_DIR +
"movie_review_labels.txt")
_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt")
_REGEX_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt")
_DELIM_REGEX_PATTERN = r"[^\w\']+"
_JSON_FILE = test_utils.get_test_data_path("movie_review.json")
_REGEX_JSON_FILE = test_utils.get_test_data_path("movie_review.json")
_BERT_MODEL = test_utils.get_test_data_path(
_TEST_DIR + "bert_text_classifier_no_metadata.tflite")
_BERT_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR +
"mobilebert_vocab.txt")
_SP_MODEL_FILE = test_utils.get_test_data_path(_TEST_DIR + "30k-clean.model")
_BERT_JSON_FILE = test_utils.get_test_data_path(
_TEST_DIR + "bert_text_classifier_with_bert_tokenizer.json")
_SENTENCE_PIECE_JSON_FILE = test_utils.get_test_data_path(
_TEST_DIR + "bert_text_classifier_with_sentence_piece.json")
class TextClassifierTest(absltest.TestCase):
def test_write_metadata(self,):
with open(_MODEL, "rb") as f:
def test_write_metadata_for_regex_model(self):
with open(_REGEX_MODEL, "rb") as f:
model_buffer = f.read()
writer = text_classifier.MetadataWriter.create_for_regex_model(
model_buffer,
regex_tokenizer=metadata_writer.RegexTokenizer(
delim_regex_pattern=_DELIM_REGEX_PATTERN,
vocab_file_path=_VOCAB_FILE),
vocab_file_path=_REGEX_VOCAB_FILE),
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
_, metadata_json = writer.populate()
with open(_JSON_FILE, "r") as f:
with open(_REGEX_JSON_FILE, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def test_write_metadata_for_bert(self):
with open(_BERT_MODEL, "rb") as f:
model_buffer = f.read()
writer = text_classifier.MetadataWriter.create_for_bert_model(
model_buffer,
tokenizer=metadata_writer.BertTokenizer(_BERT_VOCAB_FILE),
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
_, metadata_json = writer.populate()
with open(_BERT_JSON_FILE, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def test_write_metadata_for_sentence_piece(self):
with open(_BERT_MODEL, "rb") as f:
model_buffer = f.read()
writer = text_classifier.MetadataWriter.create_for_bert_model(
model_buffer,
tokenizer=metadata_writer.SentencePieceTokenizer(_SP_MODEL_FILE),
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
_, metadata_json = writer.populate()
with open(_SENTENCE_PIECE_JSON_FILE, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)

View File

@ -23,10 +23,13 @@ package(
)
mediapipe_files(srcs = [
"30k-clean.model",
"bert_text_classifier_no_metadata.tflite",
"mobile_ica_8bit-with-metadata.tflite",
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
"mobile_ica_8bit-without-model-metadata.tflite",
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
"mobilebert_vocab.txt",
"mobilenet_v1_0.25_224_1_default_1.tflite",
"mobilenet_v2_1.0_224_quant.tflite",
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
@ -60,11 +63,17 @@ exports_files([
"movie_review_labels.txt",
"regex_vocab.txt",
"movie_review.json",
"bert_tokenizer_meta.json",
"bert_text_classifier_with_sentence_piece.json",
"sentence_piece_tokenizer_meta.json",
"bert_text_classifier_with_bert_tokenizer.json",
])
filegroup(
name = "model_files",
srcs = [
"30k-clean.model",
"bert_text_classifier_no_metadata.tflite",
"mobile_ica_8bit-with-metadata.tflite",
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
"mobile_ica_8bit-without-model-metadata.tflite",
@ -81,6 +90,9 @@ filegroup(
name = "data_files",
srcs = [
"associated_file_meta.json",
"bert_text_classifier_with_bert_tokenizer.json",
"bert_text_classifier_with_sentence_piece.json",
"bert_tokenizer_meta.json",
"bounding_box_tensor_meta.json",
"classification_tensor_float_meta.json",
"classification_tensor_uint8_meta.json",
@ -96,6 +108,7 @@ filegroup(
"input_text_tensor_default_meta.json",
"input_text_tensor_meta.json",
"labels.txt",
"mobilebert_vocab.txt",
"mobilenet_v2_1.0_224.json",
"mobilenet_v2_1.0_224_quant.json",
"movie_review.json",
@ -105,5 +118,6 @@ filegroup(
"score_calibration_file_meta.json",
"score_calibration_tensor_meta.json",
"score_thresholding_meta.json",
"sentence_piece_tokenizer_meta.json",
],
)

View File

@ -0,0 +1,84 @@
{
"name": "TextClassifier",
"description": "Classify the input text into a set of known categories.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "ids",
"description": "Tokenized ids of the input text.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
},
{
"name": "segment_ids",
"description": "0 for the first sequence, 1 for the second sequence if exists.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
},
{
"name": "mask",
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
],
"input_process_units": [
{
"options_type": "BertTokenizerOptions",
"options": {
"vocab_file": [
{
"name": "mobilebert_vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors.",
"type": "VOCABULARY"
}
]
}
}
]
}
],
"min_parser_version": "1.1.0"
}

View File

@ -0,0 +1,83 @@
{
"name": "TextClassifier",
"description": "Classify the input text into a set of known categories.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "ids",
"description": "Tokenized ids of the input text.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
},
{
"name": "segment_ids",
"description": "0 for the first sequence, 1 for the second sequence if exists.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
},
{
"name": "mask",
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
],
"input_process_units": [
{
"options_type": "SentencePieceTokenizerOptions",
"options": {
"sentencePiece_model": [
{
"name": "30k-clean.model",
"description": "The sentence piece model file."
}
]
}
}
]
}
],
"min_parser_version": "1.1.0"
}

View File

@ -0,0 +1,20 @@
{
"subgraph_metadata": [
{
"input_process_units": [
{
"options_type": "BertTokenizerOptions",
"options": {
"vocab_file": [
{
"name": "vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors.",
"type": "VOCABULARY"
}
]
}
}
]
}
]
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,26 @@
{
"subgraph_metadata": [
{
"input_process_units": [
{
"options_type": "SentencePieceTokenizerOptions",
"options": {
"sentencePiece_model": [
{
"name": "sp.model",
"description": "The sentence piece model file."
}
],
"vocab_file": [
{
"name": "vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors. This file is optional during tokenization, while the sentence piece model is mandatory.",
"type": "VOCABULARY"
}
]
}
}
]
}
]
}

View File

@ -28,12 +28,36 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/associated_file_meta.json?generation=1665422792304395"],
)
http_file(
name = "com_google_mediapipe_bert_text_classifier_no_metadata_tflite",
sha256 = "9b4554f6e28a72a3f40511964eed1ccf4e74cc074f81543cacca4faf169a173e",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_no_metadata.tflite?generation=1667948360250899"],
)
http_file(
name = "com_google_mediapipe_bert_text_classifier_tflite",
sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1666144699858747"],
)
http_file(
name = "com_google_mediapipe_bert_text_classifier_with_bert_tokenizer_json",
sha256 = "49f148a13a4e3b486b1d3c2400e46e5ebd0d375674c0154278b835760e873a95",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_with_bert_tokenizer.json?generation=1667948363241334"],
)
http_file(
name = "com_google_mediapipe_bert_text_classifier_with_sentence_piece_json",
sha256 = "113091f3892691de57e379387256b2ce0cc18a1b5185af866220a46da8221f26",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_with_sentence_piece.json?generation=1667948366009530"],
)
http_file(
name = "com_google_mediapipe_bert_tokenizer_meta_json",
sha256 = "116d70c7c3ef413a8bff54ab758f9ed3d6e51fdc5621d8c920ad2f0035831804",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_tokenizer_meta.json?generation=1667948368809108"],
)
http_file(
name = "com_google_mediapipe_bounding_box_tensor_meta_json",
sha256 = "cc019cee86529955a24a3d43ca3d778fa366bcb90d67c8eaf55696789833841a",
@ -790,6 +814,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1661875931201364"],
)
http_file(
name = "com_google_mediapipe_sentence_piece_tokenizer_meta_json",
sha256 = "416bfe231710502e4a93e1b1950c0c6e5db49cffb256d241ef3d3f2d0d57718b",
urls = ["https://storage.googleapis.com/mediapipe-assets/sentence_piece_tokenizer_meta.json?generation=1667948375508564"],
)
http_file(
name = "com_google_mediapipe_speech_16000_hz_mono_wav",
sha256 = "71caf50b8757d6ab9cad5eae4d36669d3c20c225a51660afd7fe0dc44cdb74f6",