Metadata Writer: Add Bert metadata writer in Text Classifier.
PiperOrigin-RevId: 487354439
This commit is contained in:
parent
116b4bb6c4
commit
d2142e86a9
|
@ -12,9 +12,9 @@ py_library(
|
|||
srcs = [
|
||||
"metadata_info.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":writer_utils",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_py",
|
||||
"//mediapipe/tasks/metadata:schema_py",
|
||||
],
|
||||
|
|
|
@ -14,12 +14,14 @@
|
|||
# ==============================================================================
|
||||
"""Helper classes for common model metadata information."""
|
||||
|
||||
import collections
|
||||
import csv
|
||||
import os
|
||||
from typing import List, Optional, Type
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
|
||||
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
|
||||
|
||||
# Min and max values for UINT8 tensors.
|
||||
_MIN_UINT8 = 0
|
||||
|
@ -267,6 +269,86 @@ class RegexTokenizerMd:
|
|||
return tokenizer
|
||||
|
||||
|
||||
class BertTokenizerMd:
|
||||
"""A container for the Bert tokenizer [1] metadata information.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_file_path: str):
|
||||
"""Initializes a BertTokenizerMd object.
|
||||
|
||||
Args:
|
||||
vocab_file_path: path to the vocabulary file.
|
||||
"""
|
||||
self._vocab_file_path = vocab_file_path
|
||||
|
||||
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
|
||||
"""Creates the Bert tokenizer metadata based on the information.
|
||||
|
||||
Returns:
|
||||
A Flatbuffers Python object of the Bert tokenizer metadata.
|
||||
"""
|
||||
vocab = _metadata_fb.AssociatedFileT()
|
||||
vocab.name = self._vocab_file_path
|
||||
vocab.description = _VOCAB_FILE_DESCRIPTION
|
||||
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
|
||||
tokenizer = _metadata_fb.ProcessUnitT()
|
||||
tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions
|
||||
tokenizer.options = _metadata_fb.BertTokenizerOptionsT()
|
||||
tokenizer.options.vocabFile = [vocab]
|
||||
return tokenizer
|
||||
|
||||
|
||||
class SentencePieceTokenizerMd:
|
||||
"""A container for the sentence piece tokenizer [1] metadata information.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||
"""
|
||||
|
||||
_SP_MODEL_DESCRIPTION = "The sentence piece model file."
|
||||
_SP_VOCAB_FILE_DESCRIPTION = _VOCAB_FILE_DESCRIPTION + (
|
||||
" This file is optional during tokenization, while the sentence piece "
|
||||
"model is mandatory.")
|
||||
|
||||
def __init__(self,
|
||||
sentence_piece_model_path: str,
|
||||
vocab_file_path: Optional[str] = None):
|
||||
"""Initializes a SentencePieceTokenizerMd object.
|
||||
|
||||
Args:
|
||||
sentence_piece_model_path: path to the sentence piece model file.
|
||||
vocab_file_path: path to the vocabulary file.
|
||||
"""
|
||||
self._sentence_piece_model_path = sentence_piece_model_path
|
||||
self._vocab_file_path = vocab_file_path
|
||||
|
||||
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
|
||||
"""Creates the sentence piece tokenizer metadata based on the information.
|
||||
|
||||
Returns:
|
||||
A Flatbuffers Python object of the sentence piece tokenizer metadata.
|
||||
"""
|
||||
tokenizer = _metadata_fb.ProcessUnitT()
|
||||
tokenizer.optionsType = (
|
||||
_metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions)
|
||||
tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT()
|
||||
|
||||
sp_model = _metadata_fb.AssociatedFileT()
|
||||
sp_model.name = self._sentence_piece_model_path
|
||||
sp_model.description = self._SP_MODEL_DESCRIPTION
|
||||
tokenizer.options.sentencePieceModel = [sp_model]
|
||||
if self._vocab_file_path:
|
||||
vocab = _metadata_fb.AssociatedFileT()
|
||||
vocab.name = self._vocab_file_path
|
||||
vocab.description = self._SP_VOCAB_FILE_DESCRIPTION
|
||||
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
|
||||
tokenizer.options.vocabFile = [vocab]
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TensorMd:
|
||||
"""A container for common tensor metadata information.
|
||||
|
||||
|
@ -486,6 +568,145 @@ class InputTextTensorMd(TensorMd):
|
|||
return tensor_metadata
|
||||
|
||||
|
||||
def _get_file_paths(files: List[_metadata_fb.AssociatedFileT]) -> List[str]:
|
||||
"""Gets file paths from a list of associated files."""
|
||||
if not files:
|
||||
return []
|
||||
return [file.name for file in files]
|
||||
|
||||
|
||||
def _get_tokenizer_associated_files(
|
||||
tokenizer_options: Optional[
|
||||
Union[_metadata_fb.BertTokenizerOptionsT,
|
||||
_metadata_fb.SentencePieceTokenizerOptionsT]]
|
||||
) -> List[str]:
|
||||
"""Gets a list of associated files packed in the tokenizer_options.
|
||||
|
||||
Args:
|
||||
tokenizer_options: a tokenizer metadata object. Support the following
|
||||
tokenizer types:
|
||||
1. BertTokenizerOptions:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||
2. SentencePieceTokenizerOptions:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||
|
||||
Returns:
|
||||
A list of associated files included in tokenizer_options.
|
||||
"""
|
||||
|
||||
if not tokenizer_options:
|
||||
return []
|
||||
|
||||
if isinstance(tokenizer_options, _metadata_fb.BertTokenizerOptionsT):
|
||||
return _get_file_paths(tokenizer_options.vocabFile)
|
||||
elif isinstance(tokenizer_options,
|
||||
_metadata_fb.SentencePieceTokenizerOptionsT):
|
||||
return _get_file_paths(tokenizer_options.vocabFile) + _get_file_paths(
|
||||
tokenizer_options.sentencePieceModel)
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class BertInputTensorsMd:
|
||||
"""A container for the input tensor metadata information of Bert models."""
|
||||
|
||||
_IDS_NAME = "ids"
|
||||
_IDS_DESCRIPTION = "Tokenized ids of the input text."
|
||||
_MASK_NAME = "mask"
|
||||
_MASK_DESCRIPTION = ("Mask with 1 for real tokens and 0 for padding "
|
||||
"tokens.")
|
||||
_SEGMENT_IDS_NAME = "segment_ids"
|
||||
_SEGMENT_IDS_DESCRIPTION = (
|
||||
"0 for the first sequence, 1 for the second sequence if exists.")
|
||||
|
||||
def __init__(self,
|
||||
model_buffer: bytearray,
|
||||
ids_name: str,
|
||||
mask_name: str,
|
||||
segment_name: str,
|
||||
tokenizer_md: Union[None, BertTokenizerMd,
|
||||
SentencePieceTokenizerMd] = None):
|
||||
"""Initializes a BertInputTensorsMd object.
|
||||
|
||||
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
|
||||
in the TFLite schema, which help to determine the tensor order when
|
||||
populating metadata.
|
||||
|
||||
Args:
|
||||
model_buffer: valid buffer of the model file.
|
||||
ids_name: name of the ids tensor, which represents the tokenized ids of
|
||||
the input text.
|
||||
mask_name: name of the mask tensor, which represents the mask with `1` for
|
||||
real tokens and `0` for padding tokens.
|
||||
segment_name: name of the segment ids tensor, where `0` stands for the
|
||||
first sequence, and `1` stands for the second sequence if exists.
|
||||
tokenizer_md: information of the tokenizer used to process the input
|
||||
string, if any. Supported tokenizers are: `BertTokenizer` [1] and
|
||||
`SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer` [3],
|
||||
refer to `InputTensorsMd`.
|
||||
[1]:
|
||||
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436
|
||||
[2]:
|
||||
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473
|
||||
[3]:
|
||||
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475
|
||||
"""
|
||||
# Verify that tflite_input_names (read from the model) and
|
||||
# input_name (collected from users) are aligned.
|
||||
tflite_input_names = writer_utils.get_input_tensor_names(model_buffer)
|
||||
input_names = [ids_name, mask_name, segment_name]
|
||||
if collections.Counter(tflite_input_names) != collections.Counter(
|
||||
input_names):
|
||||
raise ValueError(
|
||||
f"The input tensor names ({input_names}) do not match the tensor "
|
||||
f"names read from the model ({tflite_input_names}).")
|
||||
|
||||
ids_md = TensorMd(
|
||||
name=self._IDS_NAME,
|
||||
description=self._IDS_DESCRIPTION,
|
||||
tensor_name=ids_name)
|
||||
|
||||
mask_md = TensorMd(
|
||||
name=self._MASK_NAME,
|
||||
description=self._MASK_DESCRIPTION,
|
||||
tensor_name=mask_name)
|
||||
|
||||
segment_ids_md = TensorMd(
|
||||
name=self._SEGMENT_IDS_NAME,
|
||||
description=self._SEGMENT_IDS_DESCRIPTION,
|
||||
tensor_name=segment_name)
|
||||
|
||||
self._input_md = [ids_md, mask_md, segment_ids_md]
|
||||
|
||||
if not isinstance(tokenizer_md,
|
||||
(type(None), BertTokenizerMd, SentencePieceTokenizerMd)):
|
||||
raise ValueError(
|
||||
f"The type of tokenizer_options, {type(tokenizer_md)}, is unsupported"
|
||||
)
|
||||
|
||||
self._tokenizer_md = tokenizer_md
|
||||
|
||||
def create_input_process_unit_metadata(
|
||||
self) -> List[_metadata_fb.ProcessUnitT]:
|
||||
"""Creates the input process unit metadata."""
|
||||
if self._tokenizer_md:
|
||||
return [self._tokenizer_md.create_metadata()]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_tokenizer_associated_files(self) -> List[str]:
|
||||
"""Gets the associated files that are packed in the tokenizer."""
|
||||
if self._tokenizer_md:
|
||||
return _get_tokenizer_associated_files(
|
||||
self._tokenizer_md.create_metadata().options)
|
||||
else:
|
||||
return []
|
||||
|
||||
@property
|
||||
def input_md(self) -> List[TensorMd]:
|
||||
return self._input_md
|
||||
|
||||
|
||||
class ClassificationTensorMd(TensorMd):
|
||||
"""A container for the classification tensor metadata information.
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import csv
|
|||
import dataclasses
|
||||
import os
|
||||
import tempfile
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import flatbuffers
|
||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb
|
||||
|
@ -101,6 +101,34 @@ class RegexTokenizer:
|
|||
vocab_file_path: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BertTokenizer:
|
||||
"""Parameters of the Bert tokenizer [1] metadata information.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||
|
||||
Attributes:
|
||||
vocab_file_path: path to the vocabulary file.
|
||||
"""
|
||||
vocab_file_path: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SentencePieceTokenizer:
|
||||
"""Parameters of the sentence piece tokenizer tokenizer [1] metadata information.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||
|
||||
Attributes:
|
||||
sentence_piece_model_path: path to the sentence piece model file.
|
||||
vocab_file_path: path to the vocabulary file.
|
||||
"""
|
||||
sentence_piece_model_path: str
|
||||
vocab_file_path: Optional[str] = None
|
||||
|
||||
|
||||
class Labels(object):
|
||||
"""Simple container holding classification labels of a particular tensor.
|
||||
|
||||
|
@ -282,7 +310,9 @@ def _create_metadata_buffer(
|
|||
model_buffer: bytearray,
|
||||
general_md: Optional[metadata_info.GeneralMd] = None,
|
||||
input_md: Optional[List[metadata_info.TensorMd]] = None,
|
||||
output_md: Optional[List[metadata_info.TensorMd]] = None) -> bytearray:
|
||||
output_md: Optional[List[metadata_info.TensorMd]] = None,
|
||||
input_process_units: Optional[List[metadata_fb.ProcessUnitT]] = None
|
||||
) -> bytearray:
|
||||
"""Creates a buffer of the metadata.
|
||||
|
||||
Args:
|
||||
|
@ -290,7 +320,9 @@ def _create_metadata_buffer(
|
|||
general_md: general information about the model.
|
||||
input_md: metadata information of the input tensors.
|
||||
output_md: metadata information of the output tensors.
|
||||
|
||||
input_process_units: a lists of metadata of the input process units [1].
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L655
|
||||
Returns:
|
||||
A buffer of the metadata.
|
||||
|
||||
|
@ -325,6 +357,8 @@ def _create_metadata_buffer(
|
|||
subgraph_metadata = metadata_fb.SubGraphMetadataT()
|
||||
subgraph_metadata.inputTensorMetadata = input_metadata
|
||||
subgraph_metadata.outputTensorMetadata = output_metadata
|
||||
if input_process_units:
|
||||
subgraph_metadata.inputProcessUnits = input_process_units
|
||||
|
||||
# Create the whole model metadata.
|
||||
if general_md is None:
|
||||
|
@ -366,6 +400,7 @@ class MetadataWriter(object):
|
|||
self._model_buffer = model_buffer
|
||||
self._general_md = None
|
||||
self._input_mds = []
|
||||
self._input_process_units = []
|
||||
self._output_mds = []
|
||||
self._associated_files = []
|
||||
self._temp_folder = tempfile.TemporaryDirectory()
|
||||
|
@ -416,7 +451,7 @@ class MetadataWriter(object):
|
|||
description: Description of the input tensor.
|
||||
|
||||
Returns:
|
||||
The MetaWriter instance, can be used for chained operation.
|
||||
The MetadataWriter instance, can be used for chained operation.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
|
||||
|
@ -448,7 +483,7 @@ class MetadataWriter(object):
|
|||
description: Description of the input tensor.
|
||||
|
||||
Returns:
|
||||
The MetaWriter instance, can be used for chained operation.
|
||||
The MetadataWriter instance, can be used for chained operation.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||
|
@ -462,6 +497,63 @@ class MetadataWriter(object):
|
|||
self._associated_files.append(regex_tokenizer.vocab_file_path)
|
||||
return self
|
||||
|
||||
def add_bert_text_input(self, tokenizer: Union[BertTokenizer,
|
||||
SentencePieceTokenizer],
|
||||
ids_name: str, mask_name: str,
|
||||
segment_name: str) -> 'MetadataWriter':
|
||||
"""Adds an metadata for the text input with bert / sentencepiece tokenizer.
|
||||
|
||||
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
|
||||
in the TFLite schema, which help to determine the tensor order when
|
||||
populating metadata.
|
||||
|
||||
Args:
|
||||
tokenizer: information of the tokenizer used to process the input string,
|
||||
if any. Supported tokenziers are: `BertTokenizer` [1] and
|
||||
`SentencePieceTokenizer` [2].
|
||||
ids_name: name of the ids tensor, which represents the tokenized ids of
|
||||
the input text.
|
||||
mask_name: name of the mask tensor, which represents the mask with `1` for
|
||||
real tokens and `0` for padding tokens.
|
||||
segment_name: name of the segment ids tensor, where `0` stands for the
|
||||
first sequence, and `1` stands for the second sequence if exists.
|
||||
|
||||
Returns:
|
||||
The MetadataWriter instance, can be used for chained operation.
|
||||
|
||||
Raises:
|
||||
ValueError: if the type tokenizer is not BertTokenizer or
|
||||
SentencePieceTokenizer.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||
"""
|
||||
if isinstance(tokenizer, BertTokenizer):
|
||||
tokenizer_md = metadata_info.BertTokenizerMd(
|
||||
vocab_file_path=tokenizer.vocab_file_path)
|
||||
elif isinstance(tokenizer, SentencePieceTokenizer):
|
||||
tokenizer_md = metadata_info.SentencePieceTokenizerMd(
|
||||
sentence_piece_model_path=tokenizer.sentence_piece_model_path,
|
||||
vocab_file_path=tokenizer.vocab_file_path)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'The type of tokenizer, {type(tokenizer)}, is unsupported')
|
||||
bert_input_md = metadata_info.BertInputTensorsMd(
|
||||
self._model_buffer,
|
||||
ids_name,
|
||||
mask_name,
|
||||
segment_name,
|
||||
tokenizer_md=tokenizer_md)
|
||||
|
||||
self._input_mds.extend(bert_input_md.input_md)
|
||||
self._associated_files.extend(
|
||||
bert_input_md.get_tokenizer_associated_files())
|
||||
self._input_process_units.extend(
|
||||
bert_input_md.create_input_process_unit_metadata())
|
||||
return self
|
||||
|
||||
def add_classification_output(
|
||||
self,
|
||||
labels: Optional[Labels] = None,
|
||||
|
@ -546,7 +638,8 @@ class MetadataWriter(object):
|
|||
model_buffer=self._model_buffer,
|
||||
general_md=self._general_md,
|
||||
input_md=self._input_mds,
|
||||
output_md=self._output_mds)
|
||||
output_md=self._output_mds,
|
||||
input_process_units=self._input_process_units)
|
||||
populator.load_metadata_buffer(metadata_buffer)
|
||||
if self._associated_files:
|
||||
populator.load_associated_files(self._associated_files)
|
||||
|
|
|
@ -14,11 +14,18 @@
|
|||
# ==============================================================================
|
||||
"""Writes metadata and label file to the Text classifier models."""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
|
||||
_MODEL_NAME = "TextClassifier"
|
||||
_MODEL_DESCRIPTION = ("Classify the input text into a set of known categories.")
|
||||
|
||||
# The input tensor names of models created by Model Maker.
|
||||
_DEFAULT_ID_NAME = "serving_default_input_word_ids:0"
|
||||
_DEFAULT_MASK_NAME = "serving_default_input_mask:0"
|
||||
_DEFAULT_SEGMENT_ID_NAME = "serving_default_input_type_ids:0"
|
||||
|
||||
|
||||
class MetadataWriter(metadata_writer.MetadataWriterBase):
|
||||
"""MetadataWriter to write the metadata into the text classifier."""
|
||||
|
@ -62,3 +69,51 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
|
|||
writer.add_regex_text_input(regex_tokenizer)
|
||||
writer.add_classification_output(labels)
|
||||
return cls(writer)
|
||||
|
||||
@classmethod
|
||||
def create_for_bert_model(
|
||||
cls,
|
||||
model_buffer: bytearray,
|
||||
tokenizer: Union[metadata_writer.BertTokenizer,
|
||||
metadata_writer.SentencePieceTokenizer],
|
||||
labels: metadata_writer.Labels,
|
||||
ids_name: str = _DEFAULT_ID_NAME,
|
||||
mask_name: str = _DEFAULT_MASK_NAME,
|
||||
segment_name: str = _DEFAULT_SEGMENT_ID_NAME,
|
||||
) -> "MetadataWriter":
|
||||
"""Creates MetadataWriter for models with {Bert/SentencePiece}Tokenizer.
|
||||
|
||||
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
|
||||
in the TFLite schema, which help to determine the tensor order when
|
||||
populating metadata. The default values come from Model Maker.
|
||||
|
||||
Args:
|
||||
model_buffer: valid buffer of the model file.
|
||||
tokenizer: information of the tokenizer used to process the input string,
|
||||
if any. Supported tokenziers are: `BertTokenizer` [1] and
|
||||
`SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer` [3],
|
||||
refer to `create_for_regex_model`.
|
||||
labels: an instance of Labels helper class used in the output
|
||||
classification tensor [4].
|
||||
ids_name: name of the ids tensor, which represents the tokenized ids of
|
||||
the input text.
|
||||
mask_name: name of the mask tensor, which represents the mask with `1` for
|
||||
real tokens and `0` for padding tokens.
|
||||
segment_name: name of the segment ids tensor, where `0` stands for the
|
||||
first sequence, and `1` stands for the second sequence if exists. [1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||
[3]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||
[4]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||
|
||||
Returns:
|
||||
A MetadataWriter object.
|
||||
"""
|
||||
writer = metadata_writer.MetadataWriter(model_buffer)
|
||||
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||
writer.add_bert_text_input(tokenizer, ids_name, mask_name, segment_name)
|
||||
writer.add_classification_output(labels)
|
||||
return cls(writer)
|
||||
|
|
|
@ -367,6 +367,42 @@ class ScoreThresholdingMdTest(absltest.TestCase):
|
|||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
class BertTokenizerMdTest(absltest.TestCase):
|
||||
|
||||
_VOCAB_FILE = "vocab.txt"
|
||||
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "bert_tokenizer_meta.json"))
|
||||
|
||||
def test_create_metadata_should_succeed(self):
|
||||
tokenizer_md = metadata_info.BertTokenizerMd(self._VOCAB_FILE)
|
||||
tokenizer_metadata = tokenizer_md.create_metadata()
|
||||
|
||||
metadata_json = _metadata.convert_to_json(
|
||||
_create_dummy_model_metadata_with_process_uint(tokenizer_metadata))
|
||||
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
class SentencePieceTokenizerMdTest(absltest.TestCase):
|
||||
|
||||
_VOCAB_FILE = "vocab.txt"
|
||||
_SP_MODEL = "sp.model"
|
||||
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "sentence_piece_tokenizer_meta.json"))
|
||||
|
||||
def test_create_metadata_should_succeed(self):
|
||||
tokenizer_md = metadata_info.SentencePieceTokenizerMd(
|
||||
self._SP_MODEL, self._VOCAB_FILE)
|
||||
tokenizer_metadata = tokenizer_md.create_metadata()
|
||||
|
||||
metadata_json = _metadata.convert_to_json(
|
||||
_create_dummy_model_metadata_with_process_uint(tokenizer_metadata))
|
||||
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
def _create_dummy_model_metadata_with_tensor(
|
||||
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
|
||||
# Create a dummy model using the tensor metadata.
|
||||
|
|
|
@ -21,28 +21,64 @@ from mediapipe.tasks.python.metadata.metadata_writers import text_classifier
|
|||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DIR = "mediapipe/tasks/testdata/metadata/"
|
||||
_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite")
|
||||
_REGEX_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite")
|
||||
_LABEL_FILE = test_utils.get_test_data_path(_TEST_DIR +
|
||||
"movie_review_labels.txt")
|
||||
_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt")
|
||||
_REGEX_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt")
|
||||
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||
_JSON_FILE = test_utils.get_test_data_path("movie_review.json")
|
||||
_REGEX_JSON_FILE = test_utils.get_test_data_path("movie_review.json")
|
||||
|
||||
_BERT_MODEL = test_utils.get_test_data_path(
|
||||
_TEST_DIR + "bert_text_classifier_no_metadata.tflite")
|
||||
_BERT_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR +
|
||||
"mobilebert_vocab.txt")
|
||||
_SP_MODEL_FILE = test_utils.get_test_data_path(_TEST_DIR + "30k-clean.model")
|
||||
_BERT_JSON_FILE = test_utils.get_test_data_path(
|
||||
_TEST_DIR + "bert_text_classifier_with_bert_tokenizer.json")
|
||||
_SENTENCE_PIECE_JSON_FILE = test_utils.get_test_data_path(
|
||||
_TEST_DIR + "bert_text_classifier_with_sentence_piece.json")
|
||||
|
||||
|
||||
class TextClassifierTest(absltest.TestCase):
|
||||
|
||||
def test_write_metadata(self,):
|
||||
with open(_MODEL, "rb") as f:
|
||||
def test_write_metadata_for_regex_model(self):
|
||||
with open(_REGEX_MODEL, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = text_classifier.MetadataWriter.create_for_regex_model(
|
||||
model_buffer,
|
||||
regex_tokenizer=metadata_writer.RegexTokenizer(
|
||||
delim_regex_pattern=_DELIM_REGEX_PATTERN,
|
||||
vocab_file_path=_VOCAB_FILE),
|
||||
vocab_file_path=_REGEX_VOCAB_FILE),
|
||||
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
|
||||
_, metadata_json = writer.populate()
|
||||
|
||||
with open(_JSON_FILE, "r") as f:
|
||||
with open(_REGEX_JSON_FILE, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
def test_write_metadata_for_bert(self):
|
||||
with open(_BERT_MODEL, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = text_classifier.MetadataWriter.create_for_bert_model(
|
||||
model_buffer,
|
||||
tokenizer=metadata_writer.BertTokenizer(_BERT_VOCAB_FILE),
|
||||
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
|
||||
_, metadata_json = writer.populate()
|
||||
|
||||
with open(_BERT_JSON_FILE, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
def test_write_metadata_for_sentence_piece(self):
|
||||
with open(_BERT_MODEL, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = text_classifier.MetadataWriter.create_for_bert_model(
|
||||
model_buffer,
|
||||
tokenizer=metadata_writer.SentencePieceTokenizer(_SP_MODEL_FILE),
|
||||
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
|
||||
_, metadata_json = writer.populate()
|
||||
|
||||
with open(_SENTENCE_PIECE_JSON_FILE, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
|
14
mediapipe/tasks/testdata/metadata/BUILD
vendored
14
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -23,10 +23,13 @@ package(
|
|||
)
|
||||
|
||||
mediapipe_files(srcs = [
|
||||
"30k-clean.model",
|
||||
"bert_text_classifier_no_metadata.tflite",
|
||||
"mobile_ica_8bit-with-metadata.tflite",
|
||||
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
|
||||
"mobile_ica_8bit-without-model-metadata.tflite",
|
||||
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
|
||||
"mobilebert_vocab.txt",
|
||||
"mobilenet_v1_0.25_224_1_default_1.tflite",
|
||||
"mobilenet_v2_1.0_224_quant.tflite",
|
||||
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
|
||||
|
@ -60,11 +63,17 @@ exports_files([
|
|||
"movie_review_labels.txt",
|
||||
"regex_vocab.txt",
|
||||
"movie_review.json",
|
||||
"bert_tokenizer_meta.json",
|
||||
"bert_text_classifier_with_sentence_piece.json",
|
||||
"sentence_piece_tokenizer_meta.json",
|
||||
"bert_text_classifier_with_bert_tokenizer.json",
|
||||
])
|
||||
|
||||
filegroup(
|
||||
name = "model_files",
|
||||
srcs = [
|
||||
"30k-clean.model",
|
||||
"bert_text_classifier_no_metadata.tflite",
|
||||
"mobile_ica_8bit-with-metadata.tflite",
|
||||
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
|
||||
"mobile_ica_8bit-without-model-metadata.tflite",
|
||||
|
@ -81,6 +90,9 @@ filegroup(
|
|||
name = "data_files",
|
||||
srcs = [
|
||||
"associated_file_meta.json",
|
||||
"bert_text_classifier_with_bert_tokenizer.json",
|
||||
"bert_text_classifier_with_sentence_piece.json",
|
||||
"bert_tokenizer_meta.json",
|
||||
"bounding_box_tensor_meta.json",
|
||||
"classification_tensor_float_meta.json",
|
||||
"classification_tensor_uint8_meta.json",
|
||||
|
@ -96,6 +108,7 @@ filegroup(
|
|||
"input_text_tensor_default_meta.json",
|
||||
"input_text_tensor_meta.json",
|
||||
"labels.txt",
|
||||
"mobilebert_vocab.txt",
|
||||
"mobilenet_v2_1.0_224.json",
|
||||
"mobilenet_v2_1.0_224_quant.json",
|
||||
"movie_review.json",
|
||||
|
@ -105,5 +118,6 @@ filegroup(
|
|||
"score_calibration_file_meta.json",
|
||||
"score_calibration_tensor_meta.json",
|
||||
"score_thresholding_meta.json",
|
||||
"sentence_piece_tokenizer_meta.json",
|
||||
],
|
||||
)
|
||||
|
|
84
mediapipe/tasks/testdata/metadata/bert_text_classifier_with_bert_tokenizer.json
vendored
Normal file
84
mediapipe/tasks/testdata/metadata/bert_text_classifier_with_bert_tokenizer.json
vendored
Normal file
|
@ -0,0 +1,84 @@
|
|||
{
|
||||
"name": "TextClassifier",
|
||||
"description": "Classify the input text into a set of known categories.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "ids",
|
||||
"description": "Tokenized ids of the input text.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "segment_ids",
|
||||
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "mask",
|
||||
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "score",
|
||||
"description": "Score of the labels respectively.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"input_process_units": [
|
||||
{
|
||||
"options_type": "BertTokenizerOptions",
|
||||
"options": {
|
||||
"vocab_file": [
|
||||
{
|
||||
"name": "mobilebert_vocab.txt",
|
||||
"description": "Vocabulary file to convert natural language words to embedding vectors.",
|
||||
"type": "VOCABULARY"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.1.0"
|
||||
}
|
83
mediapipe/tasks/testdata/metadata/bert_text_classifier_with_sentence_piece.json
vendored
Normal file
83
mediapipe/tasks/testdata/metadata/bert_text_classifier_with_sentence_piece.json
vendored
Normal file
|
@ -0,0 +1,83 @@
|
|||
{
|
||||
"name": "TextClassifier",
|
||||
"description": "Classify the input text into a set of known categories.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "ids",
|
||||
"description": "Tokenized ids of the input text.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "segment_ids",
|
||||
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "mask",
|
||||
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "score",
|
||||
"description": "Score of the labels respectively.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"input_process_units": [
|
||||
{
|
||||
"options_type": "SentencePieceTokenizerOptions",
|
||||
"options": {
|
||||
"sentencePiece_model": [
|
||||
{
|
||||
"name": "30k-clean.model",
|
||||
"description": "The sentence piece model file."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.1.0"
|
||||
}
|
20
mediapipe/tasks/testdata/metadata/bert_tokenizer_meta.json
vendored
Normal file
20
mediapipe/tasks/testdata/metadata/bert_tokenizer_meta.json
vendored
Normal file
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_process_units": [
|
||||
{
|
||||
"options_type": "BertTokenizerOptions",
|
||||
"options": {
|
||||
"vocab_file": [
|
||||
{
|
||||
"name": "vocab.txt",
|
||||
"description": "Vocabulary file to convert natural language words to embedding vectors.",
|
||||
"type": "VOCABULARY"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
30522
mediapipe/tasks/testdata/metadata/mobilebert_vocab.txt
vendored
Normal file
30522
mediapipe/tasks/testdata/metadata/mobilebert_vocab.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
26
mediapipe/tasks/testdata/metadata/sentence_piece_tokenizer_meta.json
vendored
Normal file
26
mediapipe/tasks/testdata/metadata/sentence_piece_tokenizer_meta.json
vendored
Normal file
|
@ -0,0 +1,26 @@
|
|||
{
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_process_units": [
|
||||
{
|
||||
"options_type": "SentencePieceTokenizerOptions",
|
||||
"options": {
|
||||
"sentencePiece_model": [
|
||||
{
|
||||
"name": "sp.model",
|
||||
"description": "The sentence piece model file."
|
||||
}
|
||||
],
|
||||
"vocab_file": [
|
||||
{
|
||||
"name": "vocab.txt",
|
||||
"description": "Vocabulary file to convert natural language words to embedding vectors. This file is optional during tokenization, while the sentence piece model is mandatory.",
|
||||
"type": "VOCABULARY"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
30
third_party/external_files.bzl
vendored
30
third_party/external_files.bzl
vendored
|
@ -28,12 +28,36 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/associated_file_meta.json?generation=1665422792304395"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_bert_text_classifier_no_metadata_tflite",
|
||||
sha256 = "9b4554f6e28a72a3f40511964eed1ccf4e74cc074f81543cacca4faf169a173e",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_no_metadata.tflite?generation=1667948360250899"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_bert_text_classifier_tflite",
|
||||
sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1666144699858747"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_bert_text_classifier_with_bert_tokenizer_json",
|
||||
sha256 = "49f148a13a4e3b486b1d3c2400e46e5ebd0d375674c0154278b835760e873a95",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_with_bert_tokenizer.json?generation=1667948363241334"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_bert_text_classifier_with_sentence_piece_json",
|
||||
sha256 = "113091f3892691de57e379387256b2ce0cc18a1b5185af866220a46da8221f26",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_with_sentence_piece.json?generation=1667948366009530"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_bert_tokenizer_meta_json",
|
||||
sha256 = "116d70c7c3ef413a8bff54ab758f9ed3d6e51fdc5621d8c920ad2f0035831804",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_tokenizer_meta.json?generation=1667948368809108"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_bounding_box_tensor_meta_json",
|
||||
sha256 = "cc019cee86529955a24a3d43ca3d778fa366bcb90d67c8eaf55696789833841a",
|
||||
|
@ -790,6 +814,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1661875931201364"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_sentence_piece_tokenizer_meta_json",
|
||||
sha256 = "416bfe231710502e4a93e1b1950c0c6e5db49cffb256d241ef3d3f2d0d57718b",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/sentence_piece_tokenizer_meta.json?generation=1667948375508564"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_speech_16000_hz_mono_wav",
|
||||
sha256 = "71caf50b8757d6ab9cad5eae4d36669d3c20c225a51660afd7fe0dc44cdb74f6",
|
||||
|
|
Loading…
Reference in New Issue
Block a user