Migrate base metadata functionality like MetadataPopulator and MetadataDisplayer class into MediaPipe.
PiperOrigin-RevId: 478279747
This commit is contained in:
parent
9568de0570
commit
13f6e0c797
|
@ -1,4 +1,4 @@
|
|||
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
|
||||
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library", "flatbuffer_py_library")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
|
@ -14,3 +14,13 @@ flatbuffer_cc_library(
|
|||
name = "metadata_schema_cc",
|
||||
srcs = ["metadata_schema.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_py_library(
|
||||
name = "schema_py",
|
||||
srcs = ["@org_tensorflow//tensorflow/lite/schema:schema.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_py_library(
|
||||
name = "metadata_schema_py",
|
||||
srcs = ["metadata_schema.fbs"],
|
||||
)
|
||||
|
|
40
mediapipe/tasks/python/metadata/BUILD
Normal file
40
mediapipe/tasks/python/metadata/BUILD
Normal file
|
@ -0,0 +1,40 @@
|
|||
load("//mediapipe/tasks/metadata:build_defs.bzl", "stamp_metadata_parser_version")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
stamp_metadata_parser_version(
|
||||
name = "metadata_parser_py",
|
||||
srcs = ["metadata_parser.py.template"],
|
||||
outs = ["metadata_parser.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "metadata",
|
||||
srcs = [
|
||||
"metadata.py",
|
||||
":metadata_parser_py",
|
||||
],
|
||||
data = ["//mediapipe/tasks/metadata:metadata_schema.fbs"],
|
||||
srcs_version = "PY3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/metadata/python:_pywrap_metadata_version",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_py",
|
||||
"//mediapipe/tasks/metadata:schema_py",
|
||||
"//mediapipe/tasks/python/metadata/flatbuffers_lib:_pywrap_flatbuffers",
|
||||
"@flatbuffers//:runtime_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "metadata_displayer_cli",
|
||||
srcs = ["metadata_displayer_cli.py"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
":metadata",
|
||||
],
|
||||
)
|
13
mediapipe/tasks/python/metadata/__init__.py
Normal file
13
mediapipe/tasks/python/metadata/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
20
mediapipe/tasks/python/metadata/flatbuffers_lib/BUILD
Normal file
20
mediapipe/tasks/python/metadata/flatbuffers_lib/BUILD
Normal file
|
@ -0,0 +1,20 @@
|
|||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe/tasks:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_pywrap_flatbuffers",
|
||||
srcs = [
|
||||
"flatbuffers_lib.cc",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_pywrap_flatbuffers",
|
||||
deps = [
|
||||
"@flatbuffers",
|
||||
"@local_config_python//:python_headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,59 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "flatbuffers/flatbuffers.h"
|
||||
#include "flatbuffers/idl.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
|
||||
PYBIND11_MODULE(_pywrap_flatbuffers, m) {
|
||||
pybind11::class_<flatbuffers::IDLOptions>(m, "IDLOptions")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("strict_json", &flatbuffers::IDLOptions::strict_json);
|
||||
pybind11::class_<flatbuffers::Parser>(m, "Parser")
|
||||
.def(pybind11::init<const flatbuffers::IDLOptions&>())
|
||||
.def("parse",
|
||||
[](flatbuffers::Parser* self, const std::string& source) {
|
||||
return self->Parse(source.c_str());
|
||||
})
|
||||
.def_readonly("builder", &flatbuffers::Parser::builder_)
|
||||
.def_readonly("error", &flatbuffers::Parser::error_);
|
||||
pybind11::class_<flatbuffers::FlatBufferBuilder>(m, "FlatBufferBuilder")
|
||||
.def("clear", &flatbuffers::FlatBufferBuilder::Clear)
|
||||
.def("push_flat_buffer", [](flatbuffers::FlatBufferBuilder* self,
|
||||
const std::string& contents) {
|
||||
self->PushFlatBuffer(reinterpret_cast<const uint8_t*>(contents.c_str()),
|
||||
contents.length());
|
||||
});
|
||||
m.def("generate_text_file", &flatbuffers::GenerateTextFile);
|
||||
m.def(
|
||||
"generate_text",
|
||||
[](const flatbuffers::Parser& parser,
|
||||
const std::string& buffer) -> std::string {
|
||||
std::string text;
|
||||
if (!flatbuffers::GenerateText(
|
||||
parser, reinterpret_cast<const void*>(buffer.c_str()), &text)) {
|
||||
return "";
|
||||
}
|
||||
return text;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace support
|
||||
} // namespace tflite
|
865
mediapipe/tasks/python/metadata/metadata.py
Normal file
865
mediapipe/tasks/python/metadata/metadata.py
Normal file
|
@ -0,0 +1,865 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TensorFlow Lite metadata tools."""
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import io
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import warnings
|
||||
import zipfile
|
||||
|
||||
import flatbuffers
|
||||
from mediapipe.tasks.cc.metadata.python import _pywrap_metadata_version
|
||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
|
||||
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
||||
from mediapipe.tasks.python.metadata.flatbuffers_lib import _pywrap_flatbuffers
|
||||
|
||||
try:
|
||||
# If exists, optionally use TensorFlow to open and check files. Used to
|
||||
# support more than local file systems.
|
||||
# In pip requirements, we doesn't necessarily need tensorflow as a dep.
|
||||
import tensorflow as tf
|
||||
_open_file = tf.io.gfile.GFile
|
||||
_exists_file = tf.io.gfile.exists
|
||||
except ImportError as e:
|
||||
# If TensorFlow package doesn't exist, fall back to original open and exists.
|
||||
_open_file = open
|
||||
_exists_file = os.path.exists
|
||||
|
||||
|
||||
def _maybe_open_as_binary(filename, mode):
|
||||
"""Maybe open the binary file, and returns a file-like."""
|
||||
if hasattr(filename, "read"): # A file-like has read().
|
||||
return filename
|
||||
openmode = mode if "b" in mode else mode + "b" # Add binary explicitly.
|
||||
return _open_file(filename, openmode)
|
||||
|
||||
|
||||
def _open_as_zipfile(filename, mode="r"):
|
||||
"""Open file as a zipfile.
|
||||
|
||||
Args:
|
||||
filename: str or file-like or path-like, to the zipfile.
|
||||
mode: str, common file mode for zip.
|
||||
(See: https://docs.python.org/3/library/zipfile.html)
|
||||
|
||||
Returns:
|
||||
A ZipFile object.
|
||||
"""
|
||||
file_like = _maybe_open_as_binary(filename, mode)
|
||||
return zipfile.ZipFile(file_like, mode)
|
||||
|
||||
|
||||
def _is_zipfile(filename):
|
||||
"""Checks whether it is a zipfile."""
|
||||
with _maybe_open_as_binary(filename, "r") as f:
|
||||
return zipfile.is_zipfile(f)
|
||||
|
||||
|
||||
def get_path_to_datafile(path):
|
||||
"""Gets the path to the specified file in the data dependencies.
|
||||
|
||||
The path is relative to the file calling the function.
|
||||
|
||||
It's a simple replacement of
|
||||
"tensorflow.python.platform.resource_loader.get_path_to_datafile".
|
||||
|
||||
Args:
|
||||
path: a string resource path relative to the calling file.
|
||||
|
||||
Returns:
|
||||
The path to the specified file present in the data attribute of py_test
|
||||
or py_binary.
|
||||
"""
|
||||
data_files_path = os.path.dirname(inspect.getfile(sys._getframe(1))) # pylint: disable=protected-access
|
||||
return os.path.join(data_files_path, path)
|
||||
|
||||
|
||||
_FLATC_TFLITE_METADATA_SCHEMA_FILE = get_path_to_datafile(
|
||||
"../../metadata/metadata_schema.fbs")
|
||||
|
||||
|
||||
# TODO: add delete method for associated files.
|
||||
class MetadataPopulator(object):
|
||||
"""Packs metadata and associated files into TensorFlow Lite model file.
|
||||
|
||||
MetadataPopulator can be used to populate metadata and model associated files
|
||||
into a model file or a model buffer (in bytearray). It can also help to
|
||||
inspect list of files that have been packed into the model or are supposed to
|
||||
be packed into the model.
|
||||
|
||||
The metadata file (or buffer) should be generated based on the metadata
|
||||
schema:
|
||||
third_party/tensorflow/lite/schema/metadata_schema.fbs
|
||||
|
||||
Example usage:
|
||||
Populate matadata and label file into an image classifier model.
|
||||
|
||||
First, based on metadata_schema.fbs, generate the metadata for this image
|
||||
classifer model using Flatbuffers API. Attach the label file onto the ouput
|
||||
tensor (the tensor of probabilities) in the metadata.
|
||||
|
||||
Then, pack the metadata and label file into the model as follows.
|
||||
|
||||
```python
|
||||
# Populating a metadata file (or a metadta buffer) and associated files to
|
||||
a model file:
|
||||
populator = MetadataPopulator.with_model_file(model_file)
|
||||
# For metadata buffer (bytearray read from the metadata file), use:
|
||||
# populator.load_metadata_buffer(metadata_buf)
|
||||
populator.load_metadata_file(metadata_file)
|
||||
populator.load_associated_files([label.txt])
|
||||
# For associated file buffer (bytearray read from the file), use:
|
||||
# populator.load_associated_file_buffers({"label.txt": b"file content"})
|
||||
populator.populate()
|
||||
|
||||
# Populating a metadata file (or a metadata buffer) and associated files to
|
||||
a model buffer:
|
||||
populator = MetadataPopulator.with_model_buffer(model_buf)
|
||||
populator.load_metadata_file(metadata_file)
|
||||
populator.load_associated_files([label.txt])
|
||||
populator.populate()
|
||||
# Writing the updated model buffer into a file.
|
||||
updated_model_buf = populator.get_model_buffer()
|
||||
with open("updated_model.tflite", "wb") as f:
|
||||
f.write(updated_model_buf)
|
||||
|
||||
# Transferring metadata and associated files from another TFLite model:
|
||||
populator = MetadataPopulator.with_model_buffer(model_buf)
|
||||
populator_dst.load_metadata_and_associated_files(src_model_buf)
|
||||
populator_dst.populate()
|
||||
updated_model_buf = populator.get_model_buffer()
|
||||
with open("updated_model.tflite", "wb") as f:
|
||||
f.write(updated_model_buf)
|
||||
```
|
||||
|
||||
Note that existing metadata buffer (if applied) will be overridden by the new
|
||||
metadata buffer.
|
||||
"""
|
||||
# As Zip API is used to concatenate associated files after tflite model file,
|
||||
# the populating operation is developed based on a model file. For in-memory
|
||||
# model buffer, we create a tempfile to serve the populating operation.
|
||||
# Creating the deleting such a tempfile is handled by the class,
|
||||
# _MetadataPopulatorWithBuffer.
|
||||
|
||||
METADATA_FIELD_NAME = "TFLITE_METADATA"
|
||||
TFLITE_FILE_IDENTIFIER = b"TFL3"
|
||||
METADATA_FILE_IDENTIFIER = b"M001"
|
||||
|
||||
def __init__(self, model_file):
|
||||
"""Constructor for MetadataPopulator.
|
||||
|
||||
Args:
|
||||
model_file: valid path to a TensorFlow Lite model file.
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
ValueError: the model does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
_assert_model_file_identifier(model_file)
|
||||
self._model_file = model_file
|
||||
self._metadata_buf = None
|
||||
# _associated_files is a dict of file name and file buffer.
|
||||
self._associated_files = {}
|
||||
|
||||
@classmethod
|
||||
def with_model_file(cls, model_file):
|
||||
"""Creates a MetadataPopulator object that populates data to a model file.
|
||||
|
||||
Args:
|
||||
model_file: valid path to a TensorFlow Lite model file.
|
||||
|
||||
Returns:
|
||||
MetadataPopulator object.
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
ValueError: the model does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
return cls(model_file)
|
||||
|
||||
# TODO: investigate if type check can be applied to model_buf for
|
||||
# FB.
|
||||
@classmethod
|
||||
def with_model_buffer(cls, model_buf):
|
||||
"""Creates a MetadataPopulator object that populates data to a model buffer.
|
||||
|
||||
Args:
|
||||
model_buf: TensorFlow Lite model buffer in bytearray.
|
||||
|
||||
Returns:
|
||||
A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
|
||||
|
||||
Raises:
|
||||
ValueError: the model does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
return _MetadataPopulatorWithBuffer(model_buf)
|
||||
|
||||
def get_model_buffer(self):
|
||||
"""Gets the buffer of the model with packed metadata and associated files.
|
||||
|
||||
Returns:
|
||||
Model buffer (in bytearray).
|
||||
"""
|
||||
with _open_file(self._model_file, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def get_packed_associated_file_list(self):
|
||||
"""Gets a list of associated files packed to the model file.
|
||||
|
||||
Returns:
|
||||
List of packed associated files.
|
||||
"""
|
||||
if not _is_zipfile(self._model_file):
|
||||
return []
|
||||
|
||||
with _open_as_zipfile(self._model_file, "r") as zf:
|
||||
return zf.namelist()
|
||||
|
||||
def get_recorded_associated_file_list(self):
|
||||
"""Gets a list of associated files recorded in metadata of the model file.
|
||||
|
||||
Associated files may be attached to a model, a subgraph, or an input/output
|
||||
tensor.
|
||||
|
||||
Returns:
|
||||
List of recorded associated files.
|
||||
"""
|
||||
if not self._metadata_buf:
|
||||
return []
|
||||
|
||||
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
|
||||
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(
|
||||
self._metadata_buf, 0))
|
||||
|
||||
return [
|
||||
file.name.decode("utf-8")
|
||||
for file in self._get_recorded_associated_file_object_list(metadata)
|
||||
]
|
||||
|
||||
def load_associated_file_buffers(self, associated_files):
|
||||
"""Loads the associated file buffers (in bytearray) to be populated.
|
||||
|
||||
Args:
|
||||
associated_files: a dictionary of associated file names and corresponding
|
||||
file buffers, such as {"file.txt": b"file content"}. If pass in file
|
||||
paths for the file name, only the basename will be populated.
|
||||
"""
|
||||
|
||||
self._associated_files.update({
|
||||
os.path.basename(name): buffers
|
||||
for name, buffers in associated_files.items()
|
||||
})
|
||||
|
||||
def load_associated_files(self, associated_files):
|
||||
"""Loads associated files that to be concatenated after the model file.
|
||||
|
||||
Args:
|
||||
associated_files: list of file paths.
|
||||
|
||||
Raises:
|
||||
IOError:
|
||||
File not found.
|
||||
"""
|
||||
for af_name in associated_files:
|
||||
_assert_file_exist(af_name)
|
||||
with _open_file(af_name, "rb") as af:
|
||||
self.load_associated_file_buffers({af_name: af.read()})
|
||||
|
||||
def load_metadata_buffer(self, metadata_buf):
|
||||
"""Loads the metadata buffer (in bytearray) to be populated.
|
||||
|
||||
Args:
|
||||
metadata_buf: metadata buffer (in bytearray) to be populated.
|
||||
|
||||
Raises:
|
||||
ValueError: The metadata to be populated is empty.
|
||||
ValueError: The metadata does not have the expected flatbuffer identifer.
|
||||
ValueError: Cannot get minimum metadata parser version.
|
||||
ValueError: The number of SubgraphMetadata is not 1.
|
||||
ValueError: The number of input/output tensors does not match the number
|
||||
of input/output tensor metadata.
|
||||
"""
|
||||
if not metadata_buf:
|
||||
raise ValueError("The metadata to be populated is empty.")
|
||||
|
||||
self._validate_metadata(metadata_buf)
|
||||
|
||||
# Gets the minimum metadata parser version of the metadata_buf.
|
||||
min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion(
|
||||
bytes(metadata_buf))
|
||||
|
||||
# Inserts in the minimum metadata parser version into the metadata_buf.
|
||||
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
|
||||
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
|
||||
metadata.minParserVersion = min_version
|
||||
|
||||
# Remove local file directory in the `name` field of `AssociatedFileT`, and
|
||||
# make it consistent with the name of the actual file packed in the model.
|
||||
self._use_basename_for_associated_files_in_metadata(metadata)
|
||||
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(metadata.Pack(b), self.METADATA_FILE_IDENTIFIER)
|
||||
metadata_buf_with_version = b.Output()
|
||||
|
||||
self._metadata_buf = metadata_buf_with_version
|
||||
|
||||
def load_metadata_file(self, metadata_file):
|
||||
"""Loads the metadata file to be populated.
|
||||
|
||||
Args:
|
||||
metadata_file: path to the metadata file to be populated.
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
ValueError: The metadata to be populated is empty.
|
||||
ValueError: The metadata does not have the expected flatbuffer identifer.
|
||||
ValueError: Cannot get minimum metadata parser version.
|
||||
ValueError: The number of SubgraphMetadata is not 1.
|
||||
ValueError: The number of input/output tensors does not match the number
|
||||
of input/output tensor metadata.
|
||||
"""
|
||||
_assert_file_exist(metadata_file)
|
||||
with _open_file(metadata_file, "rb") as f:
|
||||
metadata_buf = f.read()
|
||||
self.load_metadata_buffer(bytearray(metadata_buf))
|
||||
|
||||
def load_metadata_and_associated_files(self, src_model_buf):
|
||||
"""Loads the metadata and associated files from another model buffer.
|
||||
|
||||
Args:
|
||||
src_model_buf: source model buffer (in bytearray) with metadata and
|
||||
associated files.
|
||||
"""
|
||||
# Load the model metadata from src_model_buf if exist.
|
||||
metadata_buffer = get_metadata_buffer(src_model_buf)
|
||||
if metadata_buffer:
|
||||
self.load_metadata_buffer(metadata_buffer)
|
||||
|
||||
# Load the associated files from src_model_buf if exist.
|
||||
if _is_zipfile(io.BytesIO(src_model_buf)):
|
||||
with _open_as_zipfile(io.BytesIO(src_model_buf)) as zf:
|
||||
self.load_associated_file_buffers(
|
||||
{f: zf.read(f) for f in zf.namelist()})
|
||||
|
||||
def populate(self):
|
||||
"""Populates loaded metadata and associated files into the model file."""
|
||||
self._assert_validate()
|
||||
self._populate_metadata_buffer()
|
||||
self._populate_associated_files()
|
||||
|
||||
def _assert_validate(self):
|
||||
"""Validates the metadata and associated files to be populated.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
File is recorded in the metadata, but is not going to be populated.
|
||||
File has already been packed.
|
||||
"""
|
||||
# Gets files that are recorded in metadata.
|
||||
recorded_files = self.get_recorded_associated_file_list()
|
||||
|
||||
# Gets files that have been packed to self._model_file.
|
||||
packed_files = self.get_packed_associated_file_list()
|
||||
|
||||
# Gets the file name of those associated files to be populated.
|
||||
to_be_populated_files = self._associated_files.keys()
|
||||
|
||||
# Checks all files recorded in the metadata will be populated.
|
||||
for rf in recorded_files:
|
||||
if rf not in to_be_populated_files and rf not in packed_files:
|
||||
raise ValueError("File, '{0}', is recorded in the metadata, but has "
|
||||
"not been loaded into the populator.".format(rf))
|
||||
|
||||
for f in to_be_populated_files:
|
||||
if f in packed_files:
|
||||
raise ValueError("File, '{0}', has already been packed.".format(f))
|
||||
|
||||
if f not in recorded_files:
|
||||
warnings.warn(
|
||||
"File, '{0}', does not exist in the metadata. But packing it to "
|
||||
"tflite model is still allowed.".format(f))
|
||||
|
||||
def _copy_archived_files(self, src_zip, file_list, dst_zip):
|
||||
"""Copy archieved files in file_list from src_zip ro dst_zip."""
|
||||
|
||||
if not _is_zipfile(src_zip):
|
||||
raise ValueError("File, '{0}', is not a zipfile.".format(src_zip))
|
||||
|
||||
with _open_as_zipfile(src_zip, "r") as src_zf, \
|
||||
_open_as_zipfile(dst_zip, "a") as dst_zf:
|
||||
src_list = src_zf.namelist()
|
||||
for f in file_list:
|
||||
if f not in src_list:
|
||||
raise ValueError(
|
||||
"File, '{0}', does not exist in the zipfile, {1}.".format(
|
||||
f, src_zip))
|
||||
file_buffer = src_zf.read(f)
|
||||
dst_zf.writestr(f, file_buffer)
|
||||
|
||||
def _get_associated_files_from_process_units(self, table, field_name):
|
||||
"""Gets the files that are attached the process units field of a table.
|
||||
|
||||
Args:
|
||||
table: a Flatbuffers table object that contains fields of an array of
|
||||
ProcessUnit, such as TensorMetadata and SubGraphMetadata.
|
||||
field_name: the name of the field in the table that represents an array of
|
||||
ProcessUnit. If the table is TensorMetadata, field_name can be
|
||||
"ProcessUnits". If the table is SubGraphMetadata, field_name can be
|
||||
either "InputProcessUnits" or "OutputProcessUnits".
|
||||
|
||||
Returns:
|
||||
A list of AssociatedFileT objects.
|
||||
"""
|
||||
|
||||
if table is None:
|
||||
return []
|
||||
|
||||
file_list = []
|
||||
process_units = getattr(table, field_name)
|
||||
# If the process_units field is not populated, it will be None. Use an
|
||||
# empty list to skip the check.
|
||||
for process_unit in process_units or []:
|
||||
options = process_unit.options
|
||||
if isinstance(options, (_metadata_fb.BertTokenizerOptionsT,
|
||||
_metadata_fb.RegexTokenizerOptionsT)):
|
||||
file_list += self._get_associated_files_from_table(options, "vocabFile")
|
||||
elif isinstance(options, _metadata_fb.SentencePieceTokenizerOptionsT):
|
||||
file_list += self._get_associated_files_from_table(
|
||||
options, "sentencePieceModel")
|
||||
file_list += self._get_associated_files_from_table(options, "vocabFile")
|
||||
return file_list
|
||||
|
||||
def _get_associated_files_from_table(self, table, field_name):
|
||||
"""Gets the associated files that are attached a table directly.
|
||||
|
||||
Args:
|
||||
table: a Flatbuffers table object that contains fields of an array of
|
||||
AssociatedFile, such as TensorMetadata and BertTokenizerOptions.
|
||||
field_name: the name of the field in the table that represents an array of
|
||||
ProcessUnit. If the table is TensorMetadata, field_name can be
|
||||
"AssociatedFiles". If the table is BertTokenizerOptions, field_name can
|
||||
be "VocabFile".
|
||||
|
||||
Returns:
|
||||
A list of AssociatedFileT objects.
|
||||
"""
|
||||
|
||||
if table is None:
|
||||
return []
|
||||
|
||||
# If the associated file field is not populated,
|
||||
# `getattr(table, field_name)` will be None. Return an empty list.
|
||||
return getattr(table, field_name) or []
|
||||
|
||||
def _get_recorded_associated_file_object_list(self, metadata):
|
||||
"""Gets a list of AssociatedFileT objects recorded in the metadata.
|
||||
|
||||
Associated files may be attached to a model, a subgraph, or an input/output
|
||||
tensor.
|
||||
|
||||
Args:
|
||||
metadata: the ModelMetadataT object.
|
||||
|
||||
Returns:
|
||||
List of recorded AssociatedFileT objects.
|
||||
"""
|
||||
recorded_files = []
|
||||
|
||||
# Add associated files attached to ModelMetadata.
|
||||
recorded_files += self._get_associated_files_from_table(
|
||||
metadata, "associatedFiles")
|
||||
|
||||
# Add associated files attached to each SubgraphMetadata.
|
||||
for subgraph in metadata.subgraphMetadata or []:
|
||||
recorded_files += self._get_associated_files_from_table(
|
||||
subgraph, "associatedFiles")
|
||||
|
||||
# Add associated files attached to each input tensor.
|
||||
for tensor_metadata in subgraph.inputTensorMetadata or []:
|
||||
recorded_files += self._get_associated_files_from_table(
|
||||
tensor_metadata, "associatedFiles")
|
||||
recorded_files += self._get_associated_files_from_process_units(
|
||||
tensor_metadata, "processUnits")
|
||||
|
||||
# Add associated files attached to each output tensor.
|
||||
for tensor_metadata in subgraph.outputTensorMetadata or []:
|
||||
recorded_files += self._get_associated_files_from_table(
|
||||
tensor_metadata, "associatedFiles")
|
||||
recorded_files += self._get_associated_files_from_process_units(
|
||||
tensor_metadata, "processUnits")
|
||||
|
||||
# Add associated files attached to the input_process_units.
|
||||
recorded_files += self._get_associated_files_from_process_units(
|
||||
subgraph, "inputProcessUnits")
|
||||
|
||||
# Add associated files attached to the output_process_units.
|
||||
recorded_files += self._get_associated_files_from_process_units(
|
||||
subgraph, "outputProcessUnits")
|
||||
|
||||
return recorded_files
|
||||
|
||||
def _populate_associated_files(self):
|
||||
"""Concatenates associated files after TensorFlow Lite model file.
|
||||
|
||||
If the MetadataPopulator object is created using the method,
|
||||
with_model_file(model_file), the model file will be updated.
|
||||
"""
|
||||
# Opens up the model file in "appending" mode.
|
||||
# If self._model_file already has pack files, zipfile will concatenate
|
||||
# addition files after self._model_file. For example, suppose we have
|
||||
# self._model_file = old_tflite_file | label1.txt | label2.txt
|
||||
# Then after trigger populate() to add label3.txt, self._model_file becomes
|
||||
# self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt
|
||||
with tempfile.SpooledTemporaryFile() as temp:
|
||||
# (1) Copy content from model file of to temp file.
|
||||
with _open_file(self._model_file, "rb") as f:
|
||||
shutil.copyfileobj(f, temp)
|
||||
|
||||
# (2) Append of to a temp file as a zip.
|
||||
with _open_as_zipfile(temp, "a") as zf:
|
||||
for file_name, file_buffer in self._associated_files.items():
|
||||
zf.writestr(file_name, file_buffer)
|
||||
|
||||
# (3) Copy temp file to model file.
|
||||
temp.seek(0)
|
||||
with _open_file(self._model_file, "wb") as f:
|
||||
shutil.copyfileobj(temp, f)
|
||||
|
||||
def _populate_metadata_buffer(self):
|
||||
"""Populates the metadata buffer (in bytearray) into the model file.
|
||||
|
||||
Inserts metadata_buf into the metadata field of schema.Model. If the
|
||||
MetadataPopulator object is created using the method,
|
||||
with_model_file(model_file), the model file will be updated.
|
||||
|
||||
Existing metadata buffer (if applied) will be overridden by the new metadata
|
||||
buffer.
|
||||
"""
|
||||
|
||||
with _open_file(self._model_file, "rb") as f:
|
||||
model_buf = f.read()
|
||||
|
||||
model = _schema_fb.ModelT.InitFromObj(
|
||||
_schema_fb.Model.GetRootAsModel(model_buf, 0))
|
||||
buffer_field = _schema_fb.BufferT()
|
||||
buffer_field.data = self._metadata_buf
|
||||
|
||||
is_populated = False
|
||||
if not model.metadata:
|
||||
model.metadata = []
|
||||
else:
|
||||
# Check if metadata has already been populated.
|
||||
for meta in model.metadata:
|
||||
if meta.name.decode("utf-8") == self.METADATA_FIELD_NAME:
|
||||
is_populated = True
|
||||
model.buffers[meta.buffer] = buffer_field
|
||||
|
||||
if not is_populated:
|
||||
if not model.buffers:
|
||||
model.buffers = []
|
||||
model.buffers.append(buffer_field)
|
||||
# Creates a new metadata field.
|
||||
metadata_field = _schema_fb.MetadataT()
|
||||
metadata_field.name = self.METADATA_FIELD_NAME
|
||||
metadata_field.buffer = len(model.buffers) - 1
|
||||
model.metadata.append(metadata_field)
|
||||
|
||||
# Packs model back to a flatbuffer binaray file.
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(model.Pack(b), self.TFLITE_FILE_IDENTIFIER)
|
||||
model_buf = b.Output()
|
||||
|
||||
# Saves the updated model buffer to model file.
|
||||
# Gets files that have been packed to self._model_file.
|
||||
packed_files = self.get_packed_associated_file_list()
|
||||
if packed_files:
|
||||
# Writes the updated model buffer and associated files into a new model
|
||||
# file (in memory). Then overwrites the original model file.
|
||||
with tempfile.SpooledTemporaryFile() as temp:
|
||||
temp.write(model_buf)
|
||||
self._copy_archived_files(self._model_file, packed_files, temp)
|
||||
temp.seek(0)
|
||||
with _open_file(self._model_file, "wb") as f:
|
||||
shutil.copyfileobj(temp, f)
|
||||
else:
|
||||
with _open_file(self._model_file, "wb") as f:
|
||||
f.write(model_buf)
|
||||
|
||||
def _use_basename_for_associated_files_in_metadata(self, metadata):
|
||||
"""Removes any associated file local directory (if exists)."""
|
||||
for file in self._get_recorded_associated_file_object_list(metadata):
|
||||
file.name = os.path.basename(file.name)
|
||||
|
||||
def _validate_metadata(self, metadata_buf):
|
||||
"""Validates the metadata to be populated."""
|
||||
_assert_metadata_buffer_identifier(metadata_buf)
|
||||
|
||||
# Verify the number of SubgraphMetadata is exactly one.
|
||||
# TFLite currently only support one subgraph.
|
||||
model_meta = _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
|
||||
metadata_buf, 0)
|
||||
if model_meta.SubgraphMetadataLength() != 1:
|
||||
raise ValueError("The number of SubgraphMetadata should be exactly one, "
|
||||
"but got {0}.".format(
|
||||
model_meta.SubgraphMetadataLength()))
|
||||
|
||||
# Verify if the number of tensor metadata matches the number of tensors.
|
||||
with _open_file(self._model_file, "rb") as f:
|
||||
model_buf = f.read()
|
||||
model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
|
||||
|
||||
num_input_tensors = model.Subgraphs(0).InputsLength()
|
||||
num_input_meta = model_meta.SubgraphMetadata(0).InputTensorMetadataLength()
|
||||
if num_input_tensors != num_input_meta:
|
||||
raise ValueError(
|
||||
"The number of input tensors ({0}) should match the number of "
|
||||
"input tensor metadata ({1})".format(num_input_tensors,
|
||||
num_input_meta))
|
||||
num_output_tensors = model.Subgraphs(0).OutputsLength()
|
||||
num_output_meta = model_meta.SubgraphMetadata(
|
||||
0).OutputTensorMetadataLength()
|
||||
if num_output_tensors != num_output_meta:
|
||||
raise ValueError(
|
||||
"The number of output tensors ({0}) should match the number of "
|
||||
"output tensor metadata ({1})".format(num_output_tensors,
|
||||
num_output_meta))
|
||||
|
||||
|
||||
class _MetadataPopulatorWithBuffer(MetadataPopulator):
|
||||
"""Subclass of MetadtaPopulator that populates metadata to a model buffer.
|
||||
|
||||
This class is used to populate metadata into a in-memory model buffer. As we
|
||||
use Zip API to concatenate associated files after tflite model file, the
|
||||
populating operation is developed based on a model file. For in-memory model
|
||||
buffer, we create a tempfile to serve the populating operation. This class is
|
||||
then used to generate this tempfile, and delete the file when the
|
||||
MetadataPopulator object is deleted.
|
||||
"""
|
||||
|
||||
def __init__(self, model_buf):
|
||||
"""Constructor for _MetadataPopulatorWithBuffer.
|
||||
|
||||
Args:
|
||||
model_buf: TensorFlow Lite model buffer in bytearray.
|
||||
|
||||
Raises:
|
||||
ValueError: model_buf is empty.
|
||||
ValueError: model_buf does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
if not model_buf:
|
||||
raise ValueError("model_buf cannot be empty.")
|
||||
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
model_file = temp.name
|
||||
|
||||
with _open_file(model_file, "wb") as f:
|
||||
f.write(model_buf)
|
||||
|
||||
super().__init__(model_file)
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor of _MetadataPopulatorWithBuffer.
|
||||
|
||||
Deletes the tempfile.
|
||||
"""
|
||||
if os.path.exists(self._model_file):
|
||||
os.remove(self._model_file)
|
||||
|
||||
|
||||
class MetadataDisplayer(object):
|
||||
"""Displays metadata and associated file info in human-readable format."""
|
||||
|
||||
def __init__(self, model_buffer, metadata_buffer, associated_file_list):
|
||||
"""Constructor for MetadataDisplayer.
|
||||
|
||||
Args:
|
||||
model_buffer: valid buffer of the model file.
|
||||
metadata_buffer: valid buffer of the metadata file.
|
||||
associated_file_list: list of associate files in the model file.
|
||||
"""
|
||||
_assert_model_buffer_identifier(model_buffer)
|
||||
_assert_metadata_buffer_identifier(metadata_buffer)
|
||||
self._model_buffer = model_buffer
|
||||
self._metadata_buffer = metadata_buffer
|
||||
self._associated_file_list = associated_file_list
|
||||
|
||||
@classmethod
|
||||
def with_model_file(cls, model_file):
|
||||
"""Creates a MetadataDisplayer object for the model file.
|
||||
|
||||
Args:
|
||||
model_file: valid path to a TensorFlow Lite model file.
|
||||
|
||||
Returns:
|
||||
MetadataDisplayer object.
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
ValueError: The model does not have metadata.
|
||||
"""
|
||||
_assert_file_exist(model_file)
|
||||
with _open_file(model_file, "rb") as f:
|
||||
return cls.with_model_buffer(f.read())
|
||||
|
||||
@classmethod
|
||||
def with_model_buffer(cls, model_buffer):
|
||||
"""Creates a MetadataDisplayer object for a file buffer.
|
||||
|
||||
Args:
|
||||
model_buffer: TensorFlow Lite model buffer in bytearray.
|
||||
|
||||
Returns:
|
||||
MetadataDisplayer object.
|
||||
"""
|
||||
if not model_buffer:
|
||||
raise ValueError("model_buffer cannot be empty.")
|
||||
metadata_buffer = get_metadata_buffer(model_buffer)
|
||||
if not metadata_buffer:
|
||||
raise ValueError("The model does not have metadata.")
|
||||
associated_file_list = cls._parse_packed_associted_file_list(model_buffer)
|
||||
return cls(model_buffer, metadata_buffer, associated_file_list)
|
||||
|
||||
def get_associated_file_buffer(self, filename):
|
||||
"""Get the specified associated file content in bytearray.
|
||||
|
||||
Args:
|
||||
filename: name of the file to be extracted.
|
||||
|
||||
Returns:
|
||||
The file content in bytearray.
|
||||
|
||||
Raises:
|
||||
ValueError: if the file does not exist in the model.
|
||||
"""
|
||||
if filename not in self._associated_file_list:
|
||||
raise ValueError(
|
||||
"The file, {}, does not exist in the model.".format(filename))
|
||||
|
||||
with _open_as_zipfile(io.BytesIO(self._model_buffer)) as zf:
|
||||
return zf.read(filename)
|
||||
|
||||
def get_metadata_buffer(self):
|
||||
"""Get the metadata buffer in bytearray out from the model."""
|
||||
return copy.deepcopy(self._metadata_buffer)
|
||||
|
||||
def get_metadata_json(self):
|
||||
"""Converts the metadata into a json string."""
|
||||
return convert_to_json(self._metadata_buffer)
|
||||
|
||||
def get_packed_associated_file_list(self):
|
||||
"""Returns a list of associated files that are packed in the model.
|
||||
|
||||
Returns:
|
||||
A name list of associated files.
|
||||
"""
|
||||
return copy.deepcopy(self._associated_file_list)
|
||||
|
||||
@staticmethod
|
||||
def _parse_packed_associted_file_list(model_buf):
|
||||
"""Gets a list of associated files packed to the model file.
|
||||
|
||||
Args:
|
||||
model_buf: valid file buffer.
|
||||
|
||||
Returns:
|
||||
List of packed associated files.
|
||||
"""
|
||||
|
||||
try:
|
||||
with _open_as_zipfile(io.BytesIO(model_buf)) as zf:
|
||||
return zf.namelist()
|
||||
except zipfile.BadZipFile:
|
||||
return []
|
||||
|
||||
|
||||
# Create an individual method for getting the metadata json file, so that it can
|
||||
# be used as a standalone util.
|
||||
def convert_to_json(metadata_buffer):
|
||||
"""Converts the metadata into a json string.
|
||||
|
||||
Args:
|
||||
metadata_buffer: valid metadata buffer in bytes.
|
||||
|
||||
Returns:
|
||||
Metadata in JSON format.
|
||||
|
||||
Raises:
|
||||
ValueError: error occured when parsing the metadata schema file.
|
||||
"""
|
||||
|
||||
opt = _pywrap_flatbuffers.IDLOptions()
|
||||
opt.strict_json = True
|
||||
parser = _pywrap_flatbuffers.Parser(opt)
|
||||
with _open_file(_FLATC_TFLITE_METADATA_SCHEMA_FILE) as f:
|
||||
metadata_schema_content = f.read()
|
||||
if not parser.parse(metadata_schema_content):
|
||||
raise ValueError("Cannot parse metadata schema. Reason: " + parser.error)
|
||||
return _pywrap_flatbuffers.generate_text(parser, metadata_buffer)
|
||||
|
||||
|
||||
def _assert_file_exist(filename):
|
||||
"""Checks if a file exists."""
|
||||
if not _exists_file(filename):
|
||||
raise IOError("File, '{0}', does not exist.".format(filename))
|
||||
|
||||
|
||||
def _assert_model_file_identifier(model_file):
|
||||
"""Checks if a model file has the expected TFLite schema identifier."""
|
||||
_assert_file_exist(model_file)
|
||||
with _open_file(model_file, "rb") as f:
|
||||
_assert_model_buffer_identifier(f.read())
|
||||
|
||||
|
||||
def _assert_model_buffer_identifier(model_buf):
|
||||
if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0):
|
||||
raise ValueError(
|
||||
"The model provided does not have the expected identifier, and "
|
||||
"may not be a valid TFLite model.")
|
||||
|
||||
|
||||
def _assert_metadata_buffer_identifier(metadata_buf):
|
||||
"""Checks if a metadata buffer has the expected Metadata schema identifier."""
|
||||
if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier(
|
||||
metadata_buf, 0):
|
||||
raise ValueError(
|
||||
"The metadata buffer does not have the expected identifier, and may not"
|
||||
" be a valid TFLite Metadata.")
|
||||
|
||||
|
||||
def get_metadata_buffer(model_buf):
|
||||
"""Returns the metadata in the model file as a buffer.
|
||||
|
||||
Args:
|
||||
model_buf: valid buffer of the model file.
|
||||
|
||||
Returns:
|
||||
Metadata buffer. Returns `None` if the model does not have metadata.
|
||||
"""
|
||||
tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
|
||||
|
||||
# Gets metadata from the model file.
|
||||
for i in range(tflite_model.MetadataLength()):
|
||||
meta = tflite_model.Metadata(i)
|
||||
if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
|
||||
buffer_index = meta.Buffer()
|
||||
metadata = tflite_model.Buffers(buffer_index)
|
||||
return metadata.DataAsNumpy().tobytes()
|
||||
|
||||
return None
|
34
mediapipe/tasks/python/metadata/metadata_displayer_cli.py
Normal file
34
mediapipe/tasks/python/metadata/metadata_displayer_cli.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""CLI tool for display metadata."""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from mediapipe.tasks.python.metadata import metadata
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string('model_path', None, 'Path to the TFLite model file.')
|
||||
flags.DEFINE_string('export_json_path', None, 'Path to the output JSON file.')
|
||||
|
||||
|
||||
def main(_):
|
||||
displayer = metadata.MetadataDisplayer.with_model_file(FLAGS.model_path)
|
||||
with open(FLAGS.export_json_path, 'w') as f:
|
||||
f.write(displayer.get_metadata_json())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
26
mediapipe/tasks/python/metadata/metadata_parser.py.template
Normal file
26
mediapipe/tasks/python/metadata/metadata_parser.py.template
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Information about the metadata parser that this python library depends on."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class MetadataParser(object):
|
||||
"""Information about the metadata parser."""
|
||||
|
||||
# The version of the metadata parser.
|
||||
VERSION = "{LATEST_METADATA_PARSER_VERSION}"
|
|
@ -19,9 +19,9 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
licenses(["notice"])
|
||||
|
||||
py_library(
|
||||
name = "test_util",
|
||||
name = "test_utils",
|
||||
testonly = 1,
|
||||
srcs = ["test_util.py"],
|
||||
srcs = ["test_utils.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
|
|
31
mediapipe/tasks/python/test/metadata/BUILD
Normal file
31
mediapipe/tasks/python/test/metadata/BUILD
Normal file
|
@ -0,0 +1,31 @@
|
|||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "metadata_test",
|
||||
srcs = ["metadata_test.py"],
|
||||
data = ["//mediapipe/tasks/testdata/metadata:data_files"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//mediapipe/tasks/metadata:metadata_schema_py",
|
||||
"//mediapipe/tasks/metadata:schema_py",
|
||||
"//mediapipe/tasks/python/metadata",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"@flatbuffers//:runtime_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "metadata_parser_test",
|
||||
srcs = ["metadata_parser_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/metadata",
|
||||
],
|
||||
)
|
37
mediapipe/tasks/python/test/metadata/metadata_parser_test.py
Normal file
37
mediapipe/tasks/python/test/metadata/metadata_parser_test.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for mediapipe.tasks.python.metadata.metadata_parser."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
from absl.testing import absltest
|
||||
|
||||
from mediapipe.tasks.python.metadata import metadata_parser
|
||||
|
||||
|
||||
class MetadataParserTest(absltest.TestCase):
|
||||
|
||||
def testVersionWellFormedSemanticVersion(self):
|
||||
# Validates that the version is well-formed (x.y.z).
|
||||
self.assertTrue(
|
||||
re.match('[0-9]+\\.[0-9]+\\.[0-9]+',
|
||||
metadata_parser.MetadataParser.VERSION))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
857
mediapipe/tasks/python/test/metadata/metadata_test.py
Normal file
857
mediapipe/tasks/python/test/metadata/metadata_test.py
Normal file
|
@ -0,0 +1,857 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for mediapipe.tasks.python.metadata.metadata."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import six
|
||||
|
||||
import flatbuffers
|
||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
|
||||
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
||||
from mediapipe.tasks.python.metadata import metadata as _metadata
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
||||
class Tokenizer(enum.Enum):
|
||||
BERT_TOKENIZER = 0
|
||||
SENTENCE_PIECE = 1
|
||||
|
||||
|
||||
class TensorType(enum.Enum):
|
||||
INPUT = 0
|
||||
OUTPUT = 1
|
||||
|
||||
|
||||
def _read_file(file_name, mode="rb"):
|
||||
with open(file_name, mode) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
class MetadataTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(MetadataTest, self).setUp()
|
||||
self._invalid_model_buf = None
|
||||
self._invalid_file = "not_existed_file"
|
||||
self._model_buf = self._create_model_buf()
|
||||
self._model_file = self.create_tempfile().full_path
|
||||
with open(self._model_file, "wb") as f:
|
||||
f.write(self._model_buf)
|
||||
self._metadata_file = self._create_metadata_file()
|
||||
self._metadata_file_with_version = self._create_metadata_file_with_version(
|
||||
self._metadata_file, "1.0.0")
|
||||
self._file1 = self.create_tempfile("file1").full_path
|
||||
self._file2 = self.create_tempfile("file2").full_path
|
||||
self._file2_content = b"file2_content"
|
||||
with open(self._file2, "wb") as f:
|
||||
f.write(self._file2_content)
|
||||
self._file3 = self.create_tempfile("file3").full_path
|
||||
|
||||
def _create_model_buf(self):
|
||||
# Create a model with two inputs and one output, which matches the metadata
|
||||
# created by _create_metadata_file().
|
||||
metadata_field = _schema_fb.MetadataT()
|
||||
subgraph = _schema_fb.SubGraphT()
|
||||
subgraph.inputs = [0, 1]
|
||||
subgraph.outputs = [2]
|
||||
|
||||
metadata_field.name = "meta"
|
||||
buffer_field = _schema_fb.BufferT()
|
||||
model = _schema_fb.ModelT()
|
||||
model.subgraphs = [subgraph]
|
||||
# Creates the metadata and buffer fields for testing purposes.
|
||||
model.metadata = [metadata_field, metadata_field]
|
||||
model.buffers = [buffer_field, buffer_field, buffer_field]
|
||||
model_builder = flatbuffers.Builder(0)
|
||||
model_builder.Finish(
|
||||
model.Pack(model_builder),
|
||||
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
|
||||
return model_builder.Output()
|
||||
|
||||
def _create_metadata_file(self):
|
||||
associated_file1 = _metadata_fb.AssociatedFileT()
|
||||
associated_file1.name = b"file1"
|
||||
associated_file2 = _metadata_fb.AssociatedFileT()
|
||||
associated_file2.name = b"file2"
|
||||
self.expected_recorded_files = [
|
||||
six.ensure_str(associated_file1.name),
|
||||
six.ensure_str(associated_file2.name)
|
||||
]
|
||||
|
||||
input_meta = _metadata_fb.TensorMetadataT()
|
||||
output_meta = _metadata_fb.TensorMetadataT()
|
||||
output_meta.associatedFiles = [associated_file2]
|
||||
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||
# Create a model with two inputs and one output.
|
||||
subgraph.inputTensorMetadata = [input_meta, input_meta]
|
||||
subgraph.outputTensorMetadata = [output_meta]
|
||||
|
||||
model_meta = _metadata_fb.ModelMetadataT()
|
||||
model_meta.name = "Mobilenet_quantized"
|
||||
model_meta.associatedFiles = [associated_file1]
|
||||
model_meta.subgraphMetadata = [subgraph]
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(
|
||||
model_meta.Pack(b),
|
||||
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
|
||||
metadata_file = self.create_tempfile().full_path
|
||||
with open(metadata_file, "wb") as f:
|
||||
f.write(b.Output())
|
||||
return metadata_file
|
||||
|
||||
def _create_model_buffer_with_wrong_identifier(self):
|
||||
wrong_identifier = b"widn"
|
||||
model = _schema_fb.ModelT()
|
||||
model_builder = flatbuffers.Builder(0)
|
||||
model_builder.Finish(model.Pack(model_builder), wrong_identifier)
|
||||
return model_builder.Output()
|
||||
|
||||
def _create_metadata_buffer_with_wrong_identifier(self):
|
||||
# Creates a metadata with wrong identifier
|
||||
wrong_identifier = b"widn"
|
||||
metadata = _metadata_fb.ModelMetadataT()
|
||||
metadata_builder = flatbuffers.Builder(0)
|
||||
metadata_builder.Finish(metadata.Pack(metadata_builder), wrong_identifier)
|
||||
return metadata_builder.Output()
|
||||
|
||||
def _populate_metadata_with_identifier(self, model_buf, metadata_buf,
|
||||
identifier):
|
||||
# For testing purposes only. MetadataPopulator cannot populate metadata with
|
||||
# wrong identifiers.
|
||||
model = _schema_fb.ModelT.InitFromObj(
|
||||
_schema_fb.Model.GetRootAsModel(model_buf, 0))
|
||||
buffer_field = _schema_fb.BufferT()
|
||||
buffer_field.data = metadata_buf
|
||||
model.buffers = [buffer_field]
|
||||
# Creates a new metadata field.
|
||||
metadata_field = _schema_fb.MetadataT()
|
||||
metadata_field.name = _metadata.MetadataPopulator.METADATA_FIELD_NAME
|
||||
metadata_field.buffer = len(model.buffers) - 1
|
||||
model.metadata = [metadata_field]
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(model.Pack(b), identifier)
|
||||
return b.Output()
|
||||
|
||||
def _create_metadata_file_with_version(self, metadata_file, min_version):
|
||||
# Creates a new metadata file with the specified min_version for testing
|
||||
# purposes.
|
||||
metadata_buf = bytearray(_read_file(metadata_file))
|
||||
|
||||
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
|
||||
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
|
||||
metadata.minParserVersion = min_version
|
||||
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(
|
||||
metadata.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
|
||||
metadata_file_with_version = self.create_tempfile().full_path
|
||||
with open(metadata_file_with_version, "wb") as f:
|
||||
f.write(b.Output())
|
||||
return metadata_file_with_version
|
||||
|
||||
|
||||
class MetadataPopulatorTest(MetadataTest):
|
||||
|
||||
def _create_bert_tokenizer(self):
|
||||
vocab_file_name = "bert_vocab"
|
||||
vocab = _metadata_fb.AssociatedFileT()
|
||||
vocab.name = vocab_file_name
|
||||
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
|
||||
tokenizer = _metadata_fb.ProcessUnitT()
|
||||
tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions
|
||||
tokenizer.options = _metadata_fb.BertTokenizerOptionsT()
|
||||
tokenizer.options.vocabFile = [vocab]
|
||||
return tokenizer, [vocab_file_name]
|
||||
|
||||
def _create_sentence_piece_tokenizer(self):
|
||||
sp_model_name = "sp_model"
|
||||
vocab_file_name = "sp_vocab"
|
||||
sp_model = _metadata_fb.AssociatedFileT()
|
||||
sp_model.name = sp_model_name
|
||||
vocab = _metadata_fb.AssociatedFileT()
|
||||
vocab.name = vocab_file_name
|
||||
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
|
||||
tokenizer = _metadata_fb.ProcessUnitT()
|
||||
tokenizer.optionsType = (
|
||||
_metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions)
|
||||
tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT()
|
||||
tokenizer.options.sentencePieceModel = [sp_model]
|
||||
tokenizer.options.vocabFile = [vocab]
|
||||
return tokenizer, [sp_model_name, vocab_file_name]
|
||||
|
||||
def _create_tokenizer(self, tokenizer_type):
|
||||
if tokenizer_type is Tokenizer.BERT_TOKENIZER:
|
||||
return self._create_bert_tokenizer()
|
||||
elif tokenizer_type is Tokenizer.SENTENCE_PIECE:
|
||||
return self._create_sentence_piece_tokenizer()
|
||||
else:
|
||||
raise ValueError(
|
||||
"The tokenizer type, {0}, is unsupported.".format(tokenizer_type))
|
||||
|
||||
def _create_tempfiles(self, file_names):
|
||||
tempfiles = []
|
||||
for name in file_names:
|
||||
tempfiles.append(self.create_tempfile(name).full_path)
|
||||
return tempfiles
|
||||
|
||||
def _create_model_meta_with_subgraph_meta(self, subgraph_meta):
|
||||
model_meta = _metadata_fb.ModelMetadataT()
|
||||
model_meta.subgraphMetadata = [subgraph_meta]
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(
|
||||
model_meta.Pack(b),
|
||||
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
return b.Output()
|
||||
|
||||
def testToValidModelFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
self.assertIsInstance(populator, _metadata.MetadataPopulator)
|
||||
|
||||
def testToInvalidModelFile(self):
|
||||
with self.assertRaises(IOError) as error:
|
||||
_metadata.MetadataPopulator.with_model_file(self._invalid_file)
|
||||
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||
str(error.exception))
|
||||
|
||||
def testToValidModelBuffer(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
|
||||
self.assertIsInstance(populator, _metadata.MetadataPopulator)
|
||||
|
||||
def testToInvalidModelBuffer(self):
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf)
|
||||
self.assertEqual("model_buf cannot be empty.", str(error.exception))
|
||||
|
||||
def testToModelBufferWithWrongIdentifier(self):
|
||||
model_buf = self._create_model_buffer_with_wrong_identifier()
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataPopulator.with_model_buffer(model_buf)
|
||||
self.assertEqual(
|
||||
"The model provided does not have the expected identifier, and "
|
||||
"may not be a valid TFLite model.", str(error.exception))
|
||||
|
||||
def testSinglePopulateAssociatedFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
|
||||
populator.load_associated_files([self._file1])
|
||||
populator.populate()
|
||||
|
||||
packed_files = populator.get_packed_associated_file_list()
|
||||
expected_packed_files = [os.path.basename(self._file1)]
|
||||
self.assertEqual(set(packed_files), set(expected_packed_files))
|
||||
|
||||
def testRepeatedPopulateAssociatedFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
populator.load_associated_files([self._file1, self._file2])
|
||||
# Loads file2 multiple times.
|
||||
populator.load_associated_files([self._file2])
|
||||
populator.populate()
|
||||
|
||||
packed_files = populator.get_packed_associated_file_list()
|
||||
expected_packed_files = [
|
||||
os.path.basename(self._file1),
|
||||
os.path.basename(self._file2)
|
||||
]
|
||||
self.assertLen(packed_files, 2)
|
||||
self.assertEqual(set(packed_files), set(expected_packed_files))
|
||||
|
||||
# Check if the model buffer read from file is the same as that read from
|
||||
# get_model_buffer().
|
||||
model_buf_from_file = _read_file(self._model_file)
|
||||
model_buf_from_getter = populator.get_model_buffer()
|
||||
self.assertEqual(model_buf_from_file, model_buf_from_getter)
|
||||
|
||||
def testPopulateInvalidAssociatedFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
|
||||
with self.assertRaises(IOError) as error:
|
||||
populator.load_associated_files([self._invalid_file])
|
||||
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||
str(error.exception))
|
||||
|
||||
def testPopulatePackedAssociatedFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
|
||||
populator.load_associated_files([self._file1])
|
||||
populator.populate()
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.load_associated_files([self._file1])
|
||||
populator.populate()
|
||||
self.assertEqual(
|
||||
"File, '{0}', has already been packed.".format(
|
||||
os.path.basename(self._file1)), str(error.exception))
|
||||
|
||||
def testLoadAssociatedFileBuffers(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
|
||||
file_buffer = _read_file(self._file1)
|
||||
populator.load_associated_file_buffers({self._file1: file_buffer})
|
||||
populator.populate()
|
||||
|
||||
packed_files = populator.get_packed_associated_file_list()
|
||||
expected_packed_files = [os.path.basename(self._file1)]
|
||||
self.assertEqual(set(packed_files), set(expected_packed_files))
|
||||
|
||||
def testRepeatedLoadAssociatedFileBuffers(self):
|
||||
file_buffer1 = _read_file(self._file1)
|
||||
file_buffer2 = _read_file(self._file2)
|
||||
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
|
||||
populator.load_associated_file_buffers({
|
||||
self._file1: file_buffer1,
|
||||
self._file2: file_buffer2
|
||||
})
|
||||
# Loads file2 multiple times.
|
||||
populator.load_associated_file_buffers({self._file2: file_buffer2})
|
||||
populator.populate()
|
||||
|
||||
packed_files = populator.get_packed_associated_file_list()
|
||||
expected_packed_files = [
|
||||
os.path.basename(self._file1),
|
||||
os.path.basename(self._file2)
|
||||
]
|
||||
self.assertEqual(set(packed_files), set(expected_packed_files))
|
||||
|
||||
# Check if the model buffer read from file is the same as that read from
|
||||
# get_model_buffer().
|
||||
model_buf_from_file = _read_file(self._model_file)
|
||||
model_buf_from_getter = populator.get_model_buffer()
|
||||
self.assertEqual(model_buf_from_file, model_buf_from_getter)
|
||||
|
||||
def testLoadPackedAssociatedFileBuffersFails(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
|
||||
file_buffer = _read_file(self._file1)
|
||||
populator.load_associated_file_buffers({self._file1: file_buffer})
|
||||
populator.populate()
|
||||
|
||||
# Load file1 again should fail.
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.load_associated_file_buffers({self._file1: file_buffer})
|
||||
populator.populate()
|
||||
self.assertEqual(
|
||||
"File, '{0}', has already been packed.".format(
|
||||
os.path.basename(self._file1)), str(error.exception))
|
||||
|
||||
def testGetPackedAssociatedFileList(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
|
||||
packed_files = populator.get_packed_associated_file_list()
|
||||
self.assertEqual(packed_files, [])
|
||||
|
||||
def testPopulateMetadataFileToEmptyModelFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
populator.load_metadata_file(self._metadata_file)
|
||||
populator.load_associated_files([self._file1, self._file2])
|
||||
populator.populate()
|
||||
|
||||
model_buf_from_file = _read_file(self._model_file)
|
||||
model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
|
||||
# self._model_file already has two elements in the metadata field, so the
|
||||
# populated TFLite metadata will be the third element.
|
||||
metadata_field = model.Metadata(2)
|
||||
self.assertEqual(
|
||||
six.ensure_str(metadata_field.Name()),
|
||||
six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
|
||||
|
||||
buffer_index = metadata_field.Buffer()
|
||||
buffer_data = model.Buffers(buffer_index)
|
||||
metadata_buf_np = buffer_data.DataAsNumpy()
|
||||
metadata_buf = metadata_buf_np.tobytes()
|
||||
expected_metadata_buf = bytearray(
|
||||
_read_file(self._metadata_file_with_version))
|
||||
self.assertEqual(metadata_buf, expected_metadata_buf)
|
||||
|
||||
recorded_files = populator.get_recorded_associated_file_list()
|
||||
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
|
||||
|
||||
# Up to now, we've proved the correctness of the model buffer that read from
|
||||
# file. Then we'll test if get_model_buffer() gives the same model buffer.
|
||||
model_buf_from_getter = populator.get_model_buffer()
|
||||
self.assertEqual(model_buf_from_file, model_buf_from_getter)
|
||||
|
||||
def testPopulateMetadataFileWithoutAssociatedFiles(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
populator.load_metadata_file(self._metadata_file)
|
||||
populator.load_associated_files([self._file1])
|
||||
# Suppose to populate self._file2, because it is recorded in the metadta.
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.populate()
|
||||
self.assertEqual(("File, '{0}', is recorded in the metadata, but has "
|
||||
"not been loaded into the populator.").format(
|
||||
os.path.basename(self._file2)), str(error.exception))
|
||||
|
||||
def testPopulateMetadataBufferWithWrongIdentifier(self):
|
||||
metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
|
||||
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.load_metadata_buffer(metadata_buf)
|
||||
self.assertEqual(
|
||||
"The metadata buffer does not have the expected identifier, and may not"
|
||||
" be a valid TFLite Metadata.", str(error.exception))
|
||||
|
||||
def _assert_golden_metadata(self, model_file):
|
||||
model_buf_from_file = _read_file(model_file)
|
||||
model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
|
||||
# There are two elements in model.Metadata array before the population.
|
||||
# Metadata should be packed to the third element in the array.
|
||||
metadata_field = model.Metadata(2)
|
||||
self.assertEqual(
|
||||
six.ensure_str(metadata_field.Name()),
|
||||
six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
|
||||
|
||||
buffer_index = metadata_field.Buffer()
|
||||
buffer_data = model.Buffers(buffer_index)
|
||||
metadata_buf_np = buffer_data.DataAsNumpy()
|
||||
metadata_buf = metadata_buf_np.tobytes()
|
||||
expected_metadata_buf = bytearray(
|
||||
_read_file(self._metadata_file_with_version))
|
||||
self.assertEqual(metadata_buf, expected_metadata_buf)
|
||||
|
||||
def testPopulateMetadataFileToModelWithMetadataAndAssociatedFiles(self):
|
||||
# First, creates a dummy metadata different from self._metadata_file. It
|
||||
# needs to have the same input/output tensor numbers as self._model_file.
|
||||
# Populates it and the associated files into the model.
|
||||
input_meta = _metadata_fb.TensorMetadataT()
|
||||
output_meta = _metadata_fb.TensorMetadataT()
|
||||
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||
# Create a model with two inputs and one output.
|
||||
subgraph.inputTensorMetadata = [input_meta, input_meta]
|
||||
subgraph.outputTensorMetadata = [output_meta]
|
||||
model_meta = _metadata_fb.ModelMetadataT()
|
||||
model_meta.subgraphMetadata = [subgraph]
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(
|
||||
model_meta.Pack(b),
|
||||
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
metadata_buf = b.Output()
|
||||
|
||||
# Populate the metadata.
|
||||
populator1 = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
populator1.load_metadata_buffer(metadata_buf)
|
||||
populator1.load_associated_files([self._file1, self._file2])
|
||||
populator1.populate()
|
||||
|
||||
# Then, populate the metadata again.
|
||||
populator2 = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
populator2.load_metadata_file(self._metadata_file)
|
||||
populator2.populate()
|
||||
|
||||
# Test if the metadata is populated correctly.
|
||||
self._assert_golden_metadata(self._model_file)
|
||||
|
||||
def testPopulateMetadataFileToModelFileWithMetadataAndBufFields(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
populator.load_metadata_file(self._metadata_file)
|
||||
populator.load_associated_files([self._file1, self._file2])
|
||||
populator.populate()
|
||||
|
||||
# Tests if the metadata is populated correctly.
|
||||
self._assert_golden_metadata(self._model_file)
|
||||
|
||||
recorded_files = populator.get_recorded_associated_file_list()
|
||||
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
|
||||
|
||||
# Up to now, we've proved the correctness of the model buffer that read from
|
||||
# file. Then we'll test if get_model_buffer() gives the same model buffer.
|
||||
model_buf_from_file = _read_file(self._model_file)
|
||||
model_buf_from_getter = populator.get_model_buffer()
|
||||
self.assertEqual(model_buf_from_file, model_buf_from_getter)
|
||||
|
||||
def testPopulateInvalidMetadataFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
|
||||
with self.assertRaises(IOError) as error:
|
||||
populator.load_metadata_file(self._invalid_file)
|
||||
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||
str(error.exception))
|
||||
|
||||
def testPopulateInvalidMetadataBuffer(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.load_metadata_buffer([])
|
||||
self.assertEqual("The metadata to be populated is empty.",
|
||||
str(error.exception))
|
||||
|
||||
def testGetModelBufferBeforePopulatingData(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
|
||||
model_buf = populator.get_model_buffer()
|
||||
expected_model_buf = self._model_buf
|
||||
self.assertEqual(model_buf, expected_model_buf)
|
||||
|
||||
def testLoadMetadataBufferWithNoSubgraphMetadataThrowsException(self):
|
||||
# Create a dummy metadata without Subgraph.
|
||||
model_meta = _metadata_fb.ModelMetadataT()
|
||||
builder = flatbuffers.Builder(0)
|
||||
builder.Finish(
|
||||
model_meta.Pack(builder),
|
||||
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
meta_buf = builder.Output()
|
||||
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.load_metadata_buffer(meta_buf)
|
||||
self.assertEqual(
|
||||
"The number of SubgraphMetadata should be exactly one, but got 0.",
|
||||
str(error.exception))
|
||||
|
||||
def testLoadMetadataBufferWithWrongInputMetaNumberThrowsException(self):
|
||||
# Create a dummy metadata with no input tensor metadata, while the expected
|
||||
# number is 2.
|
||||
output_meta = _metadata_fb.TensorMetadataT()
|
||||
subgprah_meta = _metadata_fb.SubGraphMetadataT()
|
||||
subgprah_meta.outputTensorMetadata = [output_meta]
|
||||
model_meta = _metadata_fb.ModelMetadataT()
|
||||
model_meta.subgraphMetadata = [subgprah_meta]
|
||||
builder = flatbuffers.Builder(0)
|
||||
builder.Finish(
|
||||
model_meta.Pack(builder),
|
||||
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
meta_buf = builder.Output()
|
||||
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.load_metadata_buffer(meta_buf)
|
||||
self.assertEqual(
|
||||
("The number of input tensors (2) should match the number of "
|
||||
"input tensor metadata (0)"), str(error.exception))
|
||||
|
||||
def testLoadMetadataBufferWithWrongOutputMetaNumberThrowsException(self):
|
||||
# Create a dummy metadata with no output tensor metadata, while the expected
|
||||
# number is 1.
|
||||
input_meta = _metadata_fb.TensorMetadataT()
|
||||
subgprah_meta = _metadata_fb.SubGraphMetadataT()
|
||||
subgprah_meta.inputTensorMetadata = [input_meta, input_meta]
|
||||
model_meta = _metadata_fb.ModelMetadataT()
|
||||
model_meta.subgraphMetadata = [subgprah_meta]
|
||||
builder = flatbuffers.Builder(0)
|
||||
builder.Finish(
|
||||
model_meta.Pack(builder),
|
||||
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
meta_buf = builder.Output()
|
||||
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.load_metadata_buffer(meta_buf)
|
||||
self.assertEqual(
|
||||
("The number of output tensors (1) should match the number of "
|
||||
"output tensor metadata (0)"), str(error.exception))
|
||||
|
||||
def testLoadMetadataAndAssociatedFilesShouldSucceeds(self):
|
||||
# Create a src model with metadata and two associated files.
|
||||
src_model_buf = self._create_model_buf()
|
||||
populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf)
|
||||
populator_src.load_metadata_file(self._metadata_file)
|
||||
populator_src.load_associated_files([self._file1, self._file2])
|
||||
populator_src.populate()
|
||||
|
||||
# Create a model to be populated with the metadata and files from
|
||||
# src_model_buf.
|
||||
dst_model_buf = self._create_model_buf()
|
||||
populator_dst = _metadata.MetadataPopulator.with_model_buffer(dst_model_buf)
|
||||
populator_dst.load_metadata_and_associated_files(
|
||||
populator_src.get_model_buffer())
|
||||
populator_dst.populate()
|
||||
|
||||
# Tests if the metadata and associated files are populated correctly.
|
||||
dst_model_file = self.create_tempfile().full_path
|
||||
with open(dst_model_file, "wb") as f:
|
||||
f.write(populator_dst.get_model_buffer())
|
||||
self._assert_golden_metadata(dst_model_file)
|
||||
|
||||
recorded_files = populator_dst.get_recorded_associated_file_list()
|
||||
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
"testcase_name": "InputTensorWithBert",
|
||||
"tensor_type": TensorType.INPUT,
|
||||
"tokenizer_type": Tokenizer.BERT_TOKENIZER
|
||||
}, {
|
||||
"testcase_name": "OutputTensorWithBert",
|
||||
"tensor_type": TensorType.OUTPUT,
|
||||
"tokenizer_type": Tokenizer.BERT_TOKENIZER
|
||||
}, {
|
||||
"testcase_name": "InputTensorWithSentencePiece",
|
||||
"tensor_type": TensorType.INPUT,
|
||||
"tokenizer_type": Tokenizer.SENTENCE_PIECE
|
||||
}, {
|
||||
"testcase_name": "OutputTensorWithSentencePiece",
|
||||
"tensor_type": TensorType.OUTPUT,
|
||||
"tokenizer_type": Tokenizer.SENTENCE_PIECE
|
||||
})
|
||||
def testGetRecordedAssociatedFileListWithSubgraphTensor(
|
||||
self, tensor_type, tokenizer_type):
|
||||
# Creates a metadata with the tokenizer in the tensor process units.
|
||||
tokenizer, expected_files = self._create_tokenizer(tokenizer_type)
|
||||
|
||||
# Create the tensor with process units.
|
||||
tensor = _metadata_fb.TensorMetadataT()
|
||||
tensor.processUnits = [tokenizer]
|
||||
|
||||
# Create the subgrah with the tensor.
|
||||
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||
dummy_tensor_meta = _metadata_fb.TensorMetadataT()
|
||||
subgraph.outputTensorMetadata = [dummy_tensor_meta]
|
||||
if tensor_type is TensorType.INPUT:
|
||||
subgraph.inputTensorMetadata = [tensor, dummy_tensor_meta]
|
||||
subgraph.outputTensorMetadata = [dummy_tensor_meta]
|
||||
elif tensor_type is TensorType.OUTPUT:
|
||||
subgraph.inputTensorMetadata = [dummy_tensor_meta, dummy_tensor_meta]
|
||||
subgraph.outputTensorMetadata = [tensor]
|
||||
else:
|
||||
raise ValueError(
|
||||
"The tensor type, {0}, is unsupported.".format(tensor_type))
|
||||
|
||||
# Create a model metadata with the subgraph metadata
|
||||
meta_buffer = self._create_model_meta_with_subgraph_meta(subgraph)
|
||||
|
||||
# Creates the tempfiles.
|
||||
tempfiles = self._create_tempfiles(expected_files)
|
||||
|
||||
# Creates the MetadataPopulator object.
|
||||
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
populator.load_metadata_buffer(meta_buffer)
|
||||
populator.load_associated_files(tempfiles)
|
||||
populator.populate()
|
||||
|
||||
recorded_files = populator.get_recorded_associated_file_list()
|
||||
self.assertEqual(set(recorded_files), set(expected_files))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
"testcase_name": "InputTensorWithBert",
|
||||
"tensor_type": TensorType.INPUT,
|
||||
"tokenizer_type": Tokenizer.BERT_TOKENIZER
|
||||
}, {
|
||||
"testcase_name": "OutputTensorWithBert",
|
||||
"tensor_type": TensorType.OUTPUT,
|
||||
"tokenizer_type": Tokenizer.BERT_TOKENIZER
|
||||
}, {
|
||||
"testcase_name": "InputTensorWithSentencePiece",
|
||||
"tensor_type": TensorType.INPUT,
|
||||
"tokenizer_type": Tokenizer.SENTENCE_PIECE
|
||||
}, {
|
||||
"testcase_name": "OutputTensorWithSentencePiece",
|
||||
"tensor_type": TensorType.OUTPUT,
|
||||
"tokenizer_type": Tokenizer.SENTENCE_PIECE
|
||||
})
|
||||
def testGetRecordedAssociatedFileListWithSubgraphProcessUnits(
|
||||
self, tensor_type, tokenizer_type):
|
||||
# Creates a metadata with the tokenizer in the subgraph process units.
|
||||
tokenizer, expected_files = self._create_tokenizer(tokenizer_type)
|
||||
|
||||
# Create the subgraph with process units.
|
||||
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||
if tensor_type is TensorType.INPUT:
|
||||
subgraph.inputProcessUnits = [tokenizer]
|
||||
elif tensor_type is TensorType.OUTPUT:
|
||||
subgraph.outputProcessUnits = [tokenizer]
|
||||
else:
|
||||
raise ValueError(
|
||||
"The tensor type, {0}, is unsupported.".format(tensor_type))
|
||||
|
||||
# Creates the input and output tensor meta to match self._model_file.
|
||||
dummy_tensor_meta = _metadata_fb.TensorMetadataT()
|
||||
subgraph.inputTensorMetadata = [dummy_tensor_meta, dummy_tensor_meta]
|
||||
subgraph.outputTensorMetadata = [dummy_tensor_meta]
|
||||
|
||||
# Create a model metadata with the subgraph metadata
|
||||
meta_buffer = self._create_model_meta_with_subgraph_meta(subgraph)
|
||||
|
||||
# Creates the tempfiles.
|
||||
tempfiles = self._create_tempfiles(expected_files)
|
||||
|
||||
# Creates the MetadataPopulator object.
|
||||
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
populator.load_metadata_buffer(meta_buffer)
|
||||
populator.load_associated_files(tempfiles)
|
||||
populator.populate()
|
||||
|
||||
recorded_files = populator.get_recorded_associated_file_list()
|
||||
self.assertEqual(set(recorded_files), set(expected_files))
|
||||
|
||||
def testPopulatedFullPathAssociatedFileShouldSucceed(self):
|
||||
# Create AssociatedFileT using the full path file name.
|
||||
associated_file = _metadata_fb.AssociatedFileT()
|
||||
associated_file.name = self._file1
|
||||
|
||||
# Create model metadata with the associated file.
|
||||
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||
subgraph.associatedFiles = [associated_file]
|
||||
# Creates the input and output tensor metadata to match self._model_file.
|
||||
dummy_tensor = _metadata_fb.TensorMetadataT()
|
||||
subgraph.inputTensorMetadata = [dummy_tensor, dummy_tensor]
|
||||
subgraph.outputTensorMetadata = [dummy_tensor]
|
||||
md_buffer = self._create_model_meta_with_subgraph_meta(subgraph)
|
||||
|
||||
# Populate the metadata to a model.
|
||||
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
populator.load_metadata_buffer(md_buffer)
|
||||
populator.load_associated_files([self._file1])
|
||||
populator.populate()
|
||||
|
||||
# The recorded file name in metadata should only contain file basename; file
|
||||
# directory should not be included.
|
||||
recorded_files = populator.get_recorded_associated_file_list()
|
||||
self.assertEqual(set(recorded_files), set([os.path.basename(self._file1)]))
|
||||
|
||||
|
||||
class MetadataDisplayerTest(MetadataTest):
|
||||
|
||||
def setUp(self):
|
||||
super(MetadataDisplayerTest, self).setUp()
|
||||
self._model_with_meta_file = (
|
||||
self._create_model_with_metadata_and_associated_files())
|
||||
|
||||
def _create_model_with_metadata_and_associated_files(self):
|
||||
model_buf = self._create_model_buf()
|
||||
model_file = self.create_tempfile().full_path
|
||||
with open(model_file, "wb") as f:
|
||||
f.write(model_buf)
|
||||
|
||||
populator = _metadata.MetadataPopulator.with_model_file(model_file)
|
||||
populator.load_metadata_file(self._metadata_file)
|
||||
populator.load_associated_files([self._file1, self._file2])
|
||||
populator.populate()
|
||||
return model_file
|
||||
|
||||
def testLoadModelBufferMetadataBufferWithWrongIdentifierThrowsException(self):
|
||||
model_buf = self._create_model_buffer_with_wrong_identifier()
|
||||
metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
|
||||
model_buf = self._populate_metadata_with_identifier(
|
||||
model_buf, metadata_buf,
|
||||
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_buffer(model_buf)
|
||||
self.assertEqual(
|
||||
"The metadata buffer does not have the expected identifier, and may not"
|
||||
" be a valid TFLite Metadata.", str(error.exception))
|
||||
|
||||
def testLoadModelBufferModelBufferWithWrongIdentifierThrowsException(self):
|
||||
model_buf = self._create_model_buffer_with_wrong_identifier()
|
||||
metadata_file = self._create_metadata_file()
|
||||
wrong_identifier = b"widn"
|
||||
metadata_buf = bytearray(_read_file(metadata_file))
|
||||
model_buf = self._populate_metadata_with_identifier(model_buf, metadata_buf,
|
||||
wrong_identifier)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_buffer(model_buf)
|
||||
self.assertEqual(
|
||||
"The model provided does not have the expected identifier, and "
|
||||
"may not be a valid TFLite model.", str(error.exception))
|
||||
|
||||
def testLoadModelFileInvalidModelFileThrowsException(self):
|
||||
with self.assertRaises(IOError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_file(self._invalid_file)
|
||||
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||
str(error.exception))
|
||||
|
||||
def testLoadModelFileModelWithoutMetadataThrowsException(self):
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_file(self._model_file)
|
||||
self.assertEqual("The model does not have metadata.", str(error.exception))
|
||||
|
||||
def testLoadModelFileModelWithMetadata(self):
|
||||
displayer = _metadata.MetadataDisplayer.with_model_file(
|
||||
self._model_with_meta_file)
|
||||
self.assertIsInstance(displayer, _metadata.MetadataDisplayer)
|
||||
|
||||
def testLoadModelBufferInvalidModelBufferThrowsException(self):
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_buffer(_read_file(self._file1))
|
||||
self.assertEqual("model_buffer cannot be empty.", str(error.exception))
|
||||
|
||||
def testLoadModelBufferModelWithOutMetadataThrowsException(self):
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_buffer(self._create_model_buf())
|
||||
self.assertEqual("The model does not have metadata.", str(error.exception))
|
||||
|
||||
def testLoadModelBufferModelWithMetadata(self):
|
||||
displayer = _metadata.MetadataDisplayer.with_model_buffer(
|
||||
_read_file(self._model_with_meta_file))
|
||||
self.assertIsInstance(displayer, _metadata.MetadataDisplayer)
|
||||
|
||||
def testGetAssociatedFileBufferShouldSucceed(self):
|
||||
# _model_with_meta_file contains file1 and file2.
|
||||
displayer = _metadata.MetadataDisplayer.with_model_file(
|
||||
self._model_with_meta_file)
|
||||
|
||||
actual_content = displayer.get_associated_file_buffer("file2")
|
||||
self.assertEqual(actual_content, self._file2_content)
|
||||
|
||||
def testGetAssociatedFileBufferFailsWithNonExistentFile(self):
|
||||
# _model_with_meta_file contains file1 and file2.
|
||||
displayer = _metadata.MetadataDisplayer.with_model_file(
|
||||
self._model_with_meta_file)
|
||||
|
||||
non_existent_file = "non_existent_file"
|
||||
with self.assertRaises(ValueError) as error:
|
||||
displayer.get_associated_file_buffer(non_existent_file)
|
||||
self.assertEqual(
|
||||
"The file, {}, does not exist in the model.".format(non_existent_file),
|
||||
str(error.exception))
|
||||
|
||||
def testGetMetadataBufferShouldSucceed(self):
|
||||
displayer = _metadata.MetadataDisplayer.with_model_file(
|
||||
self._model_with_meta_file)
|
||||
actual_buffer = displayer.get_metadata_buffer()
|
||||
actual_json = _metadata.convert_to_json(actual_buffer)
|
||||
|
||||
# Verifies the generated json file.
|
||||
golden_json_file_path = test_utils.get_test_data_path("golden_json.json")
|
||||
with open(golden_json_file_path, "r") as f:
|
||||
expected = f.read()
|
||||
self.assertEqual(actual_json, expected)
|
||||
|
||||
def testGetMetadataJsonModelWithMetadata(self):
|
||||
displayer = _metadata.MetadataDisplayer.with_model_file(
|
||||
self._model_with_meta_file)
|
||||
actual = displayer.get_metadata_json()
|
||||
|
||||
# Verifies the generated json file.
|
||||
golden_json_file_path = test_utils.get_test_data_path("golden_json.json")
|
||||
expected = _read_file(golden_json_file_path, "r")
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def testGetPackedAssociatedFileListModelWithMetadata(self):
|
||||
displayer = _metadata.MetadataDisplayer.with_model_file(
|
||||
self._model_with_meta_file)
|
||||
packed_files = displayer.get_packed_associated_file_list()
|
||||
|
||||
expected_packed_files = [
|
||||
os.path.basename(self._file1),
|
||||
os.path.basename(self._file2)
|
||||
]
|
||||
self.assertLen(
|
||||
packed_files, 2,
|
||||
"The following two associated files packed to the model: {0}; {1}"
|
||||
.format(expected_packed_files[0], expected_packed_files[1]))
|
||||
self.assertEqual(set(packed_files), set(expected_packed_files))
|
||||
|
||||
|
||||
class MetadataUtilTest(MetadataTest):
|
||||
|
||||
def test_convert_to_json_should_succeed(self):
|
||||
metadata_buf = _read_file(self._metadata_file_with_version)
|
||||
metadata_json = _metadata.convert_to_json(metadata_buf)
|
||||
|
||||
# Verifies the generated json file.
|
||||
golden_json_file_path = test_utils.get_test_data_path("golden_json.json")
|
||||
expected = _read_file(golden_json_file_path, "r")
|
||||
self.assertEqual(metadata_json, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
45
mediapipe/tasks/python/test/test_utils.py
Normal file
45
mediapipe/tasks/python/test/test_utils.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test util for MediaPipe Tasks."""
|
||||
|
||||
import os
|
||||
|
||||
from absl import flags
|
||||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import image_frame as image_frame_module
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
_Image = image_module.Image
|
||||
_ImageFormat = image_frame_module.ImageFormat
|
||||
_RGB_CHANNELS = 3
|
||||
|
||||
|
||||
def test_srcdir():
|
||||
"""Returns the path where to look for test data files."""
|
||||
if "test_srcdir" in flags.FLAGS:
|
||||
return flags.FLAGS["test_srcdir"].value
|
||||
elif "TEST_SRCDIR" in os.environ:
|
||||
return os.environ["TEST_SRCDIR"]
|
||||
else:
|
||||
raise RuntimeError("Missing TEST_SRCDIR environment.")
|
||||
|
||||
|
||||
def get_test_data_path(file_or_dirname: str) -> str:
|
||||
"""Returns full test data path."""
|
||||
for (directory, subdirs, files) in os.walk(test_srcdir()):
|
||||
for f in subdirs + files:
|
||||
if f.endswith(file_or_dirname):
|
||||
return os.path.join(directory, f)
|
||||
raise ValueError("No %s in test directory" % file_or_dirname)
|
|
@ -31,7 +31,7 @@ py_test(
|
|||
"//mediapipe/tasks/python/components/containers:category",
|
||||
"//mediapipe/tasks/python/components/containers:detections",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_util",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/vision:object_detector",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
],
|
||||
|
|
|
@ -25,7 +25,7 @@ from mediapipe.tasks.python.components.containers import bounding_box as boundin
|
|||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import detections as detections_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_util
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import object_detector
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
|
@ -99,8 +99,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_util.get_test_data_path(_IMAGE_FILE))
|
||||
self.model_path = test_util.get_test_data_path(_MODEL_FILE)
|
||||
test_utils.get_test_data_path(_IMAGE_FILE))
|
||||
self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
|
|
12
mediapipe/tasks/testdata/metadata/BUILD
vendored
12
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -28,9 +28,13 @@ mediapipe_files(srcs = [
|
|||
"mobile_ica_8bit-without-model-metadata.tflite",
|
||||
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
|
||||
"mobilenet_v1_0.25_224_1_default_1.tflite",
|
||||
"mobilenet_v2_1.0_224_quant.tflite",
|
||||
])
|
||||
|
||||
exports_files(["external_file"])
|
||||
exports_files([
|
||||
"external_file",
|
||||
"golden_json.json",
|
||||
])
|
||||
|
||||
filegroup(
|
||||
name = "model_files",
|
||||
|
@ -40,10 +44,14 @@ filegroup(
|
|||
"mobile_ica_8bit-without-model-metadata.tflite",
|
||||
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
|
||||
"mobilenet_v1_0.25_224_1_default_1.tflite",
|
||||
"mobilenet_v2_1.0_224_quant.tflite",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "data_files",
|
||||
srcs = ["external_file"],
|
||||
srcs = [
|
||||
"external_file",
|
||||
"golden_json.json",
|
||||
],
|
||||
)
|
||||
|
|
28
mediapipe/tasks/testdata/metadata/golden_json.json
vendored
Normal file
28
mediapipe/tasks/testdata/metadata/golden_json.json
vendored
Normal file
|
@ -0,0 +1,28 @@
|
|||
{
|
||||
"name": "Mobilenet_quantized",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
},
|
||||
{
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "file2"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "file1"
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.0.0"
|
||||
}
|
12
third_party/external_files.bzl
vendored
12
third_party/external_files.bzl
vendored
|
@ -166,6 +166,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmark_with_attention.tflite?generation=1661875751615925"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_golden_json_json",
|
||||
sha256 = "55c0c88748d099aa379930504df62c6c8f1d8874ea52d2f8a925f352c4c7f09c",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/golden_json.json?generation=1664340169675228"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_hair_segmentation_tflite",
|
||||
sha256 = "d2c940c4fd80edeaf38f5d7387d1b4235ee320ed120080df67c663e749e77633",
|
||||
|
@ -316,6 +322,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite?generation=1661875836078124"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_tflite",
|
||||
sha256 = "f08d447cde49b4e0446428aa921aff0a14ea589fa9c5817b31f83128e9a43c1d",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.tflite?generation=1664340173966530"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilenet_v2_1_0_224_tflite",
|
||||
sha256 = "ff5cb7f9e62c92ebdad971f8a98aa6b3106d82a64587a7787c6a385c9e791339",
|
||||
|
|
Loading…
Reference in New Issue
Block a user