Migrate base metadata functionality like MetadataPopulator and MetadataDisplayer class into MediaPipe.

PiperOrigin-RevId: 478279747
This commit is contained in:
Yuqi Li 2022-10-01 21:50:22 -07:00 committed by Copybara-Service
parent 9568de0570
commit 13f6e0c797
18 changed files with 2094 additions and 9 deletions

View File

@ -1,4 +1,4 @@
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library", "flatbuffer_py_library")
package( package(
default_visibility = [ default_visibility = [
@ -14,3 +14,13 @@ flatbuffer_cc_library(
name = "metadata_schema_cc", name = "metadata_schema_cc",
srcs = ["metadata_schema.fbs"], srcs = ["metadata_schema.fbs"],
) )
flatbuffer_py_library(
name = "schema_py",
srcs = ["@org_tensorflow//tensorflow/lite/schema:schema.fbs"],
)
flatbuffer_py_library(
name = "metadata_schema_py",
srcs = ["metadata_schema.fbs"],
)

View File

@ -0,0 +1,40 @@
load("//mediapipe/tasks/metadata:build_defs.bzl", "stamp_metadata_parser_version")
package(
licenses = ["notice"], # Apache 2.0
)
stamp_metadata_parser_version(
name = "metadata_parser_py",
srcs = ["metadata_parser.py.template"],
outs = ["metadata_parser.py"],
)
py_library(
name = "metadata",
srcs = [
"metadata.py",
":metadata_parser_py",
],
data = ["//mediapipe/tasks/metadata:metadata_schema.fbs"],
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [
"//mediapipe/tasks/cc/metadata/python:_pywrap_metadata_version",
"//mediapipe/tasks/metadata:metadata_schema_py",
"//mediapipe/tasks/metadata:schema_py",
"//mediapipe/tasks/python/metadata/flatbuffers_lib:_pywrap_flatbuffers",
"@flatbuffers//:runtime_py",
],
)
py_binary(
name = "metadata_displayer_cli",
srcs = ["metadata_displayer_cli.py"],
visibility = [
"//visibility:public",
],
deps = [
":metadata",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,20 @@
load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
package(
default_visibility = ["//mediapipe/tasks:internal"],
licenses = ["notice"], # Apache 2.0
)
pybind_extension(
name = "_pywrap_flatbuffers",
srcs = [
"flatbuffers_lib.cc",
],
features = ["-use_header_modules"],
module_name = "_pywrap_flatbuffers",
deps = [
"@flatbuffers",
"@local_config_python//:python_headers",
"@pybind11",
],
)

View File

@ -0,0 +1,59 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "flatbuffers/flatbuffers.h"
#include "flatbuffers/idl.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
namespace tflite {
namespace support {
PYBIND11_MODULE(_pywrap_flatbuffers, m) {
pybind11::class_<flatbuffers::IDLOptions>(m, "IDLOptions")
.def(pybind11::init<>())
.def_readwrite("strict_json", &flatbuffers::IDLOptions::strict_json);
pybind11::class_<flatbuffers::Parser>(m, "Parser")
.def(pybind11::init<const flatbuffers::IDLOptions&>())
.def("parse",
[](flatbuffers::Parser* self, const std::string& source) {
return self->Parse(source.c_str());
})
.def_readonly("builder", &flatbuffers::Parser::builder_)
.def_readonly("error", &flatbuffers::Parser::error_);
pybind11::class_<flatbuffers::FlatBufferBuilder>(m, "FlatBufferBuilder")
.def("clear", &flatbuffers::FlatBufferBuilder::Clear)
.def("push_flat_buffer", [](flatbuffers::FlatBufferBuilder* self,
const std::string& contents) {
self->PushFlatBuffer(reinterpret_cast<const uint8_t*>(contents.c_str()),
contents.length());
});
m.def("generate_text_file", &flatbuffers::GenerateTextFile);
m.def(
"generate_text",
[](const flatbuffers::Parser& parser,
const std::string& buffer) -> std::string {
std::string text;
if (!flatbuffers::GenerateText(
parser, reinterpret_cast<const void*>(buffer.c_str()), &text)) {
return "";
}
return text;
});
}
} // namespace support
} // namespace tflite

View File

@ -0,0 +1,865 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Lite metadata tools."""
import copy
import inspect
import io
import os
import shutil
import sys
import tempfile
import warnings
import zipfile
import flatbuffers
from mediapipe.tasks.cc.metadata.python import _pywrap_metadata_version
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
from mediapipe.tasks.python.metadata.flatbuffers_lib import _pywrap_flatbuffers
try:
# If exists, optionally use TensorFlow to open and check files. Used to
# support more than local file systems.
# In pip requirements, we doesn't necessarily need tensorflow as a dep.
import tensorflow as tf
_open_file = tf.io.gfile.GFile
_exists_file = tf.io.gfile.exists
except ImportError as e:
# If TensorFlow package doesn't exist, fall back to original open and exists.
_open_file = open
_exists_file = os.path.exists
def _maybe_open_as_binary(filename, mode):
"""Maybe open the binary file, and returns a file-like."""
if hasattr(filename, "read"): # A file-like has read().
return filename
openmode = mode if "b" in mode else mode + "b" # Add binary explicitly.
return _open_file(filename, openmode)
def _open_as_zipfile(filename, mode="r"):
"""Open file as a zipfile.
Args:
filename: str or file-like or path-like, to the zipfile.
mode: str, common file mode for zip.
(See: https://docs.python.org/3/library/zipfile.html)
Returns:
A ZipFile object.
"""
file_like = _maybe_open_as_binary(filename, mode)
return zipfile.ZipFile(file_like, mode)
def _is_zipfile(filename):
"""Checks whether it is a zipfile."""
with _maybe_open_as_binary(filename, "r") as f:
return zipfile.is_zipfile(f)
def get_path_to_datafile(path):
"""Gets the path to the specified file in the data dependencies.
The path is relative to the file calling the function.
It's a simple replacement of
"tensorflow.python.platform.resource_loader.get_path_to_datafile".
Args:
path: a string resource path relative to the calling file.
Returns:
The path to the specified file present in the data attribute of py_test
or py_binary.
"""
data_files_path = os.path.dirname(inspect.getfile(sys._getframe(1))) # pylint: disable=protected-access
return os.path.join(data_files_path, path)
_FLATC_TFLITE_METADATA_SCHEMA_FILE = get_path_to_datafile(
"../../metadata/metadata_schema.fbs")
# TODO: add delete method for associated files.
class MetadataPopulator(object):
"""Packs metadata and associated files into TensorFlow Lite model file.
MetadataPopulator can be used to populate metadata and model associated files
into a model file or a model buffer (in bytearray). It can also help to
inspect list of files that have been packed into the model or are supposed to
be packed into the model.
The metadata file (or buffer) should be generated based on the metadata
schema:
third_party/tensorflow/lite/schema/metadata_schema.fbs
Example usage:
Populate matadata and label file into an image classifier model.
First, based on metadata_schema.fbs, generate the metadata for this image
classifer model using Flatbuffers API. Attach the label file onto the ouput
tensor (the tensor of probabilities) in the metadata.
Then, pack the metadata and label file into the model as follows.
```python
# Populating a metadata file (or a metadta buffer) and associated files to
a model file:
populator = MetadataPopulator.with_model_file(model_file)
# For metadata buffer (bytearray read from the metadata file), use:
# populator.load_metadata_buffer(metadata_buf)
populator.load_metadata_file(metadata_file)
populator.load_associated_files([label.txt])
# For associated file buffer (bytearray read from the file), use:
# populator.load_associated_file_buffers({"label.txt": b"file content"})
populator.populate()
# Populating a metadata file (or a metadata buffer) and associated files to
a model buffer:
populator = MetadataPopulator.with_model_buffer(model_buf)
populator.load_metadata_file(metadata_file)
populator.load_associated_files([label.txt])
populator.populate()
# Writing the updated model buffer into a file.
updated_model_buf = populator.get_model_buffer()
with open("updated_model.tflite", "wb") as f:
f.write(updated_model_buf)
# Transferring metadata and associated files from another TFLite model:
populator = MetadataPopulator.with_model_buffer(model_buf)
populator_dst.load_metadata_and_associated_files(src_model_buf)
populator_dst.populate()
updated_model_buf = populator.get_model_buffer()
with open("updated_model.tflite", "wb") as f:
f.write(updated_model_buf)
```
Note that existing metadata buffer (if applied) will be overridden by the new
metadata buffer.
"""
# As Zip API is used to concatenate associated files after tflite model file,
# the populating operation is developed based on a model file. For in-memory
# model buffer, we create a tempfile to serve the populating operation.
# Creating the deleting such a tempfile is handled by the class,
# _MetadataPopulatorWithBuffer.
METADATA_FIELD_NAME = "TFLITE_METADATA"
TFLITE_FILE_IDENTIFIER = b"TFL3"
METADATA_FILE_IDENTIFIER = b"M001"
def __init__(self, model_file):
"""Constructor for MetadataPopulator.
Args:
model_file: valid path to a TensorFlow Lite model file.
Raises:
IOError: File not found.
ValueError: the model does not have the expected flatbuffer identifer.
"""
_assert_model_file_identifier(model_file)
self._model_file = model_file
self._metadata_buf = None
# _associated_files is a dict of file name and file buffer.
self._associated_files = {}
@classmethod
def with_model_file(cls, model_file):
"""Creates a MetadataPopulator object that populates data to a model file.
Args:
model_file: valid path to a TensorFlow Lite model file.
Returns:
MetadataPopulator object.
Raises:
IOError: File not found.
ValueError: the model does not have the expected flatbuffer identifer.
"""
return cls(model_file)
# TODO: investigate if type check can be applied to model_buf for
# FB.
@classmethod
def with_model_buffer(cls, model_buf):
"""Creates a MetadataPopulator object that populates data to a model buffer.
Args:
model_buf: TensorFlow Lite model buffer in bytearray.
Returns:
A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
Raises:
ValueError: the model does not have the expected flatbuffer identifer.
"""
return _MetadataPopulatorWithBuffer(model_buf)
def get_model_buffer(self):
"""Gets the buffer of the model with packed metadata and associated files.
Returns:
Model buffer (in bytearray).
"""
with _open_file(self._model_file, "rb") as f:
return f.read()
def get_packed_associated_file_list(self):
"""Gets a list of associated files packed to the model file.
Returns:
List of packed associated files.
"""
if not _is_zipfile(self._model_file):
return []
with _open_as_zipfile(self._model_file, "r") as zf:
return zf.namelist()
def get_recorded_associated_file_list(self):
"""Gets a list of associated files recorded in metadata of the model file.
Associated files may be attached to a model, a subgraph, or an input/output
tensor.
Returns:
List of recorded associated files.
"""
if not self._metadata_buf:
return []
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(
self._metadata_buf, 0))
return [
file.name.decode("utf-8")
for file in self._get_recorded_associated_file_object_list(metadata)
]
def load_associated_file_buffers(self, associated_files):
"""Loads the associated file buffers (in bytearray) to be populated.
Args:
associated_files: a dictionary of associated file names and corresponding
file buffers, such as {"file.txt": b"file content"}. If pass in file
paths for the file name, only the basename will be populated.
"""
self._associated_files.update({
os.path.basename(name): buffers
for name, buffers in associated_files.items()
})
def load_associated_files(self, associated_files):
"""Loads associated files that to be concatenated after the model file.
Args:
associated_files: list of file paths.
Raises:
IOError:
File not found.
"""
for af_name in associated_files:
_assert_file_exist(af_name)
with _open_file(af_name, "rb") as af:
self.load_associated_file_buffers({af_name: af.read()})
def load_metadata_buffer(self, metadata_buf):
"""Loads the metadata buffer (in bytearray) to be populated.
Args:
metadata_buf: metadata buffer (in bytearray) to be populated.
Raises:
ValueError: The metadata to be populated is empty.
ValueError: The metadata does not have the expected flatbuffer identifer.
ValueError: Cannot get minimum metadata parser version.
ValueError: The number of SubgraphMetadata is not 1.
ValueError: The number of input/output tensors does not match the number
of input/output tensor metadata.
"""
if not metadata_buf:
raise ValueError("The metadata to be populated is empty.")
self._validate_metadata(metadata_buf)
# Gets the minimum metadata parser version of the metadata_buf.
min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion(
bytes(metadata_buf))
# Inserts in the minimum metadata parser version into the metadata_buf.
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
metadata.minParserVersion = min_version
# Remove local file directory in the `name` field of `AssociatedFileT`, and
# make it consistent with the name of the actual file packed in the model.
self._use_basename_for_associated_files_in_metadata(metadata)
b = flatbuffers.Builder(0)
b.Finish(metadata.Pack(b), self.METADATA_FILE_IDENTIFIER)
metadata_buf_with_version = b.Output()
self._metadata_buf = metadata_buf_with_version
def load_metadata_file(self, metadata_file):
"""Loads the metadata file to be populated.
Args:
metadata_file: path to the metadata file to be populated.
Raises:
IOError: File not found.
ValueError: The metadata to be populated is empty.
ValueError: The metadata does not have the expected flatbuffer identifer.
ValueError: Cannot get minimum metadata parser version.
ValueError: The number of SubgraphMetadata is not 1.
ValueError: The number of input/output tensors does not match the number
of input/output tensor metadata.
"""
_assert_file_exist(metadata_file)
with _open_file(metadata_file, "rb") as f:
metadata_buf = f.read()
self.load_metadata_buffer(bytearray(metadata_buf))
def load_metadata_and_associated_files(self, src_model_buf):
"""Loads the metadata and associated files from another model buffer.
Args:
src_model_buf: source model buffer (in bytearray) with metadata and
associated files.
"""
# Load the model metadata from src_model_buf if exist.
metadata_buffer = get_metadata_buffer(src_model_buf)
if metadata_buffer:
self.load_metadata_buffer(metadata_buffer)
# Load the associated files from src_model_buf if exist.
if _is_zipfile(io.BytesIO(src_model_buf)):
with _open_as_zipfile(io.BytesIO(src_model_buf)) as zf:
self.load_associated_file_buffers(
{f: zf.read(f) for f in zf.namelist()})
def populate(self):
"""Populates loaded metadata and associated files into the model file."""
self._assert_validate()
self._populate_metadata_buffer()
self._populate_associated_files()
def _assert_validate(self):
"""Validates the metadata and associated files to be populated.
Raises:
ValueError:
File is recorded in the metadata, but is not going to be populated.
File has already been packed.
"""
# Gets files that are recorded in metadata.
recorded_files = self.get_recorded_associated_file_list()
# Gets files that have been packed to self._model_file.
packed_files = self.get_packed_associated_file_list()
# Gets the file name of those associated files to be populated.
to_be_populated_files = self._associated_files.keys()
# Checks all files recorded in the metadata will be populated.
for rf in recorded_files:
if rf not in to_be_populated_files and rf not in packed_files:
raise ValueError("File, '{0}', is recorded in the metadata, but has "
"not been loaded into the populator.".format(rf))
for f in to_be_populated_files:
if f in packed_files:
raise ValueError("File, '{0}', has already been packed.".format(f))
if f not in recorded_files:
warnings.warn(
"File, '{0}', does not exist in the metadata. But packing it to "
"tflite model is still allowed.".format(f))
def _copy_archived_files(self, src_zip, file_list, dst_zip):
"""Copy archieved files in file_list from src_zip ro dst_zip."""
if not _is_zipfile(src_zip):
raise ValueError("File, '{0}', is not a zipfile.".format(src_zip))
with _open_as_zipfile(src_zip, "r") as src_zf, \
_open_as_zipfile(dst_zip, "a") as dst_zf:
src_list = src_zf.namelist()
for f in file_list:
if f not in src_list:
raise ValueError(
"File, '{0}', does not exist in the zipfile, {1}.".format(
f, src_zip))
file_buffer = src_zf.read(f)
dst_zf.writestr(f, file_buffer)
def _get_associated_files_from_process_units(self, table, field_name):
"""Gets the files that are attached the process units field of a table.
Args:
table: a Flatbuffers table object that contains fields of an array of
ProcessUnit, such as TensorMetadata and SubGraphMetadata.
field_name: the name of the field in the table that represents an array of
ProcessUnit. If the table is TensorMetadata, field_name can be
"ProcessUnits". If the table is SubGraphMetadata, field_name can be
either "InputProcessUnits" or "OutputProcessUnits".
Returns:
A list of AssociatedFileT objects.
"""
if table is None:
return []
file_list = []
process_units = getattr(table, field_name)
# If the process_units field is not populated, it will be None. Use an
# empty list to skip the check.
for process_unit in process_units or []:
options = process_unit.options
if isinstance(options, (_metadata_fb.BertTokenizerOptionsT,
_metadata_fb.RegexTokenizerOptionsT)):
file_list += self._get_associated_files_from_table(options, "vocabFile")
elif isinstance(options, _metadata_fb.SentencePieceTokenizerOptionsT):
file_list += self._get_associated_files_from_table(
options, "sentencePieceModel")
file_list += self._get_associated_files_from_table(options, "vocabFile")
return file_list
def _get_associated_files_from_table(self, table, field_name):
"""Gets the associated files that are attached a table directly.
Args:
table: a Flatbuffers table object that contains fields of an array of
AssociatedFile, such as TensorMetadata and BertTokenizerOptions.
field_name: the name of the field in the table that represents an array of
ProcessUnit. If the table is TensorMetadata, field_name can be
"AssociatedFiles". If the table is BertTokenizerOptions, field_name can
be "VocabFile".
Returns:
A list of AssociatedFileT objects.
"""
if table is None:
return []
# If the associated file field is not populated,
# `getattr(table, field_name)` will be None. Return an empty list.
return getattr(table, field_name) or []
def _get_recorded_associated_file_object_list(self, metadata):
"""Gets a list of AssociatedFileT objects recorded in the metadata.
Associated files may be attached to a model, a subgraph, or an input/output
tensor.
Args:
metadata: the ModelMetadataT object.
Returns:
List of recorded AssociatedFileT objects.
"""
recorded_files = []
# Add associated files attached to ModelMetadata.
recorded_files += self._get_associated_files_from_table(
metadata, "associatedFiles")
# Add associated files attached to each SubgraphMetadata.
for subgraph in metadata.subgraphMetadata or []:
recorded_files += self._get_associated_files_from_table(
subgraph, "associatedFiles")
# Add associated files attached to each input tensor.
for tensor_metadata in subgraph.inputTensorMetadata or []:
recorded_files += self._get_associated_files_from_table(
tensor_metadata, "associatedFiles")
recorded_files += self._get_associated_files_from_process_units(
tensor_metadata, "processUnits")
# Add associated files attached to each output tensor.
for tensor_metadata in subgraph.outputTensorMetadata or []:
recorded_files += self._get_associated_files_from_table(
tensor_metadata, "associatedFiles")
recorded_files += self._get_associated_files_from_process_units(
tensor_metadata, "processUnits")
# Add associated files attached to the input_process_units.
recorded_files += self._get_associated_files_from_process_units(
subgraph, "inputProcessUnits")
# Add associated files attached to the output_process_units.
recorded_files += self._get_associated_files_from_process_units(
subgraph, "outputProcessUnits")
return recorded_files
def _populate_associated_files(self):
"""Concatenates associated files after TensorFlow Lite model file.
If the MetadataPopulator object is created using the method,
with_model_file(model_file), the model file will be updated.
"""
# Opens up the model file in "appending" mode.
# If self._model_file already has pack files, zipfile will concatenate
# addition files after self._model_file. For example, suppose we have
# self._model_file = old_tflite_file | label1.txt | label2.txt
# Then after trigger populate() to add label3.txt, self._model_file becomes
# self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt
with tempfile.SpooledTemporaryFile() as temp:
# (1) Copy content from model file of to temp file.
with _open_file(self._model_file, "rb") as f:
shutil.copyfileobj(f, temp)
# (2) Append of to a temp file as a zip.
with _open_as_zipfile(temp, "a") as zf:
for file_name, file_buffer in self._associated_files.items():
zf.writestr(file_name, file_buffer)
# (3) Copy temp file to model file.
temp.seek(0)
with _open_file(self._model_file, "wb") as f:
shutil.copyfileobj(temp, f)
def _populate_metadata_buffer(self):
"""Populates the metadata buffer (in bytearray) into the model file.
Inserts metadata_buf into the metadata field of schema.Model. If the
MetadataPopulator object is created using the method,
with_model_file(model_file), the model file will be updated.
Existing metadata buffer (if applied) will be overridden by the new metadata
buffer.
"""
with _open_file(self._model_file, "rb") as f:
model_buf = f.read()
model = _schema_fb.ModelT.InitFromObj(
_schema_fb.Model.GetRootAsModel(model_buf, 0))
buffer_field = _schema_fb.BufferT()
buffer_field.data = self._metadata_buf
is_populated = False
if not model.metadata:
model.metadata = []
else:
# Check if metadata has already been populated.
for meta in model.metadata:
if meta.name.decode("utf-8") == self.METADATA_FIELD_NAME:
is_populated = True
model.buffers[meta.buffer] = buffer_field
if not is_populated:
if not model.buffers:
model.buffers = []
model.buffers.append(buffer_field)
# Creates a new metadata field.
metadata_field = _schema_fb.MetadataT()
metadata_field.name = self.METADATA_FIELD_NAME
metadata_field.buffer = len(model.buffers) - 1
model.metadata.append(metadata_field)
# Packs model back to a flatbuffer binaray file.
b = flatbuffers.Builder(0)
b.Finish(model.Pack(b), self.TFLITE_FILE_IDENTIFIER)
model_buf = b.Output()
# Saves the updated model buffer to model file.
# Gets files that have been packed to self._model_file.
packed_files = self.get_packed_associated_file_list()
if packed_files:
# Writes the updated model buffer and associated files into a new model
# file (in memory). Then overwrites the original model file.
with tempfile.SpooledTemporaryFile() as temp:
temp.write(model_buf)
self._copy_archived_files(self._model_file, packed_files, temp)
temp.seek(0)
with _open_file(self._model_file, "wb") as f:
shutil.copyfileobj(temp, f)
else:
with _open_file(self._model_file, "wb") as f:
f.write(model_buf)
def _use_basename_for_associated_files_in_metadata(self, metadata):
"""Removes any associated file local directory (if exists)."""
for file in self._get_recorded_associated_file_object_list(metadata):
file.name = os.path.basename(file.name)
def _validate_metadata(self, metadata_buf):
"""Validates the metadata to be populated."""
_assert_metadata_buffer_identifier(metadata_buf)
# Verify the number of SubgraphMetadata is exactly one.
# TFLite currently only support one subgraph.
model_meta = _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
metadata_buf, 0)
if model_meta.SubgraphMetadataLength() != 1:
raise ValueError("The number of SubgraphMetadata should be exactly one, "
"but got {0}.".format(
model_meta.SubgraphMetadataLength()))
# Verify if the number of tensor metadata matches the number of tensors.
with _open_file(self._model_file, "rb") as f:
model_buf = f.read()
model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
num_input_tensors = model.Subgraphs(0).InputsLength()
num_input_meta = model_meta.SubgraphMetadata(0).InputTensorMetadataLength()
if num_input_tensors != num_input_meta:
raise ValueError(
"The number of input tensors ({0}) should match the number of "
"input tensor metadata ({1})".format(num_input_tensors,
num_input_meta))
num_output_tensors = model.Subgraphs(0).OutputsLength()
num_output_meta = model_meta.SubgraphMetadata(
0).OutputTensorMetadataLength()
if num_output_tensors != num_output_meta:
raise ValueError(
"The number of output tensors ({0}) should match the number of "
"output tensor metadata ({1})".format(num_output_tensors,
num_output_meta))
class _MetadataPopulatorWithBuffer(MetadataPopulator):
"""Subclass of MetadtaPopulator that populates metadata to a model buffer.
This class is used to populate metadata into a in-memory model buffer. As we
use Zip API to concatenate associated files after tflite model file, the
populating operation is developed based on a model file. For in-memory model
buffer, we create a tempfile to serve the populating operation. This class is
then used to generate this tempfile, and delete the file when the
MetadataPopulator object is deleted.
"""
def __init__(self, model_buf):
"""Constructor for _MetadataPopulatorWithBuffer.
Args:
model_buf: TensorFlow Lite model buffer in bytearray.
Raises:
ValueError: model_buf is empty.
ValueError: model_buf does not have the expected flatbuffer identifer.
"""
if not model_buf:
raise ValueError("model_buf cannot be empty.")
with tempfile.NamedTemporaryFile() as temp:
model_file = temp.name
with _open_file(model_file, "wb") as f:
f.write(model_buf)
super().__init__(model_file)
def __del__(self):
"""Destructor of _MetadataPopulatorWithBuffer.
Deletes the tempfile.
"""
if os.path.exists(self._model_file):
os.remove(self._model_file)
class MetadataDisplayer(object):
"""Displays metadata and associated file info in human-readable format."""
def __init__(self, model_buffer, metadata_buffer, associated_file_list):
"""Constructor for MetadataDisplayer.
Args:
model_buffer: valid buffer of the model file.
metadata_buffer: valid buffer of the metadata file.
associated_file_list: list of associate files in the model file.
"""
_assert_model_buffer_identifier(model_buffer)
_assert_metadata_buffer_identifier(metadata_buffer)
self._model_buffer = model_buffer
self._metadata_buffer = metadata_buffer
self._associated_file_list = associated_file_list
@classmethod
def with_model_file(cls, model_file):
"""Creates a MetadataDisplayer object for the model file.
Args:
model_file: valid path to a TensorFlow Lite model file.
Returns:
MetadataDisplayer object.
Raises:
IOError: File not found.
ValueError: The model does not have metadata.
"""
_assert_file_exist(model_file)
with _open_file(model_file, "rb") as f:
return cls.with_model_buffer(f.read())
@classmethod
def with_model_buffer(cls, model_buffer):
"""Creates a MetadataDisplayer object for a file buffer.
Args:
model_buffer: TensorFlow Lite model buffer in bytearray.
Returns:
MetadataDisplayer object.
"""
if not model_buffer:
raise ValueError("model_buffer cannot be empty.")
metadata_buffer = get_metadata_buffer(model_buffer)
if not metadata_buffer:
raise ValueError("The model does not have metadata.")
associated_file_list = cls._parse_packed_associted_file_list(model_buffer)
return cls(model_buffer, metadata_buffer, associated_file_list)
def get_associated_file_buffer(self, filename):
"""Get the specified associated file content in bytearray.
Args:
filename: name of the file to be extracted.
Returns:
The file content in bytearray.
Raises:
ValueError: if the file does not exist in the model.
"""
if filename not in self._associated_file_list:
raise ValueError(
"The file, {}, does not exist in the model.".format(filename))
with _open_as_zipfile(io.BytesIO(self._model_buffer)) as zf:
return zf.read(filename)
def get_metadata_buffer(self):
"""Get the metadata buffer in bytearray out from the model."""
return copy.deepcopy(self._metadata_buffer)
def get_metadata_json(self):
"""Converts the metadata into a json string."""
return convert_to_json(self._metadata_buffer)
def get_packed_associated_file_list(self):
"""Returns a list of associated files that are packed in the model.
Returns:
A name list of associated files.
"""
return copy.deepcopy(self._associated_file_list)
@staticmethod
def _parse_packed_associted_file_list(model_buf):
"""Gets a list of associated files packed to the model file.
Args:
model_buf: valid file buffer.
Returns:
List of packed associated files.
"""
try:
with _open_as_zipfile(io.BytesIO(model_buf)) as zf:
return zf.namelist()
except zipfile.BadZipFile:
return []
# Create an individual method for getting the metadata json file, so that it can
# be used as a standalone util.
def convert_to_json(metadata_buffer):
"""Converts the metadata into a json string.
Args:
metadata_buffer: valid metadata buffer in bytes.
Returns:
Metadata in JSON format.
Raises:
ValueError: error occured when parsing the metadata schema file.
"""
opt = _pywrap_flatbuffers.IDLOptions()
opt.strict_json = True
parser = _pywrap_flatbuffers.Parser(opt)
with _open_file(_FLATC_TFLITE_METADATA_SCHEMA_FILE) as f:
metadata_schema_content = f.read()
if not parser.parse(metadata_schema_content):
raise ValueError("Cannot parse metadata schema. Reason: " + parser.error)
return _pywrap_flatbuffers.generate_text(parser, metadata_buffer)
def _assert_file_exist(filename):
"""Checks if a file exists."""
if not _exists_file(filename):
raise IOError("File, '{0}', does not exist.".format(filename))
def _assert_model_file_identifier(model_file):
"""Checks if a model file has the expected TFLite schema identifier."""
_assert_file_exist(model_file)
with _open_file(model_file, "rb") as f:
_assert_model_buffer_identifier(f.read())
def _assert_model_buffer_identifier(model_buf):
if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0):
raise ValueError(
"The model provided does not have the expected identifier, and "
"may not be a valid TFLite model.")
def _assert_metadata_buffer_identifier(metadata_buf):
"""Checks if a metadata buffer has the expected Metadata schema identifier."""
if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier(
metadata_buf, 0):
raise ValueError(
"The metadata buffer does not have the expected identifier, and may not"
" be a valid TFLite Metadata.")
def get_metadata_buffer(model_buf):
"""Returns the metadata in the model file as a buffer.
Args:
model_buf: valid buffer of the model file.
Returns:
Metadata buffer. Returns `None` if the model does not have metadata.
"""
tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
# Gets metadata from the model file.
for i in range(tflite_model.MetadataLength()):
meta = tflite_model.Metadata(i)
if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
buffer_index = meta.Buffer()
metadata = tflite_model.Buffers(buffer_index)
return metadata.DataAsNumpy().tobytes()
return None

View File

@ -0,0 +1,34 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""CLI tool for display metadata."""
from absl import app
from absl import flags
from mediapipe.tasks.python.metadata import metadata
FLAGS = flags.FLAGS
flags.DEFINE_string('model_path', None, 'Path to the TFLite model file.')
flags.DEFINE_string('export_json_path', None, 'Path to the output JSON file.')
def main(_):
displayer = metadata.MetadataDisplayer.with_model_file(FLAGS.model_path)
with open(FLAGS.export_json_path, 'w') as f:
f.write(displayer.get_metadata_json())
if __name__ == '__main__':
app.run(main)

View File

@ -0,0 +1,26 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Information about the metadata parser that this python library depends on."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class MetadataParser(object):
"""Information about the metadata parser."""
# The version of the metadata parser.
VERSION = "{LATEST_METADATA_PARSER_VERSION}"

View File

@ -19,9 +19,9 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
py_library( py_library(
name = "test_util", name = "test_utils",
testonly = 1, testonly = 1,
srcs = ["test_util.py"], srcs = ["test_utils.py"],
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",

View File

@ -0,0 +1,31 @@
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
py_test(
name = "metadata_test",
srcs = ["metadata_test.py"],
data = ["//mediapipe/tasks/testdata/metadata:data_files"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
"//mediapipe/tasks/metadata:metadata_schema_py",
"//mediapipe/tasks/metadata:schema_py",
"//mediapipe/tasks/python/metadata",
"//mediapipe/tasks/python/test:test_utils",
"@flatbuffers//:runtime_py",
],
)
py_test(
name = "metadata_parser_test",
srcs = ["metadata_parser_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
"//mediapipe/tasks/python/metadata",
],
)

View File

@ -0,0 +1,37 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for mediapipe.tasks.python.metadata.metadata_parser."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from absl.testing import absltest
from mediapipe.tasks.python.metadata import metadata_parser
class MetadataParserTest(absltest.TestCase):
def testVersionWellFormedSemanticVersion(self):
# Validates that the version is well-formed (x.y.z).
self.assertTrue(
re.match('[0-9]+\\.[0-9]+\\.[0-9]+',
metadata_parser.MetadataParser.VERSION))
if __name__ == '__main__':
absltest.main()

View File

@ -0,0 +1,857 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for mediapipe.tasks.python.metadata.metadata."""
import enum
import os
from absl.testing import absltest
from absl.testing import parameterized
import six
import flatbuffers
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
from mediapipe.tasks.python.metadata import metadata as _metadata
from mediapipe.tasks.python.test import test_utils
class Tokenizer(enum.Enum):
BERT_TOKENIZER = 0
SENTENCE_PIECE = 1
class TensorType(enum.Enum):
INPUT = 0
OUTPUT = 1
def _read_file(file_name, mode="rb"):
with open(file_name, mode) as f:
return f.read()
class MetadataTest(parameterized.TestCase):
def setUp(self):
super(MetadataTest, self).setUp()
self._invalid_model_buf = None
self._invalid_file = "not_existed_file"
self._model_buf = self._create_model_buf()
self._model_file = self.create_tempfile().full_path
with open(self._model_file, "wb") as f:
f.write(self._model_buf)
self._metadata_file = self._create_metadata_file()
self._metadata_file_with_version = self._create_metadata_file_with_version(
self._metadata_file, "1.0.0")
self._file1 = self.create_tempfile("file1").full_path
self._file2 = self.create_tempfile("file2").full_path
self._file2_content = b"file2_content"
with open(self._file2, "wb") as f:
f.write(self._file2_content)
self._file3 = self.create_tempfile("file3").full_path
def _create_model_buf(self):
# Create a model with two inputs and one output, which matches the metadata
# created by _create_metadata_file().
metadata_field = _schema_fb.MetadataT()
subgraph = _schema_fb.SubGraphT()
subgraph.inputs = [0, 1]
subgraph.outputs = [2]
metadata_field.name = "meta"
buffer_field = _schema_fb.BufferT()
model = _schema_fb.ModelT()
model.subgraphs = [subgraph]
# Creates the metadata and buffer fields for testing purposes.
model.metadata = [metadata_field, metadata_field]
model.buffers = [buffer_field, buffer_field, buffer_field]
model_builder = flatbuffers.Builder(0)
model_builder.Finish(
model.Pack(model_builder),
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
return model_builder.Output()
def _create_metadata_file(self):
associated_file1 = _metadata_fb.AssociatedFileT()
associated_file1.name = b"file1"
associated_file2 = _metadata_fb.AssociatedFileT()
associated_file2.name = b"file2"
self.expected_recorded_files = [
six.ensure_str(associated_file1.name),
six.ensure_str(associated_file2.name)
]
input_meta = _metadata_fb.TensorMetadataT()
output_meta = _metadata_fb.TensorMetadataT()
output_meta.associatedFiles = [associated_file2]
subgraph = _metadata_fb.SubGraphMetadataT()
# Create a model with two inputs and one output.
subgraph.inputTensorMetadata = [input_meta, input_meta]
subgraph.outputTensorMetadata = [output_meta]
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "Mobilenet_quantized"
model_meta.associatedFiles = [associated_file1]
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
b.Finish(
model_meta.Pack(b),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_file = self.create_tempfile().full_path
with open(metadata_file, "wb") as f:
f.write(b.Output())
return metadata_file
def _create_model_buffer_with_wrong_identifier(self):
wrong_identifier = b"widn"
model = _schema_fb.ModelT()
model_builder = flatbuffers.Builder(0)
model_builder.Finish(model.Pack(model_builder), wrong_identifier)
return model_builder.Output()
def _create_metadata_buffer_with_wrong_identifier(self):
# Creates a metadata with wrong identifier
wrong_identifier = b"widn"
metadata = _metadata_fb.ModelMetadataT()
metadata_builder = flatbuffers.Builder(0)
metadata_builder.Finish(metadata.Pack(metadata_builder), wrong_identifier)
return metadata_builder.Output()
def _populate_metadata_with_identifier(self, model_buf, metadata_buf,
identifier):
# For testing purposes only. MetadataPopulator cannot populate metadata with
# wrong identifiers.
model = _schema_fb.ModelT.InitFromObj(
_schema_fb.Model.GetRootAsModel(model_buf, 0))
buffer_field = _schema_fb.BufferT()
buffer_field.data = metadata_buf
model.buffers = [buffer_field]
# Creates a new metadata field.
metadata_field = _schema_fb.MetadataT()
metadata_field.name = _metadata.MetadataPopulator.METADATA_FIELD_NAME
metadata_field.buffer = len(model.buffers) - 1
model.metadata = [metadata_field]
b = flatbuffers.Builder(0)
b.Finish(model.Pack(b), identifier)
return b.Output()
def _create_metadata_file_with_version(self, metadata_file, min_version):
# Creates a new metadata file with the specified min_version for testing
# purposes.
metadata_buf = bytearray(_read_file(metadata_file))
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
metadata.minParserVersion = min_version
b = flatbuffers.Builder(0)
b.Finish(
metadata.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_file_with_version = self.create_tempfile().full_path
with open(metadata_file_with_version, "wb") as f:
f.write(b.Output())
return metadata_file_with_version
class MetadataPopulatorTest(MetadataTest):
def _create_bert_tokenizer(self):
vocab_file_name = "bert_vocab"
vocab = _metadata_fb.AssociatedFileT()
vocab.name = vocab_file_name
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions
tokenizer.options = _metadata_fb.BertTokenizerOptionsT()
tokenizer.options.vocabFile = [vocab]
return tokenizer, [vocab_file_name]
def _create_sentence_piece_tokenizer(self):
sp_model_name = "sp_model"
vocab_file_name = "sp_vocab"
sp_model = _metadata_fb.AssociatedFileT()
sp_model.name = sp_model_name
vocab = _metadata_fb.AssociatedFileT()
vocab.name = vocab_file_name
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = (
_metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions)
tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT()
tokenizer.options.sentencePieceModel = [sp_model]
tokenizer.options.vocabFile = [vocab]
return tokenizer, [sp_model_name, vocab_file_name]
def _create_tokenizer(self, tokenizer_type):
if tokenizer_type is Tokenizer.BERT_TOKENIZER:
return self._create_bert_tokenizer()
elif tokenizer_type is Tokenizer.SENTENCE_PIECE:
return self._create_sentence_piece_tokenizer()
else:
raise ValueError(
"The tokenizer type, {0}, is unsupported.".format(tokenizer_type))
def _create_tempfiles(self, file_names):
tempfiles = []
for name in file_names:
tempfiles.append(self.create_tempfile(name).full_path)
return tempfiles
def _create_model_meta_with_subgraph_meta(self, subgraph_meta):
model_meta = _metadata_fb.ModelMetadataT()
model_meta.subgraphMetadata = [subgraph_meta]
b = flatbuffers.Builder(0)
b.Finish(
model_meta.Pack(b),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
return b.Output()
def testToValidModelFile(self):
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
self.assertIsInstance(populator, _metadata.MetadataPopulator)
def testToInvalidModelFile(self):
with self.assertRaises(IOError) as error:
_metadata.MetadataPopulator.with_model_file(self._invalid_file)
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
str(error.exception))
def testToValidModelBuffer(self):
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
self.assertIsInstance(populator, _metadata.MetadataPopulator)
def testToInvalidModelBuffer(self):
with self.assertRaises(ValueError) as error:
_metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf)
self.assertEqual("model_buf cannot be empty.", str(error.exception))
def testToModelBufferWithWrongIdentifier(self):
model_buf = self._create_model_buffer_with_wrong_identifier()
with self.assertRaises(ValueError) as error:
_metadata.MetadataPopulator.with_model_buffer(model_buf)
self.assertEqual(
"The model provided does not have the expected identifier, and "
"may not be a valid TFLite model.", str(error.exception))
def testSinglePopulateAssociatedFile(self):
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
populator.load_associated_files([self._file1])
populator.populate()
packed_files = populator.get_packed_associated_file_list()
expected_packed_files = [os.path.basename(self._file1)]
self.assertEqual(set(packed_files), set(expected_packed_files))
def testRepeatedPopulateAssociatedFile(self):
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator.load_associated_files([self._file1, self._file2])
# Loads file2 multiple times.
populator.load_associated_files([self._file2])
populator.populate()
packed_files = populator.get_packed_associated_file_list()
expected_packed_files = [
os.path.basename(self._file1),
os.path.basename(self._file2)
]
self.assertLen(packed_files, 2)
self.assertEqual(set(packed_files), set(expected_packed_files))
# Check if the model buffer read from file is the same as that read from
# get_model_buffer().
model_buf_from_file = _read_file(self._model_file)
model_buf_from_getter = populator.get_model_buffer()
self.assertEqual(model_buf_from_file, model_buf_from_getter)
def testPopulateInvalidAssociatedFile(self):
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
with self.assertRaises(IOError) as error:
populator.load_associated_files([self._invalid_file])
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
str(error.exception))
def testPopulatePackedAssociatedFile(self):
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
populator.load_associated_files([self._file1])
populator.populate()
with self.assertRaises(ValueError) as error:
populator.load_associated_files([self._file1])
populator.populate()
self.assertEqual(
"File, '{0}', has already been packed.".format(
os.path.basename(self._file1)), str(error.exception))
def testLoadAssociatedFileBuffers(self):
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
file_buffer = _read_file(self._file1)
populator.load_associated_file_buffers({self._file1: file_buffer})
populator.populate()
packed_files = populator.get_packed_associated_file_list()
expected_packed_files = [os.path.basename(self._file1)]
self.assertEqual(set(packed_files), set(expected_packed_files))
def testRepeatedLoadAssociatedFileBuffers(self):
file_buffer1 = _read_file(self._file1)
file_buffer2 = _read_file(self._file2)
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator.load_associated_file_buffers({
self._file1: file_buffer1,
self._file2: file_buffer2
})
# Loads file2 multiple times.
populator.load_associated_file_buffers({self._file2: file_buffer2})
populator.populate()
packed_files = populator.get_packed_associated_file_list()
expected_packed_files = [
os.path.basename(self._file1),
os.path.basename(self._file2)
]
self.assertEqual(set(packed_files), set(expected_packed_files))
# Check if the model buffer read from file is the same as that read from
# get_model_buffer().
model_buf_from_file = _read_file(self._model_file)
model_buf_from_getter = populator.get_model_buffer()
self.assertEqual(model_buf_from_file, model_buf_from_getter)
def testLoadPackedAssociatedFileBuffersFails(self):
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
file_buffer = _read_file(self._file1)
populator.load_associated_file_buffers({self._file1: file_buffer})
populator.populate()
# Load file1 again should fail.
with self.assertRaises(ValueError) as error:
populator.load_associated_file_buffers({self._file1: file_buffer})
populator.populate()
self.assertEqual(
"File, '{0}', has already been packed.".format(
os.path.basename(self._file1)), str(error.exception))
def testGetPackedAssociatedFileList(self):
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
packed_files = populator.get_packed_associated_file_list()
self.assertEqual(packed_files, [])
def testPopulateMetadataFileToEmptyModelFile(self):
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator.load_metadata_file(self._metadata_file)
populator.load_associated_files([self._file1, self._file2])
populator.populate()
model_buf_from_file = _read_file(self._model_file)
model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
# self._model_file already has two elements in the metadata field, so the
# populated TFLite metadata will be the third element.
metadata_field = model.Metadata(2)
self.assertEqual(
six.ensure_str(metadata_field.Name()),
six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
buffer_index = metadata_field.Buffer()
buffer_data = model.Buffers(buffer_index)
metadata_buf_np = buffer_data.DataAsNumpy()
metadata_buf = metadata_buf_np.tobytes()
expected_metadata_buf = bytearray(
_read_file(self._metadata_file_with_version))
self.assertEqual(metadata_buf, expected_metadata_buf)
recorded_files = populator.get_recorded_associated_file_list()
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
# Up to now, we've proved the correctness of the model buffer that read from
# file. Then we'll test if get_model_buffer() gives the same model buffer.
model_buf_from_getter = populator.get_model_buffer()
self.assertEqual(model_buf_from_file, model_buf_from_getter)
def testPopulateMetadataFileWithoutAssociatedFiles(self):
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator.load_metadata_file(self._metadata_file)
populator.load_associated_files([self._file1])
# Suppose to populate self._file2, because it is recorded in the metadta.
with self.assertRaises(ValueError) as error:
populator.populate()
self.assertEqual(("File, '{0}', is recorded in the metadata, but has "
"not been loaded into the populator.").format(
os.path.basename(self._file2)), str(error.exception))
def testPopulateMetadataBufferWithWrongIdentifier(self):
metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
with self.assertRaises(ValueError) as error:
populator.load_metadata_buffer(metadata_buf)
self.assertEqual(
"The metadata buffer does not have the expected identifier, and may not"
" be a valid TFLite Metadata.", str(error.exception))
def _assert_golden_metadata(self, model_file):
model_buf_from_file = _read_file(model_file)
model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
# There are two elements in model.Metadata array before the population.
# Metadata should be packed to the third element in the array.
metadata_field = model.Metadata(2)
self.assertEqual(
six.ensure_str(metadata_field.Name()),
six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
buffer_index = metadata_field.Buffer()
buffer_data = model.Buffers(buffer_index)
metadata_buf_np = buffer_data.DataAsNumpy()
metadata_buf = metadata_buf_np.tobytes()
expected_metadata_buf = bytearray(
_read_file(self._metadata_file_with_version))
self.assertEqual(metadata_buf, expected_metadata_buf)
def testPopulateMetadataFileToModelWithMetadataAndAssociatedFiles(self):
# First, creates a dummy metadata different from self._metadata_file. It
# needs to have the same input/output tensor numbers as self._model_file.
# Populates it and the associated files into the model.
input_meta = _metadata_fb.TensorMetadataT()
output_meta = _metadata_fb.TensorMetadataT()
subgraph = _metadata_fb.SubGraphMetadataT()
# Create a model with two inputs and one output.
subgraph.inputTensorMetadata = [input_meta, input_meta]
subgraph.outputTensorMetadata = [output_meta]
model_meta = _metadata_fb.ModelMetadataT()
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
b.Finish(
model_meta.Pack(b),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()
# Populate the metadata.
populator1 = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator1.load_metadata_buffer(metadata_buf)
populator1.load_associated_files([self._file1, self._file2])
populator1.populate()
# Then, populate the metadata again.
populator2 = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator2.load_metadata_file(self._metadata_file)
populator2.populate()
# Test if the metadata is populated correctly.
self._assert_golden_metadata(self._model_file)
def testPopulateMetadataFileToModelFileWithMetadataAndBufFields(self):
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator.load_metadata_file(self._metadata_file)
populator.load_associated_files([self._file1, self._file2])
populator.populate()
# Tests if the metadata is populated correctly.
self._assert_golden_metadata(self._model_file)
recorded_files = populator.get_recorded_associated_file_list()
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
# Up to now, we've proved the correctness of the model buffer that read from
# file. Then we'll test if get_model_buffer() gives the same model buffer.
model_buf_from_file = _read_file(self._model_file)
model_buf_from_getter = populator.get_model_buffer()
self.assertEqual(model_buf_from_file, model_buf_from_getter)
def testPopulateInvalidMetadataFile(self):
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
with self.assertRaises(IOError) as error:
populator.load_metadata_file(self._invalid_file)
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
str(error.exception))
def testPopulateInvalidMetadataBuffer(self):
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
with self.assertRaises(ValueError) as error:
populator.load_metadata_buffer([])
self.assertEqual("The metadata to be populated is empty.",
str(error.exception))
def testGetModelBufferBeforePopulatingData(self):
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
model_buf = populator.get_model_buffer()
expected_model_buf = self._model_buf
self.assertEqual(model_buf, expected_model_buf)
def testLoadMetadataBufferWithNoSubgraphMetadataThrowsException(self):
# Create a dummy metadata without Subgraph.
model_meta = _metadata_fb.ModelMetadataT()
builder = flatbuffers.Builder(0)
builder.Finish(
model_meta.Pack(builder),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
meta_buf = builder.Output()
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
with self.assertRaises(ValueError) as error:
populator.load_metadata_buffer(meta_buf)
self.assertEqual(
"The number of SubgraphMetadata should be exactly one, but got 0.",
str(error.exception))
def testLoadMetadataBufferWithWrongInputMetaNumberThrowsException(self):
# Create a dummy metadata with no input tensor metadata, while the expected
# number is 2.
output_meta = _metadata_fb.TensorMetadataT()
subgprah_meta = _metadata_fb.SubGraphMetadataT()
subgprah_meta.outputTensorMetadata = [output_meta]
model_meta = _metadata_fb.ModelMetadataT()
model_meta.subgraphMetadata = [subgprah_meta]
builder = flatbuffers.Builder(0)
builder.Finish(
model_meta.Pack(builder),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
meta_buf = builder.Output()
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
with self.assertRaises(ValueError) as error:
populator.load_metadata_buffer(meta_buf)
self.assertEqual(
("The number of input tensors (2) should match the number of "
"input tensor metadata (0)"), str(error.exception))
def testLoadMetadataBufferWithWrongOutputMetaNumberThrowsException(self):
# Create a dummy metadata with no output tensor metadata, while the expected
# number is 1.
input_meta = _metadata_fb.TensorMetadataT()
subgprah_meta = _metadata_fb.SubGraphMetadataT()
subgprah_meta.inputTensorMetadata = [input_meta, input_meta]
model_meta = _metadata_fb.ModelMetadataT()
model_meta.subgraphMetadata = [subgprah_meta]
builder = flatbuffers.Builder(0)
builder.Finish(
model_meta.Pack(builder),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
meta_buf = builder.Output()
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
with self.assertRaises(ValueError) as error:
populator.load_metadata_buffer(meta_buf)
self.assertEqual(
("The number of output tensors (1) should match the number of "
"output tensor metadata (0)"), str(error.exception))
def testLoadMetadataAndAssociatedFilesShouldSucceeds(self):
# Create a src model with metadata and two associated files.
src_model_buf = self._create_model_buf()
populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf)
populator_src.load_metadata_file(self._metadata_file)
populator_src.load_associated_files([self._file1, self._file2])
populator_src.populate()
# Create a model to be populated with the metadata and files from
# src_model_buf.
dst_model_buf = self._create_model_buf()
populator_dst = _metadata.MetadataPopulator.with_model_buffer(dst_model_buf)
populator_dst.load_metadata_and_associated_files(
populator_src.get_model_buffer())
populator_dst.populate()
# Tests if the metadata and associated files are populated correctly.
dst_model_file = self.create_tempfile().full_path
with open(dst_model_file, "wb") as f:
f.write(populator_dst.get_model_buffer())
self._assert_golden_metadata(dst_model_file)
recorded_files = populator_dst.get_recorded_associated_file_list()
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
@parameterized.named_parameters(
{
"testcase_name": "InputTensorWithBert",
"tensor_type": TensorType.INPUT,
"tokenizer_type": Tokenizer.BERT_TOKENIZER
}, {
"testcase_name": "OutputTensorWithBert",
"tensor_type": TensorType.OUTPUT,
"tokenizer_type": Tokenizer.BERT_TOKENIZER
}, {
"testcase_name": "InputTensorWithSentencePiece",
"tensor_type": TensorType.INPUT,
"tokenizer_type": Tokenizer.SENTENCE_PIECE
}, {
"testcase_name": "OutputTensorWithSentencePiece",
"tensor_type": TensorType.OUTPUT,
"tokenizer_type": Tokenizer.SENTENCE_PIECE
})
def testGetRecordedAssociatedFileListWithSubgraphTensor(
self, tensor_type, tokenizer_type):
# Creates a metadata with the tokenizer in the tensor process units.
tokenizer, expected_files = self._create_tokenizer(tokenizer_type)
# Create the tensor with process units.
tensor = _metadata_fb.TensorMetadataT()
tensor.processUnits = [tokenizer]
# Create the subgrah with the tensor.
subgraph = _metadata_fb.SubGraphMetadataT()
dummy_tensor_meta = _metadata_fb.TensorMetadataT()
subgraph.outputTensorMetadata = [dummy_tensor_meta]
if tensor_type is TensorType.INPUT:
subgraph.inputTensorMetadata = [tensor, dummy_tensor_meta]
subgraph.outputTensorMetadata = [dummy_tensor_meta]
elif tensor_type is TensorType.OUTPUT:
subgraph.inputTensorMetadata = [dummy_tensor_meta, dummy_tensor_meta]
subgraph.outputTensorMetadata = [tensor]
else:
raise ValueError(
"The tensor type, {0}, is unsupported.".format(tensor_type))
# Create a model metadata with the subgraph metadata
meta_buffer = self._create_model_meta_with_subgraph_meta(subgraph)
# Creates the tempfiles.
tempfiles = self._create_tempfiles(expected_files)
# Creates the MetadataPopulator object.
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator.load_metadata_buffer(meta_buffer)
populator.load_associated_files(tempfiles)
populator.populate()
recorded_files = populator.get_recorded_associated_file_list()
self.assertEqual(set(recorded_files), set(expected_files))
@parameterized.named_parameters(
{
"testcase_name": "InputTensorWithBert",
"tensor_type": TensorType.INPUT,
"tokenizer_type": Tokenizer.BERT_TOKENIZER
}, {
"testcase_name": "OutputTensorWithBert",
"tensor_type": TensorType.OUTPUT,
"tokenizer_type": Tokenizer.BERT_TOKENIZER
}, {
"testcase_name": "InputTensorWithSentencePiece",
"tensor_type": TensorType.INPUT,
"tokenizer_type": Tokenizer.SENTENCE_PIECE
}, {
"testcase_name": "OutputTensorWithSentencePiece",
"tensor_type": TensorType.OUTPUT,
"tokenizer_type": Tokenizer.SENTENCE_PIECE
})
def testGetRecordedAssociatedFileListWithSubgraphProcessUnits(
self, tensor_type, tokenizer_type):
# Creates a metadata with the tokenizer in the subgraph process units.
tokenizer, expected_files = self._create_tokenizer(tokenizer_type)
# Create the subgraph with process units.
subgraph = _metadata_fb.SubGraphMetadataT()
if tensor_type is TensorType.INPUT:
subgraph.inputProcessUnits = [tokenizer]
elif tensor_type is TensorType.OUTPUT:
subgraph.outputProcessUnits = [tokenizer]
else:
raise ValueError(
"The tensor type, {0}, is unsupported.".format(tensor_type))
# Creates the input and output tensor meta to match self._model_file.
dummy_tensor_meta = _metadata_fb.TensorMetadataT()
subgraph.inputTensorMetadata = [dummy_tensor_meta, dummy_tensor_meta]
subgraph.outputTensorMetadata = [dummy_tensor_meta]
# Create a model metadata with the subgraph metadata
meta_buffer = self._create_model_meta_with_subgraph_meta(subgraph)
# Creates the tempfiles.
tempfiles = self._create_tempfiles(expected_files)
# Creates the MetadataPopulator object.
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator.load_metadata_buffer(meta_buffer)
populator.load_associated_files(tempfiles)
populator.populate()
recorded_files = populator.get_recorded_associated_file_list()
self.assertEqual(set(recorded_files), set(expected_files))
def testPopulatedFullPathAssociatedFileShouldSucceed(self):
# Create AssociatedFileT using the full path file name.
associated_file = _metadata_fb.AssociatedFileT()
associated_file.name = self._file1
# Create model metadata with the associated file.
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.associatedFiles = [associated_file]
# Creates the input and output tensor metadata to match self._model_file.
dummy_tensor = _metadata_fb.TensorMetadataT()
subgraph.inputTensorMetadata = [dummy_tensor, dummy_tensor]
subgraph.outputTensorMetadata = [dummy_tensor]
md_buffer = self._create_model_meta_with_subgraph_meta(subgraph)
# Populate the metadata to a model.
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator.load_metadata_buffer(md_buffer)
populator.load_associated_files([self._file1])
populator.populate()
# The recorded file name in metadata should only contain file basename; file
# directory should not be included.
recorded_files = populator.get_recorded_associated_file_list()
self.assertEqual(set(recorded_files), set([os.path.basename(self._file1)]))
class MetadataDisplayerTest(MetadataTest):
def setUp(self):
super(MetadataDisplayerTest, self).setUp()
self._model_with_meta_file = (
self._create_model_with_metadata_and_associated_files())
def _create_model_with_metadata_and_associated_files(self):
model_buf = self._create_model_buf()
model_file = self.create_tempfile().full_path
with open(model_file, "wb") as f:
f.write(model_buf)
populator = _metadata.MetadataPopulator.with_model_file(model_file)
populator.load_metadata_file(self._metadata_file)
populator.load_associated_files([self._file1, self._file2])
populator.populate()
return model_file
def testLoadModelBufferMetadataBufferWithWrongIdentifierThrowsException(self):
model_buf = self._create_model_buffer_with_wrong_identifier()
metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
model_buf = self._populate_metadata_with_identifier(
model_buf, metadata_buf,
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
with self.assertRaises(ValueError) as error:
_metadata.MetadataDisplayer.with_model_buffer(model_buf)
self.assertEqual(
"The metadata buffer does not have the expected identifier, and may not"
" be a valid TFLite Metadata.", str(error.exception))
def testLoadModelBufferModelBufferWithWrongIdentifierThrowsException(self):
model_buf = self._create_model_buffer_with_wrong_identifier()
metadata_file = self._create_metadata_file()
wrong_identifier = b"widn"
metadata_buf = bytearray(_read_file(metadata_file))
model_buf = self._populate_metadata_with_identifier(model_buf, metadata_buf,
wrong_identifier)
with self.assertRaises(ValueError) as error:
_metadata.MetadataDisplayer.with_model_buffer(model_buf)
self.assertEqual(
"The model provided does not have the expected identifier, and "
"may not be a valid TFLite model.", str(error.exception))
def testLoadModelFileInvalidModelFileThrowsException(self):
with self.assertRaises(IOError) as error:
_metadata.MetadataDisplayer.with_model_file(self._invalid_file)
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
str(error.exception))
def testLoadModelFileModelWithoutMetadataThrowsException(self):
with self.assertRaises(ValueError) as error:
_metadata.MetadataDisplayer.with_model_file(self._model_file)
self.assertEqual("The model does not have metadata.", str(error.exception))
def testLoadModelFileModelWithMetadata(self):
displayer = _metadata.MetadataDisplayer.with_model_file(
self._model_with_meta_file)
self.assertIsInstance(displayer, _metadata.MetadataDisplayer)
def testLoadModelBufferInvalidModelBufferThrowsException(self):
with self.assertRaises(ValueError) as error:
_metadata.MetadataDisplayer.with_model_buffer(_read_file(self._file1))
self.assertEqual("model_buffer cannot be empty.", str(error.exception))
def testLoadModelBufferModelWithOutMetadataThrowsException(self):
with self.assertRaises(ValueError) as error:
_metadata.MetadataDisplayer.with_model_buffer(self._create_model_buf())
self.assertEqual("The model does not have metadata.", str(error.exception))
def testLoadModelBufferModelWithMetadata(self):
displayer = _metadata.MetadataDisplayer.with_model_buffer(
_read_file(self._model_with_meta_file))
self.assertIsInstance(displayer, _metadata.MetadataDisplayer)
def testGetAssociatedFileBufferShouldSucceed(self):
# _model_with_meta_file contains file1 and file2.
displayer = _metadata.MetadataDisplayer.with_model_file(
self._model_with_meta_file)
actual_content = displayer.get_associated_file_buffer("file2")
self.assertEqual(actual_content, self._file2_content)
def testGetAssociatedFileBufferFailsWithNonExistentFile(self):
# _model_with_meta_file contains file1 and file2.
displayer = _metadata.MetadataDisplayer.with_model_file(
self._model_with_meta_file)
non_existent_file = "non_existent_file"
with self.assertRaises(ValueError) as error:
displayer.get_associated_file_buffer(non_existent_file)
self.assertEqual(
"The file, {}, does not exist in the model.".format(non_existent_file),
str(error.exception))
def testGetMetadataBufferShouldSucceed(self):
displayer = _metadata.MetadataDisplayer.with_model_file(
self._model_with_meta_file)
actual_buffer = displayer.get_metadata_buffer()
actual_json = _metadata.convert_to_json(actual_buffer)
# Verifies the generated json file.
golden_json_file_path = test_utils.get_test_data_path("golden_json.json")
with open(golden_json_file_path, "r") as f:
expected = f.read()
self.assertEqual(actual_json, expected)
def testGetMetadataJsonModelWithMetadata(self):
displayer = _metadata.MetadataDisplayer.with_model_file(
self._model_with_meta_file)
actual = displayer.get_metadata_json()
# Verifies the generated json file.
golden_json_file_path = test_utils.get_test_data_path("golden_json.json")
expected = _read_file(golden_json_file_path, "r")
self.assertEqual(actual, expected)
def testGetPackedAssociatedFileListModelWithMetadata(self):
displayer = _metadata.MetadataDisplayer.with_model_file(
self._model_with_meta_file)
packed_files = displayer.get_packed_associated_file_list()
expected_packed_files = [
os.path.basename(self._file1),
os.path.basename(self._file2)
]
self.assertLen(
packed_files, 2,
"The following two associated files packed to the model: {0}; {1}"
.format(expected_packed_files[0], expected_packed_files[1]))
self.assertEqual(set(packed_files), set(expected_packed_files))
class MetadataUtilTest(MetadataTest):
def test_convert_to_json_should_succeed(self):
metadata_buf = _read_file(self._metadata_file_with_version)
metadata_json = _metadata.convert_to_json(metadata_buf)
# Verifies the generated json file.
golden_json_file_path = test_utils.get_test_data_path("golden_json.json")
expected = _read_file(golden_json_file_path, "r")
self.assertEqual(metadata_json, expected)
if __name__ == "__main__":
absltest.main()

View File

@ -0,0 +1,45 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test util for MediaPipe Tasks."""
import os
from absl import flags
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import image_frame as image_frame_module
FLAGS = flags.FLAGS
_Image = image_module.Image
_ImageFormat = image_frame_module.ImageFormat
_RGB_CHANNELS = 3
def test_srcdir():
"""Returns the path where to look for test data files."""
if "test_srcdir" in flags.FLAGS:
return flags.FLAGS["test_srcdir"].value
elif "TEST_SRCDIR" in os.environ:
return os.environ["TEST_SRCDIR"]
else:
raise RuntimeError("Missing TEST_SRCDIR environment.")
def get_test_data_path(file_or_dirname: str) -> str:
"""Returns full test data path."""
for (directory, subdirs, files) in os.walk(test_srcdir()):
for f in subdirs + files:
if f.endswith(file_or_dirname):
return os.path.join(directory, f)
raise ValueError("No %s in test directory" % file_or_dirname)

View File

@ -31,7 +31,7 @@ py_test(
"//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:detections", "//mediapipe/tasks/python/components/containers:detections",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_util", "//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:object_detector", "//mediapipe/tasks/python/vision:object_detector",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
], ],

View File

@ -25,7 +25,7 @@ from mediapipe.tasks.python.components.containers import bounding_box as boundin
from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import detections as detections_module from mediapipe.tasks.python.components.containers import detections as detections_module
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_util from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import object_detector from mediapipe.tasks.python.vision import object_detector
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
@ -99,8 +99,8 @@ class ObjectDetectorTest(parameterized.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.test_image = _Image.create_from_file( self.test_image = _Image.create_from_file(
test_util.get_test_data_path(_IMAGE_FILE)) test_utils.get_test_data_path(_IMAGE_FILE))
self.model_path = test_util.get_test_data_path(_MODEL_FILE) self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
def test_create_from_file_succeeds_with_valid_model_path(self): def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully. # Creates with default option and valid model file successfully.

View File

@ -28,9 +28,13 @@ mediapipe_files(srcs = [
"mobile_ica_8bit-without-model-metadata.tflite", "mobile_ica_8bit-without-model-metadata.tflite",
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite", "mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
"mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite",
"mobilenet_v2_1.0_224_quant.tflite",
]) ])
exports_files(["external_file"]) exports_files([
"external_file",
"golden_json.json",
])
filegroup( filegroup(
name = "model_files", name = "model_files",
@ -40,10 +44,14 @@ filegroup(
"mobile_ica_8bit-without-model-metadata.tflite", "mobile_ica_8bit-without-model-metadata.tflite",
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite", "mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
"mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite",
"mobilenet_v2_1.0_224_quant.tflite",
], ],
) )
filegroup( filegroup(
name = "data_files", name = "data_files",
srcs = ["external_file"], srcs = [
"external_file",
"golden_json.json",
],
) )

View File

@ -0,0 +1,28 @@
{
"name": "Mobilenet_quantized",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
},
{
}
],
"output_tensor_metadata": [
{
"associated_files": [
{
"name": "file2"
}
]
}
]
}
],
"associated_files": [
{
"name": "file1"
}
],
"min_parser_version": "1.0.0"
}

View File

@ -166,6 +166,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmark_with_attention.tflite?generation=1661875751615925"], urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmark_with_attention.tflite?generation=1661875751615925"],
) )
http_file(
name = "com_google_mediapipe_golden_json_json",
sha256 = "55c0c88748d099aa379930504df62c6c8f1d8874ea52d2f8a925f352c4c7f09c",
urls = ["https://storage.googleapis.com/mediapipe-assets/golden_json.json?generation=1664340169675228"],
)
http_file( http_file(
name = "com_google_mediapipe_hair_segmentation_tflite", name = "com_google_mediapipe_hair_segmentation_tflite",
sha256 = "d2c940c4fd80edeaf38f5d7387d1b4235ee320ed120080df67c663e749e77633", sha256 = "d2c940c4fd80edeaf38f5d7387d1b4235ee320ed120080df67c663e749e77633",
@ -316,6 +322,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite?generation=1661875836078124"], urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite?generation=1661875836078124"],
) )
http_file(
name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_tflite",
sha256 = "f08d447cde49b4e0446428aa921aff0a14ea589fa9c5817b31f83128e9a43c1d",
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.tflite?generation=1664340173966530"],
)
http_file( http_file(
name = "com_google_mediapipe_mobilenet_v2_1_0_224_tflite", name = "com_google_mediapipe_mobilenet_v2_1_0_224_tflite",
sha256 = "ff5cb7f9e62c92ebdad971f8a98aa6b3106d82a64587a7787c6a385c9e791339", sha256 = "ff5cb7f9e62c92ebdad971f8a98aa6b3106d82a64587a7787c6a385c9e791339",