Migrate base metadata functionality like MetadataPopulator and MetadataDisplayer class into MediaPipe.
PiperOrigin-RevId: 478279747
This commit is contained in:
		
							parent
							
								
									9568de0570
								
							
						
					
					
						commit
						13f6e0c797
					
				| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
 | 
					load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library", "flatbuffer_py_library")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
package(
 | 
					package(
 | 
				
			||||||
    default_visibility = [
 | 
					    default_visibility = [
 | 
				
			||||||
| 
						 | 
					@ -14,3 +14,13 @@ flatbuffer_cc_library(
 | 
				
			||||||
    name = "metadata_schema_cc",
 | 
					    name = "metadata_schema_cc",
 | 
				
			||||||
    srcs = ["metadata_schema.fbs"],
 | 
					    srcs = ["metadata_schema.fbs"],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					flatbuffer_py_library(
 | 
				
			||||||
 | 
					    name = "schema_py",
 | 
				
			||||||
 | 
					    srcs = ["@org_tensorflow//tensorflow/lite/schema:schema.fbs"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					flatbuffer_py_library(
 | 
				
			||||||
 | 
					    name = "metadata_schema_py",
 | 
				
			||||||
 | 
					    srcs = ["metadata_schema.fbs"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										40
									
								
								mediapipe/tasks/python/metadata/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								mediapipe/tasks/python/metadata/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,40 @@
 | 
				
			||||||
 | 
					load("//mediapipe/tasks/metadata:build_defs.bzl", "stamp_metadata_parser_version")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(
 | 
				
			||||||
 | 
					    licenses = ["notice"],  # Apache 2.0
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					stamp_metadata_parser_version(
 | 
				
			||||||
 | 
					    name = "metadata_parser_py",
 | 
				
			||||||
 | 
					    srcs = ["metadata_parser.py.template"],
 | 
				
			||||||
 | 
					    outs = ["metadata_parser.py"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_library(
 | 
				
			||||||
 | 
					    name = "metadata",
 | 
				
			||||||
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        "metadata.py",
 | 
				
			||||||
 | 
					        ":metadata_parser_py",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    data = ["//mediapipe/tasks/metadata:metadata_schema.fbs"],
 | 
				
			||||||
 | 
					    srcs_version = "PY3",
 | 
				
			||||||
 | 
					    visibility = ["//visibility:public"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/metadata/python:_pywrap_metadata_version",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/metadata:metadata_schema_py",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/metadata:schema_py",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/metadata/flatbuffers_lib:_pywrap_flatbuffers",
 | 
				
			||||||
 | 
					        "@flatbuffers//:runtime_py",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_binary(
 | 
				
			||||||
 | 
					    name = "metadata_displayer_cli",
 | 
				
			||||||
 | 
					    srcs = ["metadata_displayer_cli.py"],
 | 
				
			||||||
 | 
					    visibility = [
 | 
				
			||||||
 | 
					        "//visibility:public",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":metadata",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										13
									
								
								mediapipe/tasks/python/metadata/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								mediapipe/tasks/python/metadata/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,13 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
							
								
								
									
										20
									
								
								mediapipe/tasks/python/metadata/flatbuffers_lib/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								mediapipe/tasks/python/metadata/flatbuffers_lib/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,20 @@
 | 
				
			||||||
 | 
					load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(
 | 
				
			||||||
 | 
					    default_visibility = ["//mediapipe/tasks:internal"],
 | 
				
			||||||
 | 
					    licenses = ["notice"],  # Apache 2.0
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					pybind_extension(
 | 
				
			||||||
 | 
					    name = "_pywrap_flatbuffers",
 | 
				
			||||||
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        "flatbuffers_lib.cc",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    features = ["-use_header_modules"],
 | 
				
			||||||
 | 
					    module_name = "_pywrap_flatbuffers",
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "@flatbuffers",
 | 
				
			||||||
 | 
					        "@local_config_python//:python_headers",
 | 
				
			||||||
 | 
					        "@pybind11",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,59 @@
 | 
				
			||||||
 | 
					/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					==============================================================================*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "flatbuffers/flatbuffers.h"
 | 
				
			||||||
 | 
					#include "flatbuffers/idl.h"
 | 
				
			||||||
 | 
					#include "pybind11/pybind11.h"
 | 
				
			||||||
 | 
					#include "pybind11/pytypes.h"
 | 
				
			||||||
 | 
					#include "pybind11/stl.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace tflite {
 | 
				
			||||||
 | 
					namespace support {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					PYBIND11_MODULE(_pywrap_flatbuffers, m) {
 | 
				
			||||||
 | 
					  pybind11::class_<flatbuffers::IDLOptions>(m, "IDLOptions")
 | 
				
			||||||
 | 
					      .def(pybind11::init<>())
 | 
				
			||||||
 | 
					      .def_readwrite("strict_json", &flatbuffers::IDLOptions::strict_json);
 | 
				
			||||||
 | 
					  pybind11::class_<flatbuffers::Parser>(m, "Parser")
 | 
				
			||||||
 | 
					      .def(pybind11::init<const flatbuffers::IDLOptions&>())
 | 
				
			||||||
 | 
					      .def("parse",
 | 
				
			||||||
 | 
					           [](flatbuffers::Parser* self, const std::string& source) {
 | 
				
			||||||
 | 
					             return self->Parse(source.c_str());
 | 
				
			||||||
 | 
					           })
 | 
				
			||||||
 | 
					      .def_readonly("builder", &flatbuffers::Parser::builder_)
 | 
				
			||||||
 | 
					      .def_readonly("error", &flatbuffers::Parser::error_);
 | 
				
			||||||
 | 
					  pybind11::class_<flatbuffers::FlatBufferBuilder>(m, "FlatBufferBuilder")
 | 
				
			||||||
 | 
					      .def("clear", &flatbuffers::FlatBufferBuilder::Clear)
 | 
				
			||||||
 | 
					      .def("push_flat_buffer", [](flatbuffers::FlatBufferBuilder* self,
 | 
				
			||||||
 | 
					                                  const std::string& contents) {
 | 
				
			||||||
 | 
					        self->PushFlatBuffer(reinterpret_cast<const uint8_t*>(contents.c_str()),
 | 
				
			||||||
 | 
					                             contents.length());
 | 
				
			||||||
 | 
					      });
 | 
				
			||||||
 | 
					  m.def("generate_text_file", &flatbuffers::GenerateTextFile);
 | 
				
			||||||
 | 
					  m.def(
 | 
				
			||||||
 | 
					      "generate_text",
 | 
				
			||||||
 | 
					      [](const flatbuffers::Parser& parser,
 | 
				
			||||||
 | 
					         const std::string& buffer) -> std::string {
 | 
				
			||||||
 | 
					        std::string text;
 | 
				
			||||||
 | 
					        if (!flatbuffers::GenerateText(
 | 
				
			||||||
 | 
					                parser, reinterpret_cast<const void*>(buffer.c_str()), &text)) {
 | 
				
			||||||
 | 
					          return "";
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        return text;
 | 
				
			||||||
 | 
					      });
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace support
 | 
				
			||||||
 | 
					}  // namespace tflite
 | 
				
			||||||
							
								
								
									
										865
									
								
								mediapipe/tasks/python/metadata/metadata.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										865
									
								
								mediapipe/tasks/python/metadata/metadata.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,865 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					# ==============================================================================
 | 
				
			||||||
 | 
					"""TensorFlow Lite metadata tools."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import copy
 | 
				
			||||||
 | 
					import inspect
 | 
				
			||||||
 | 
					import io
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import shutil
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					import tempfile
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					import zipfile
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import flatbuffers
 | 
				
			||||||
 | 
					from mediapipe.tasks.cc.metadata.python import _pywrap_metadata_version
 | 
				
			||||||
 | 
					from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
 | 
				
			||||||
 | 
					from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.metadata.flatbuffers_lib import _pywrap_flatbuffers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					  # If exists, optionally use TensorFlow to open and check files. Used to
 | 
				
			||||||
 | 
					  # support more than local file systems.
 | 
				
			||||||
 | 
					  # In pip requirements, we doesn't necessarily need tensorflow as a dep.
 | 
				
			||||||
 | 
					  import tensorflow as tf
 | 
				
			||||||
 | 
					  _open_file = tf.io.gfile.GFile
 | 
				
			||||||
 | 
					  _exists_file = tf.io.gfile.exists
 | 
				
			||||||
 | 
					except ImportError as e:
 | 
				
			||||||
 | 
					  # If TensorFlow package doesn't exist, fall back to original open and exists.
 | 
				
			||||||
 | 
					  _open_file = open
 | 
				
			||||||
 | 
					  _exists_file = os.path.exists
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _maybe_open_as_binary(filename, mode):
 | 
				
			||||||
 | 
					  """Maybe open the binary file, and returns a file-like."""
 | 
				
			||||||
 | 
					  if hasattr(filename, "read"):  # A file-like has read().
 | 
				
			||||||
 | 
					    return filename
 | 
				
			||||||
 | 
					  openmode = mode if "b" in mode else mode + "b"  # Add binary explicitly.
 | 
				
			||||||
 | 
					  return _open_file(filename, openmode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _open_as_zipfile(filename, mode="r"):
 | 
				
			||||||
 | 
					  """Open file as a zipfile.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Args:
 | 
				
			||||||
 | 
					    filename: str or file-like or path-like, to the zipfile.
 | 
				
			||||||
 | 
					    mode: str, common file mode for zip.
 | 
				
			||||||
 | 
					          (See: https://docs.python.org/3/library/zipfile.html)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Returns:
 | 
				
			||||||
 | 
					    A ZipFile object.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					  file_like = _maybe_open_as_binary(filename, mode)
 | 
				
			||||||
 | 
					  return zipfile.ZipFile(file_like, mode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _is_zipfile(filename):
 | 
				
			||||||
 | 
					  """Checks whether it is a zipfile."""
 | 
				
			||||||
 | 
					  with _maybe_open_as_binary(filename, "r") as f:
 | 
				
			||||||
 | 
					    return zipfile.is_zipfile(f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_path_to_datafile(path):
 | 
				
			||||||
 | 
					  """Gets the path to the specified file in the data dependencies.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  The path is relative to the file calling the function.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  It's a simple replacement of
 | 
				
			||||||
 | 
					  "tensorflow.python.platform.resource_loader.get_path_to_datafile".
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Args:
 | 
				
			||||||
 | 
					    path: a string resource path relative to the calling file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Returns:
 | 
				
			||||||
 | 
					    The path to the specified file present in the data attribute of py_test
 | 
				
			||||||
 | 
					    or py_binary.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					  data_files_path = os.path.dirname(inspect.getfile(sys._getframe(1)))  # pylint: disable=protected-access
 | 
				
			||||||
 | 
					  return os.path.join(data_files_path, path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_FLATC_TFLITE_METADATA_SCHEMA_FILE = get_path_to_datafile(
 | 
				
			||||||
 | 
					    "../../metadata/metadata_schema.fbs")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# TODO: add delete method for associated files.
 | 
				
			||||||
 | 
					class MetadataPopulator(object):
 | 
				
			||||||
 | 
					  """Packs metadata and associated files into TensorFlow Lite model file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MetadataPopulator can be used to populate metadata and model associated files
 | 
				
			||||||
 | 
					  into a model file or a model buffer (in bytearray). It can also help to
 | 
				
			||||||
 | 
					  inspect list of files that have been packed into the model or are supposed to
 | 
				
			||||||
 | 
					  be packed into the model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  The metadata file (or buffer) should be generated based on the metadata
 | 
				
			||||||
 | 
					  schema:
 | 
				
			||||||
 | 
					  third_party/tensorflow/lite/schema/metadata_schema.fbs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Example usage:
 | 
				
			||||||
 | 
					  Populate matadata and label file into an image classifier model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  First, based on metadata_schema.fbs, generate the metadata for this image
 | 
				
			||||||
 | 
					  classifer model using Flatbuffers API. Attach the label file onto the ouput
 | 
				
			||||||
 | 
					  tensor (the tensor of probabilities) in the metadata.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Then, pack the metadata and label file into the model as follows.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ```python
 | 
				
			||||||
 | 
					    # Populating a metadata file (or a metadta buffer) and associated files to
 | 
				
			||||||
 | 
					    a model file:
 | 
				
			||||||
 | 
					    populator = MetadataPopulator.with_model_file(model_file)
 | 
				
			||||||
 | 
					    # For metadata buffer (bytearray read from the metadata file), use:
 | 
				
			||||||
 | 
					    # populator.load_metadata_buffer(metadata_buf)
 | 
				
			||||||
 | 
					    populator.load_metadata_file(metadata_file)
 | 
				
			||||||
 | 
					    populator.load_associated_files([label.txt])
 | 
				
			||||||
 | 
					    # For associated file buffer (bytearray read from the file), use:
 | 
				
			||||||
 | 
					    # populator.load_associated_file_buffers({"label.txt": b"file content"})
 | 
				
			||||||
 | 
					    populator.populate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Populating a metadata file (or a metadata buffer) and associated files to
 | 
				
			||||||
 | 
					    a model buffer:
 | 
				
			||||||
 | 
					    populator = MetadataPopulator.with_model_buffer(model_buf)
 | 
				
			||||||
 | 
					    populator.load_metadata_file(metadata_file)
 | 
				
			||||||
 | 
					    populator.load_associated_files([label.txt])
 | 
				
			||||||
 | 
					    populator.populate()
 | 
				
			||||||
 | 
					    # Writing the updated model buffer into a file.
 | 
				
			||||||
 | 
					    updated_model_buf = populator.get_model_buffer()
 | 
				
			||||||
 | 
					    with open("updated_model.tflite", "wb") as f:
 | 
				
			||||||
 | 
					      f.write(updated_model_buf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Transferring metadata and associated files from another TFLite model:
 | 
				
			||||||
 | 
					    populator = MetadataPopulator.with_model_buffer(model_buf)
 | 
				
			||||||
 | 
					    populator_dst.load_metadata_and_associated_files(src_model_buf)
 | 
				
			||||||
 | 
					    populator_dst.populate()
 | 
				
			||||||
 | 
					    updated_model_buf = populator.get_model_buffer()
 | 
				
			||||||
 | 
					    with open("updated_model.tflite", "wb") as f:
 | 
				
			||||||
 | 
					      f.write(updated_model_buf)
 | 
				
			||||||
 | 
					    ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Note that existing metadata buffer (if applied) will be overridden by the new
 | 
				
			||||||
 | 
					  metadata buffer.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					  # As Zip API is used to concatenate associated files after tflite model file,
 | 
				
			||||||
 | 
					  # the populating operation is developed based on a model file. For in-memory
 | 
				
			||||||
 | 
					  # model buffer, we create a tempfile to serve the populating operation.
 | 
				
			||||||
 | 
					  # Creating the deleting such a tempfile is handled by the class,
 | 
				
			||||||
 | 
					  # _MetadataPopulatorWithBuffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  METADATA_FIELD_NAME = "TFLITE_METADATA"
 | 
				
			||||||
 | 
					  TFLITE_FILE_IDENTIFIER = b"TFL3"
 | 
				
			||||||
 | 
					  METADATA_FILE_IDENTIFIER = b"M001"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __init__(self, model_file):
 | 
				
			||||||
 | 
					    """Constructor for MetadataPopulator.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      model_file: valid path to a TensorFlow Lite model file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      IOError: File not found.
 | 
				
			||||||
 | 
					      ValueError: the model does not have the expected flatbuffer identifer.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    _assert_model_file_identifier(model_file)
 | 
				
			||||||
 | 
					    self._model_file = model_file
 | 
				
			||||||
 | 
					    self._metadata_buf = None
 | 
				
			||||||
 | 
					    # _associated_files is a dict of file name and file buffer.
 | 
				
			||||||
 | 
					    self._associated_files = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @classmethod
 | 
				
			||||||
 | 
					  def with_model_file(cls, model_file):
 | 
				
			||||||
 | 
					    """Creates a MetadataPopulator object that populates data to a model file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      model_file: valid path to a TensorFlow Lite model file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      MetadataPopulator object.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      IOError: File not found.
 | 
				
			||||||
 | 
					      ValueError: the model does not have the expected flatbuffer identifer.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return cls(model_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # TODO: investigate if type check can be applied to model_buf for
 | 
				
			||||||
 | 
					  # FB.
 | 
				
			||||||
 | 
					  @classmethod
 | 
				
			||||||
 | 
					  def with_model_buffer(cls, model_buf):
 | 
				
			||||||
 | 
					    """Creates a MetadataPopulator object that populates data to a model buffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      model_buf: TensorFlow Lite model buffer in bytearray.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: the model does not have the expected flatbuffer identifer.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return _MetadataPopulatorWithBuffer(model_buf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def get_model_buffer(self):
 | 
				
			||||||
 | 
					    """Gets the buffer of the model with packed metadata and associated files.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      Model buffer (in bytearray).
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    with _open_file(self._model_file, "rb") as f:
 | 
				
			||||||
 | 
					      return f.read()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def get_packed_associated_file_list(self):
 | 
				
			||||||
 | 
					    """Gets a list of associated files packed to the model file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      List of packed associated files.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if not _is_zipfile(self._model_file):
 | 
				
			||||||
 | 
					      return []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with _open_as_zipfile(self._model_file, "r") as zf:
 | 
				
			||||||
 | 
					      return zf.namelist()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def get_recorded_associated_file_list(self):
 | 
				
			||||||
 | 
					    """Gets a list of associated files recorded in metadata of the model file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Associated files may be attached to a model, a subgraph, or an input/output
 | 
				
			||||||
 | 
					    tensor.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      List of recorded associated files.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if not self._metadata_buf:
 | 
				
			||||||
 | 
					      return []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    metadata = _metadata_fb.ModelMetadataT.InitFromObj(
 | 
				
			||||||
 | 
					        _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
 | 
				
			||||||
 | 
					            self._metadata_buf, 0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return [
 | 
				
			||||||
 | 
					        file.name.decode("utf-8")
 | 
				
			||||||
 | 
					        for file in self._get_recorded_associated_file_object_list(metadata)
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def load_associated_file_buffers(self, associated_files):
 | 
				
			||||||
 | 
					    """Loads the associated file buffers (in bytearray) to be populated.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      associated_files: a dictionary of associated file names and corresponding
 | 
				
			||||||
 | 
					        file buffers, such as {"file.txt": b"file content"}. If pass in file
 | 
				
			||||||
 | 
					          paths for the file name, only the basename will be populated.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    self._associated_files.update({
 | 
				
			||||||
 | 
					        os.path.basename(name): buffers
 | 
				
			||||||
 | 
					        for name, buffers in associated_files.items()
 | 
				
			||||||
 | 
					    })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def load_associated_files(self, associated_files):
 | 
				
			||||||
 | 
					    """Loads associated files that to be concatenated after the model file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      associated_files: list of file paths.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      IOError:
 | 
				
			||||||
 | 
					        File not found.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    for af_name in associated_files:
 | 
				
			||||||
 | 
					      _assert_file_exist(af_name)
 | 
				
			||||||
 | 
					      with _open_file(af_name, "rb") as af:
 | 
				
			||||||
 | 
					        self.load_associated_file_buffers({af_name: af.read()})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def load_metadata_buffer(self, metadata_buf):
 | 
				
			||||||
 | 
					    """Loads the metadata buffer (in bytearray) to be populated.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      metadata_buf: metadata buffer (in bytearray) to be populated.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: The metadata to be populated is empty.
 | 
				
			||||||
 | 
					      ValueError: The metadata does not have the expected flatbuffer identifer.
 | 
				
			||||||
 | 
					      ValueError: Cannot get minimum metadata parser version.
 | 
				
			||||||
 | 
					      ValueError: The number of SubgraphMetadata is not 1.
 | 
				
			||||||
 | 
					      ValueError: The number of input/output tensors does not match the number
 | 
				
			||||||
 | 
					        of input/output tensor metadata.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if not metadata_buf:
 | 
				
			||||||
 | 
					      raise ValueError("The metadata to be populated is empty.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    self._validate_metadata(metadata_buf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Gets the minimum metadata parser version of the metadata_buf.
 | 
				
			||||||
 | 
					    min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion(
 | 
				
			||||||
 | 
					        bytes(metadata_buf))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Inserts in the minimum metadata parser version into the metadata_buf.
 | 
				
			||||||
 | 
					    metadata = _metadata_fb.ModelMetadataT.InitFromObj(
 | 
				
			||||||
 | 
					        _metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
 | 
				
			||||||
 | 
					    metadata.minParserVersion = min_version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Remove local file directory in the `name` field of `AssociatedFileT`, and
 | 
				
			||||||
 | 
					    # make it consistent with the name of the actual file packed in the model.
 | 
				
			||||||
 | 
					    self._use_basename_for_associated_files_in_metadata(metadata)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    b = flatbuffers.Builder(0)
 | 
				
			||||||
 | 
					    b.Finish(metadata.Pack(b), self.METADATA_FILE_IDENTIFIER)
 | 
				
			||||||
 | 
					    metadata_buf_with_version = b.Output()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    self._metadata_buf = metadata_buf_with_version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def load_metadata_file(self, metadata_file):
 | 
				
			||||||
 | 
					    """Loads the metadata file to be populated.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      metadata_file: path to the metadata file to be populated.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      IOError: File not found.
 | 
				
			||||||
 | 
					      ValueError: The metadata to be populated is empty.
 | 
				
			||||||
 | 
					      ValueError: The metadata does not have the expected flatbuffer identifer.
 | 
				
			||||||
 | 
					      ValueError: Cannot get minimum metadata parser version.
 | 
				
			||||||
 | 
					      ValueError: The number of SubgraphMetadata is not 1.
 | 
				
			||||||
 | 
					      ValueError: The number of input/output tensors does not match the number
 | 
				
			||||||
 | 
					        of input/output tensor metadata.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    _assert_file_exist(metadata_file)
 | 
				
			||||||
 | 
					    with _open_file(metadata_file, "rb") as f:
 | 
				
			||||||
 | 
					      metadata_buf = f.read()
 | 
				
			||||||
 | 
					    self.load_metadata_buffer(bytearray(metadata_buf))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def load_metadata_and_associated_files(self, src_model_buf):
 | 
				
			||||||
 | 
					    """Loads the metadata and associated files from another model buffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      src_model_buf: source model buffer (in bytearray) with metadata and
 | 
				
			||||||
 | 
					        associated files.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # Load the model metadata from src_model_buf if exist.
 | 
				
			||||||
 | 
					    metadata_buffer = get_metadata_buffer(src_model_buf)
 | 
				
			||||||
 | 
					    if metadata_buffer:
 | 
				
			||||||
 | 
					      self.load_metadata_buffer(metadata_buffer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Load the associated files from src_model_buf if exist.
 | 
				
			||||||
 | 
					    if _is_zipfile(io.BytesIO(src_model_buf)):
 | 
				
			||||||
 | 
					      with _open_as_zipfile(io.BytesIO(src_model_buf)) as zf:
 | 
				
			||||||
 | 
					        self.load_associated_file_buffers(
 | 
				
			||||||
 | 
					            {f: zf.read(f) for f in zf.namelist()})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def populate(self):
 | 
				
			||||||
 | 
					    """Populates loaded metadata and associated files into the model file."""
 | 
				
			||||||
 | 
					    self._assert_validate()
 | 
				
			||||||
 | 
					    self._populate_metadata_buffer()
 | 
				
			||||||
 | 
					    self._populate_associated_files()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _assert_validate(self):
 | 
				
			||||||
 | 
					    """Validates the metadata and associated files to be populated.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError:
 | 
				
			||||||
 | 
					        File is recorded in the metadata, but is not going to be populated.
 | 
				
			||||||
 | 
					        File has already been packed.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # Gets files that are recorded in metadata.
 | 
				
			||||||
 | 
					    recorded_files = self.get_recorded_associated_file_list()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Gets files that have been packed to self._model_file.
 | 
				
			||||||
 | 
					    packed_files = self.get_packed_associated_file_list()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Gets the file name of those associated files to be populated.
 | 
				
			||||||
 | 
					    to_be_populated_files = self._associated_files.keys()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Checks all files recorded in the metadata will be populated.
 | 
				
			||||||
 | 
					    for rf in recorded_files:
 | 
				
			||||||
 | 
					      if rf not in to_be_populated_files and rf not in packed_files:
 | 
				
			||||||
 | 
					        raise ValueError("File, '{0}', is recorded in the metadata, but has "
 | 
				
			||||||
 | 
					                         "not been loaded into the populator.".format(rf))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for f in to_be_populated_files:
 | 
				
			||||||
 | 
					      if f in packed_files:
 | 
				
			||||||
 | 
					        raise ValueError("File, '{0}', has already been packed.".format(f))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      if f not in recorded_files:
 | 
				
			||||||
 | 
					        warnings.warn(
 | 
				
			||||||
 | 
					            "File, '{0}', does not exist in the metadata. But packing it to "
 | 
				
			||||||
 | 
					            "tflite model is still allowed.".format(f))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _copy_archived_files(self, src_zip, file_list, dst_zip):
 | 
				
			||||||
 | 
					    """Copy archieved files in file_list from src_zip ro dst_zip."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not _is_zipfile(src_zip):
 | 
				
			||||||
 | 
					      raise ValueError("File, '{0}', is not a zipfile.".format(src_zip))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with _open_as_zipfile(src_zip, "r") as src_zf, \
 | 
				
			||||||
 | 
					         _open_as_zipfile(dst_zip, "a") as dst_zf:
 | 
				
			||||||
 | 
					      src_list = src_zf.namelist()
 | 
				
			||||||
 | 
					      for f in file_list:
 | 
				
			||||||
 | 
					        if f not in src_list:
 | 
				
			||||||
 | 
					          raise ValueError(
 | 
				
			||||||
 | 
					              "File, '{0}', does not exist in the zipfile, {1}.".format(
 | 
				
			||||||
 | 
					                  f, src_zip))
 | 
				
			||||||
 | 
					        file_buffer = src_zf.read(f)
 | 
				
			||||||
 | 
					        dst_zf.writestr(f, file_buffer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _get_associated_files_from_process_units(self, table, field_name):
 | 
				
			||||||
 | 
					    """Gets the files that are attached the process units field of a table.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      table: a Flatbuffers table object that contains fields of an array of
 | 
				
			||||||
 | 
					        ProcessUnit, such as TensorMetadata and SubGraphMetadata.
 | 
				
			||||||
 | 
					      field_name: the name of the field in the table that represents an array of
 | 
				
			||||||
 | 
					        ProcessUnit. If the table is TensorMetadata, field_name can be
 | 
				
			||||||
 | 
					        "ProcessUnits". If the table is SubGraphMetadata, field_name can be
 | 
				
			||||||
 | 
					        either "InputProcessUnits" or "OutputProcessUnits".
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A list of AssociatedFileT objects.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if table is None:
 | 
				
			||||||
 | 
					      return []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    file_list = []
 | 
				
			||||||
 | 
					    process_units = getattr(table, field_name)
 | 
				
			||||||
 | 
					    # If the process_units field is not populated, it will be None. Use an
 | 
				
			||||||
 | 
					    # empty list to skip the check.
 | 
				
			||||||
 | 
					    for process_unit in process_units or []:
 | 
				
			||||||
 | 
					      options = process_unit.options
 | 
				
			||||||
 | 
					      if isinstance(options, (_metadata_fb.BertTokenizerOptionsT,
 | 
				
			||||||
 | 
					                              _metadata_fb.RegexTokenizerOptionsT)):
 | 
				
			||||||
 | 
					        file_list += self._get_associated_files_from_table(options, "vocabFile")
 | 
				
			||||||
 | 
					      elif isinstance(options, _metadata_fb.SentencePieceTokenizerOptionsT):
 | 
				
			||||||
 | 
					        file_list += self._get_associated_files_from_table(
 | 
				
			||||||
 | 
					            options, "sentencePieceModel")
 | 
				
			||||||
 | 
					        file_list += self._get_associated_files_from_table(options, "vocabFile")
 | 
				
			||||||
 | 
					    return file_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _get_associated_files_from_table(self, table, field_name):
 | 
				
			||||||
 | 
					    """Gets the associated files that are attached a table directly.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      table: a Flatbuffers table object that contains fields of an array of
 | 
				
			||||||
 | 
					        AssociatedFile, such as TensorMetadata and BertTokenizerOptions.
 | 
				
			||||||
 | 
					      field_name: the name of the field in the table that represents an array of
 | 
				
			||||||
 | 
					        ProcessUnit. If the table is TensorMetadata, field_name can be
 | 
				
			||||||
 | 
					        "AssociatedFiles". If the table is BertTokenizerOptions, field_name can
 | 
				
			||||||
 | 
					        be "VocabFile".
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A list of AssociatedFileT objects.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if table is None:
 | 
				
			||||||
 | 
					      return []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # If the associated file field is not populated,
 | 
				
			||||||
 | 
					    # `getattr(table, field_name)` will be None. Return an empty list.
 | 
				
			||||||
 | 
					    return getattr(table, field_name) or []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _get_recorded_associated_file_object_list(self, metadata):
 | 
				
			||||||
 | 
					    """Gets a list of AssociatedFileT objects recorded in the metadata.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Associated files may be attached to a model, a subgraph, or an input/output
 | 
				
			||||||
 | 
					    tensor.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      metadata: the ModelMetadataT object.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      List of recorded AssociatedFileT objects.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    recorded_files = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Add associated files attached to ModelMetadata.
 | 
				
			||||||
 | 
					    recorded_files += self._get_associated_files_from_table(
 | 
				
			||||||
 | 
					        metadata, "associatedFiles")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Add associated files attached to each SubgraphMetadata.
 | 
				
			||||||
 | 
					    for subgraph in metadata.subgraphMetadata or []:
 | 
				
			||||||
 | 
					      recorded_files += self._get_associated_files_from_table(
 | 
				
			||||||
 | 
					          subgraph, "associatedFiles")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      # Add associated files attached to each input tensor.
 | 
				
			||||||
 | 
					      for tensor_metadata in subgraph.inputTensorMetadata or []:
 | 
				
			||||||
 | 
					        recorded_files += self._get_associated_files_from_table(
 | 
				
			||||||
 | 
					            tensor_metadata, "associatedFiles")
 | 
				
			||||||
 | 
					        recorded_files += self._get_associated_files_from_process_units(
 | 
				
			||||||
 | 
					            tensor_metadata, "processUnits")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      # Add associated files attached to each output tensor.
 | 
				
			||||||
 | 
					      for tensor_metadata in subgraph.outputTensorMetadata or []:
 | 
				
			||||||
 | 
					        recorded_files += self._get_associated_files_from_table(
 | 
				
			||||||
 | 
					            tensor_metadata, "associatedFiles")
 | 
				
			||||||
 | 
					        recorded_files += self._get_associated_files_from_process_units(
 | 
				
			||||||
 | 
					            tensor_metadata, "processUnits")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      # Add associated files attached to the input_process_units.
 | 
				
			||||||
 | 
					      recorded_files += self._get_associated_files_from_process_units(
 | 
				
			||||||
 | 
					          subgraph, "inputProcessUnits")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      # Add associated files attached to the output_process_units.
 | 
				
			||||||
 | 
					      recorded_files += self._get_associated_files_from_process_units(
 | 
				
			||||||
 | 
					          subgraph, "outputProcessUnits")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return recorded_files
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _populate_associated_files(self):
 | 
				
			||||||
 | 
					    """Concatenates associated files after TensorFlow Lite model file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    If the MetadataPopulator object is created using the method,
 | 
				
			||||||
 | 
					    with_model_file(model_file), the model file will be updated.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # Opens up the model file in "appending" mode.
 | 
				
			||||||
 | 
					    # If self._model_file already has pack files, zipfile will concatenate
 | 
				
			||||||
 | 
					    # addition files after self._model_file. For example, suppose we have
 | 
				
			||||||
 | 
					    # self._model_file = old_tflite_file | label1.txt | label2.txt
 | 
				
			||||||
 | 
					    # Then after trigger populate() to add label3.txt, self._model_file becomes
 | 
				
			||||||
 | 
					    # self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt
 | 
				
			||||||
 | 
					    with tempfile.SpooledTemporaryFile() as temp:
 | 
				
			||||||
 | 
					      # (1) Copy content from model file of to temp file.
 | 
				
			||||||
 | 
					      with _open_file(self._model_file, "rb") as f:
 | 
				
			||||||
 | 
					        shutil.copyfileobj(f, temp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      # (2) Append of to a temp file as a zip.
 | 
				
			||||||
 | 
					      with _open_as_zipfile(temp, "a") as zf:
 | 
				
			||||||
 | 
					        for file_name, file_buffer in self._associated_files.items():
 | 
				
			||||||
 | 
					          zf.writestr(file_name, file_buffer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      # (3) Copy temp file to model file.
 | 
				
			||||||
 | 
					      temp.seek(0)
 | 
				
			||||||
 | 
					      with _open_file(self._model_file, "wb") as f:
 | 
				
			||||||
 | 
					        shutil.copyfileobj(temp, f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _populate_metadata_buffer(self):
 | 
				
			||||||
 | 
					    """Populates the metadata buffer (in bytearray) into the model file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Inserts metadata_buf into the metadata field of schema.Model. If the
 | 
				
			||||||
 | 
					    MetadataPopulator object is created using the method,
 | 
				
			||||||
 | 
					    with_model_file(model_file), the model file will be updated.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Existing metadata buffer (if applied) will be overridden by the new metadata
 | 
				
			||||||
 | 
					    buffer.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with _open_file(self._model_file, "rb") as f:
 | 
				
			||||||
 | 
					      model_buf = f.read()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model = _schema_fb.ModelT.InitFromObj(
 | 
				
			||||||
 | 
					        _schema_fb.Model.GetRootAsModel(model_buf, 0))
 | 
				
			||||||
 | 
					    buffer_field = _schema_fb.BufferT()
 | 
				
			||||||
 | 
					    buffer_field.data = self._metadata_buf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    is_populated = False
 | 
				
			||||||
 | 
					    if not model.metadata:
 | 
				
			||||||
 | 
					      model.metadata = []
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					      # Check if metadata has already been populated.
 | 
				
			||||||
 | 
					      for meta in model.metadata:
 | 
				
			||||||
 | 
					        if meta.name.decode("utf-8") == self.METADATA_FIELD_NAME:
 | 
				
			||||||
 | 
					          is_populated = True
 | 
				
			||||||
 | 
					          model.buffers[meta.buffer] = buffer_field
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not is_populated:
 | 
				
			||||||
 | 
					      if not model.buffers:
 | 
				
			||||||
 | 
					        model.buffers = []
 | 
				
			||||||
 | 
					      model.buffers.append(buffer_field)
 | 
				
			||||||
 | 
					      # Creates a new metadata field.
 | 
				
			||||||
 | 
					      metadata_field = _schema_fb.MetadataT()
 | 
				
			||||||
 | 
					      metadata_field.name = self.METADATA_FIELD_NAME
 | 
				
			||||||
 | 
					      metadata_field.buffer = len(model.buffers) - 1
 | 
				
			||||||
 | 
					      model.metadata.append(metadata_field)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Packs model back to a flatbuffer binaray file.
 | 
				
			||||||
 | 
					    b = flatbuffers.Builder(0)
 | 
				
			||||||
 | 
					    b.Finish(model.Pack(b), self.TFLITE_FILE_IDENTIFIER)
 | 
				
			||||||
 | 
					    model_buf = b.Output()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Saves the updated model buffer to model file.
 | 
				
			||||||
 | 
					    # Gets files that have been packed to self._model_file.
 | 
				
			||||||
 | 
					    packed_files = self.get_packed_associated_file_list()
 | 
				
			||||||
 | 
					    if packed_files:
 | 
				
			||||||
 | 
					      # Writes the updated model buffer and associated files into a new model
 | 
				
			||||||
 | 
					      # file (in memory). Then overwrites the original model file.
 | 
				
			||||||
 | 
					      with tempfile.SpooledTemporaryFile() as temp:
 | 
				
			||||||
 | 
					        temp.write(model_buf)
 | 
				
			||||||
 | 
					        self._copy_archived_files(self._model_file, packed_files, temp)
 | 
				
			||||||
 | 
					        temp.seek(0)
 | 
				
			||||||
 | 
					        with _open_file(self._model_file, "wb") as f:
 | 
				
			||||||
 | 
					          shutil.copyfileobj(temp, f)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					      with _open_file(self._model_file, "wb") as f:
 | 
				
			||||||
 | 
					        f.write(model_buf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _use_basename_for_associated_files_in_metadata(self, metadata):
 | 
				
			||||||
 | 
					    """Removes any associated file local directory (if exists)."""
 | 
				
			||||||
 | 
					    for file in self._get_recorded_associated_file_object_list(metadata):
 | 
				
			||||||
 | 
					      file.name = os.path.basename(file.name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _validate_metadata(self, metadata_buf):
 | 
				
			||||||
 | 
					    """Validates the metadata to be populated."""
 | 
				
			||||||
 | 
					    _assert_metadata_buffer_identifier(metadata_buf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Verify the number of SubgraphMetadata is exactly one.
 | 
				
			||||||
 | 
					    # TFLite currently only support one subgraph.
 | 
				
			||||||
 | 
					    model_meta = _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
 | 
				
			||||||
 | 
					        metadata_buf, 0)
 | 
				
			||||||
 | 
					    if model_meta.SubgraphMetadataLength() != 1:
 | 
				
			||||||
 | 
					      raise ValueError("The number of SubgraphMetadata should be exactly one, "
 | 
				
			||||||
 | 
					                       "but got {0}.".format(
 | 
				
			||||||
 | 
					                           model_meta.SubgraphMetadataLength()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Verify if the number of tensor metadata matches the number of tensors.
 | 
				
			||||||
 | 
					    with _open_file(self._model_file, "rb") as f:
 | 
				
			||||||
 | 
					      model_buf = f.read()
 | 
				
			||||||
 | 
					    model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    num_input_tensors = model.Subgraphs(0).InputsLength()
 | 
				
			||||||
 | 
					    num_input_meta = model_meta.SubgraphMetadata(0).InputTensorMetadataLength()
 | 
				
			||||||
 | 
					    if num_input_tensors != num_input_meta:
 | 
				
			||||||
 | 
					      raise ValueError(
 | 
				
			||||||
 | 
					          "The number of input tensors ({0}) should match the number of "
 | 
				
			||||||
 | 
					          "input tensor metadata ({1})".format(num_input_tensors,
 | 
				
			||||||
 | 
					                                               num_input_meta))
 | 
				
			||||||
 | 
					    num_output_tensors = model.Subgraphs(0).OutputsLength()
 | 
				
			||||||
 | 
					    num_output_meta = model_meta.SubgraphMetadata(
 | 
				
			||||||
 | 
					        0).OutputTensorMetadataLength()
 | 
				
			||||||
 | 
					    if num_output_tensors != num_output_meta:
 | 
				
			||||||
 | 
					      raise ValueError(
 | 
				
			||||||
 | 
					          "The number of output tensors ({0}) should match the number of "
 | 
				
			||||||
 | 
					          "output tensor metadata ({1})".format(num_output_tensors,
 | 
				
			||||||
 | 
					                                                num_output_meta))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _MetadataPopulatorWithBuffer(MetadataPopulator):
 | 
				
			||||||
 | 
					  """Subclass of MetadtaPopulator that populates metadata to a model buffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  This class is used to populate metadata into a in-memory model buffer. As we
 | 
				
			||||||
 | 
					  use Zip API to concatenate associated files after tflite model file, the
 | 
				
			||||||
 | 
					  populating operation is developed based on a model file. For in-memory model
 | 
				
			||||||
 | 
					  buffer, we create a tempfile to serve the populating operation. This class is
 | 
				
			||||||
 | 
					  then used to generate this tempfile, and delete the file when the
 | 
				
			||||||
 | 
					  MetadataPopulator object is deleted.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __init__(self, model_buf):
 | 
				
			||||||
 | 
					    """Constructor for _MetadataPopulatorWithBuffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      model_buf: TensorFlow Lite model buffer in bytearray.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: model_buf is empty.
 | 
				
			||||||
 | 
					      ValueError: model_buf does not have the expected flatbuffer identifer.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if not model_buf:
 | 
				
			||||||
 | 
					      raise ValueError("model_buf cannot be empty.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with tempfile.NamedTemporaryFile() as temp:
 | 
				
			||||||
 | 
					      model_file = temp.name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with _open_file(model_file, "wb") as f:
 | 
				
			||||||
 | 
					      f.write(model_buf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    super().__init__(model_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __del__(self):
 | 
				
			||||||
 | 
					    """Destructor of _MetadataPopulatorWithBuffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Deletes the tempfile.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if os.path.exists(self._model_file):
 | 
				
			||||||
 | 
					      os.remove(self._model_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MetadataDisplayer(object):
 | 
				
			||||||
 | 
					  """Displays metadata and associated file info in human-readable format."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __init__(self, model_buffer, metadata_buffer, associated_file_list):
 | 
				
			||||||
 | 
					    """Constructor for MetadataDisplayer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      model_buffer: valid buffer of the model file.
 | 
				
			||||||
 | 
					      metadata_buffer: valid buffer of the metadata file.
 | 
				
			||||||
 | 
					      associated_file_list: list of associate files in the model file.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    _assert_model_buffer_identifier(model_buffer)
 | 
				
			||||||
 | 
					    _assert_metadata_buffer_identifier(metadata_buffer)
 | 
				
			||||||
 | 
					    self._model_buffer = model_buffer
 | 
				
			||||||
 | 
					    self._metadata_buffer = metadata_buffer
 | 
				
			||||||
 | 
					    self._associated_file_list = associated_file_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @classmethod
 | 
				
			||||||
 | 
					  def with_model_file(cls, model_file):
 | 
				
			||||||
 | 
					    """Creates a MetadataDisplayer object for the model file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      model_file: valid path to a TensorFlow Lite model file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      MetadataDisplayer object.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      IOError: File not found.
 | 
				
			||||||
 | 
					      ValueError: The model does not have metadata.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    _assert_file_exist(model_file)
 | 
				
			||||||
 | 
					    with _open_file(model_file, "rb") as f:
 | 
				
			||||||
 | 
					      return cls.with_model_buffer(f.read())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @classmethod
 | 
				
			||||||
 | 
					  def with_model_buffer(cls, model_buffer):
 | 
				
			||||||
 | 
					    """Creates a MetadataDisplayer object for a file buffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      model_buffer: TensorFlow Lite model buffer in bytearray.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      MetadataDisplayer object.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if not model_buffer:
 | 
				
			||||||
 | 
					      raise ValueError("model_buffer cannot be empty.")
 | 
				
			||||||
 | 
					    metadata_buffer = get_metadata_buffer(model_buffer)
 | 
				
			||||||
 | 
					    if not metadata_buffer:
 | 
				
			||||||
 | 
					      raise ValueError("The model does not have metadata.")
 | 
				
			||||||
 | 
					    associated_file_list = cls._parse_packed_associted_file_list(model_buffer)
 | 
				
			||||||
 | 
					    return cls(model_buffer, metadata_buffer, associated_file_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def get_associated_file_buffer(self, filename):
 | 
				
			||||||
 | 
					    """Get the specified associated file content in bytearray.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      filename: name of the file to be extracted.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      The file content in bytearray.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: if the file does not exist in the model.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if filename not in self._associated_file_list:
 | 
				
			||||||
 | 
					      raise ValueError(
 | 
				
			||||||
 | 
					          "The file, {}, does not exist in the model.".format(filename))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with _open_as_zipfile(io.BytesIO(self._model_buffer)) as zf:
 | 
				
			||||||
 | 
					      return zf.read(filename)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def get_metadata_buffer(self):
 | 
				
			||||||
 | 
					    """Get the metadata buffer in bytearray out from the model."""
 | 
				
			||||||
 | 
					    return copy.deepcopy(self._metadata_buffer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def get_metadata_json(self):
 | 
				
			||||||
 | 
					    """Converts the metadata into a json string."""
 | 
				
			||||||
 | 
					    return convert_to_json(self._metadata_buffer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def get_packed_associated_file_list(self):
 | 
				
			||||||
 | 
					    """Returns a list of associated files that are packed in the model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A name list of associated files.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return copy.deepcopy(self._associated_file_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @staticmethod
 | 
				
			||||||
 | 
					  def _parse_packed_associted_file_list(model_buf):
 | 
				
			||||||
 | 
					    """Gets a list of associated files packed to the model file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      model_buf: valid file buffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      List of packed associated files.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					      with _open_as_zipfile(io.BytesIO(model_buf)) as zf:
 | 
				
			||||||
 | 
					        return zf.namelist()
 | 
				
			||||||
 | 
					    except zipfile.BadZipFile:
 | 
				
			||||||
 | 
					      return []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Create an individual method for getting the metadata json file, so that it can
 | 
				
			||||||
 | 
					# be used as a standalone util.
 | 
				
			||||||
 | 
					def convert_to_json(metadata_buffer):
 | 
				
			||||||
 | 
					  """Converts the metadata into a json string.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Args:
 | 
				
			||||||
 | 
					    metadata_buffer: valid metadata buffer in bytes.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Returns:
 | 
				
			||||||
 | 
					    Metadata in JSON format.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Raises:
 | 
				
			||||||
 | 
					    ValueError: error occured when parsing the metadata schema file.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  opt = _pywrap_flatbuffers.IDLOptions()
 | 
				
			||||||
 | 
					  opt.strict_json = True
 | 
				
			||||||
 | 
					  parser = _pywrap_flatbuffers.Parser(opt)
 | 
				
			||||||
 | 
					  with _open_file(_FLATC_TFLITE_METADATA_SCHEMA_FILE) as f:
 | 
				
			||||||
 | 
					    metadata_schema_content = f.read()
 | 
				
			||||||
 | 
					  if not parser.parse(metadata_schema_content):
 | 
				
			||||||
 | 
					    raise ValueError("Cannot parse metadata schema. Reason: " + parser.error)
 | 
				
			||||||
 | 
					  return _pywrap_flatbuffers.generate_text(parser, metadata_buffer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _assert_file_exist(filename):
 | 
				
			||||||
 | 
					  """Checks if a file exists."""
 | 
				
			||||||
 | 
					  if not _exists_file(filename):
 | 
				
			||||||
 | 
					    raise IOError("File, '{0}', does not exist.".format(filename))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _assert_model_file_identifier(model_file):
 | 
				
			||||||
 | 
					  """Checks if a model file has the expected TFLite schema identifier."""
 | 
				
			||||||
 | 
					  _assert_file_exist(model_file)
 | 
				
			||||||
 | 
					  with _open_file(model_file, "rb") as f:
 | 
				
			||||||
 | 
					    _assert_model_buffer_identifier(f.read())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _assert_model_buffer_identifier(model_buf):
 | 
				
			||||||
 | 
					  if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0):
 | 
				
			||||||
 | 
					    raise ValueError(
 | 
				
			||||||
 | 
					        "The model provided does not have the expected identifier, and "
 | 
				
			||||||
 | 
					        "may not be a valid TFLite model.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _assert_metadata_buffer_identifier(metadata_buf):
 | 
				
			||||||
 | 
					  """Checks if a metadata buffer has the expected Metadata schema identifier."""
 | 
				
			||||||
 | 
					  if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier(
 | 
				
			||||||
 | 
					      metadata_buf, 0):
 | 
				
			||||||
 | 
					    raise ValueError(
 | 
				
			||||||
 | 
					        "The metadata buffer does not have the expected identifier, and may not"
 | 
				
			||||||
 | 
					        " be a valid TFLite Metadata.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_metadata_buffer(model_buf):
 | 
				
			||||||
 | 
					  """Returns the metadata in the model file as a buffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Args:
 | 
				
			||||||
 | 
					    model_buf: valid buffer of the model file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Returns:
 | 
				
			||||||
 | 
					    Metadata buffer. Returns `None` if the model does not have metadata.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					  tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # Gets metadata from the model file.
 | 
				
			||||||
 | 
					  for i in range(tflite_model.MetadataLength()):
 | 
				
			||||||
 | 
					    meta = tflite_model.Metadata(i)
 | 
				
			||||||
 | 
					    if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
 | 
				
			||||||
 | 
					      buffer_index = meta.Buffer()
 | 
				
			||||||
 | 
					      metadata = tflite_model.Buffers(buffer_index)
 | 
				
			||||||
 | 
					      return metadata.DataAsNumpy().tobytes()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return None
 | 
				
			||||||
							
								
								
									
										34
									
								
								mediapipe/tasks/python/metadata/metadata_displayer_cli.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								mediapipe/tasks/python/metadata/metadata_displayer_cli.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,34 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					# ==============================================================================
 | 
				
			||||||
 | 
					"""CLI tool for display metadata."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from absl import app
 | 
				
			||||||
 | 
					from absl import flags
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.metadata import metadata
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					FLAGS = flags.FLAGS
 | 
				
			||||||
 | 
					flags.DEFINE_string('model_path', None, 'Path to the TFLite model file.')
 | 
				
			||||||
 | 
					flags.DEFINE_string('export_json_path', None, 'Path to the output JSON file.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main(_):
 | 
				
			||||||
 | 
					  displayer = metadata.MetadataDisplayer.with_model_file(FLAGS.model_path)
 | 
				
			||||||
 | 
					  with open(FLAGS.export_json_path, 'w') as f:
 | 
				
			||||||
 | 
					    f.write(displayer.get_metadata_json())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					  app.run(main)
 | 
				
			||||||
							
								
								
									
										26
									
								
								mediapipe/tasks/python/metadata/metadata_parser.py.template
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								mediapipe/tasks/python/metadata/metadata_parser.py.template
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,26 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					# ==============================================================================
 | 
				
			||||||
 | 
					"""Information about the metadata parser that this python library depends on."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from __future__ import absolute_import
 | 
				
			||||||
 | 
					from __future__ import division
 | 
				
			||||||
 | 
					from __future__ import print_function
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MetadataParser(object):
 | 
				
			||||||
 | 
					  """Information about the metadata parser."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # The version of the metadata parser.
 | 
				
			||||||
 | 
					  VERSION = "{LATEST_METADATA_PARSER_VERSION}"
 | 
				
			||||||
| 
						 | 
					@ -19,9 +19,9 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
licenses(["notice"])
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
py_library(
 | 
					py_library(
 | 
				
			||||||
    name = "test_util",
 | 
					    name = "test_utils",
 | 
				
			||||||
    testonly = 1,
 | 
					    testonly = 1,
 | 
				
			||||||
    srcs = ["test_util.py"],
 | 
					    srcs = ["test_utils.py"],
 | 
				
			||||||
    srcs_version = "PY3",
 | 
					    srcs_version = "PY3",
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        "//mediapipe/python:_framework_bindings",
 | 
					        "//mediapipe/python:_framework_bindings",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										31
									
								
								mediapipe/tasks/python/test/metadata/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								mediapipe/tasks/python/test/metadata/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,31 @@
 | 
				
			||||||
 | 
					package(
 | 
				
			||||||
 | 
					    default_visibility = [
 | 
				
			||||||
 | 
					        "//visibility:public",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    licenses = ["notice"],  # Apache 2.0
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_test(
 | 
				
			||||||
 | 
					    name = "metadata_test",
 | 
				
			||||||
 | 
					    srcs = ["metadata_test.py"],
 | 
				
			||||||
 | 
					    data = ["//mediapipe/tasks/testdata/metadata:data_files"],
 | 
				
			||||||
 | 
					    python_version = "PY3",
 | 
				
			||||||
 | 
					    srcs_version = "PY3",
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/metadata:metadata_schema_py",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/metadata:schema_py",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/metadata",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/test:test_utils",
 | 
				
			||||||
 | 
					        "@flatbuffers//:runtime_py",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_test(
 | 
				
			||||||
 | 
					    name = "metadata_parser_test",
 | 
				
			||||||
 | 
					    srcs = ["metadata_parser_test.py"],
 | 
				
			||||||
 | 
					    python_version = "PY3",
 | 
				
			||||||
 | 
					    srcs_version = "PY2AND3",
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/metadata",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										37
									
								
								mediapipe/tasks/python/test/metadata/metadata_parser_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								mediapipe/tasks/python/test/metadata/metadata_parser_test.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,37 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					# ==============================================================================
 | 
				
			||||||
 | 
					"""Tests for mediapipe.tasks.python.metadata.metadata_parser."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from __future__ import absolute_import
 | 
				
			||||||
 | 
					from __future__ import division
 | 
				
			||||||
 | 
					from __future__ import print_function
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import re
 | 
				
			||||||
 | 
					from absl.testing import absltest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.metadata import metadata_parser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MetadataParserTest(absltest.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testVersionWellFormedSemanticVersion(self):
 | 
				
			||||||
 | 
					    # Validates that the version is well-formed (x.y.z).
 | 
				
			||||||
 | 
					    self.assertTrue(
 | 
				
			||||||
 | 
					        re.match('[0-9]+\\.[0-9]+\\.[0-9]+',
 | 
				
			||||||
 | 
					                 metadata_parser.MetadataParser.VERSION))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					  absltest.main()
 | 
				
			||||||
							
								
								
									
										857
									
								
								mediapipe/tasks/python/test/metadata/metadata_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										857
									
								
								mediapipe/tasks/python/test/metadata/metadata_test.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,857 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					# ==============================================================================
 | 
				
			||||||
 | 
					"""Tests for mediapipe.tasks.python.metadata.metadata."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import enum
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from absl.testing import absltest
 | 
				
			||||||
 | 
					from absl.testing import parameterized
 | 
				
			||||||
 | 
					import six
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import flatbuffers
 | 
				
			||||||
 | 
					from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
 | 
				
			||||||
 | 
					from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.metadata import metadata as _metadata
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.test import test_utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Tokenizer(enum.Enum):
 | 
				
			||||||
 | 
					  BERT_TOKENIZER = 0
 | 
				
			||||||
 | 
					  SENTENCE_PIECE = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TensorType(enum.Enum):
 | 
				
			||||||
 | 
					  INPUT = 0
 | 
				
			||||||
 | 
					  OUTPUT = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _read_file(file_name, mode="rb"):
 | 
				
			||||||
 | 
					  with open(file_name, mode) as f:
 | 
				
			||||||
 | 
					    return f.read()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MetadataTest(parameterized.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def setUp(self):
 | 
				
			||||||
 | 
					    super(MetadataTest, self).setUp()
 | 
				
			||||||
 | 
					    self._invalid_model_buf = None
 | 
				
			||||||
 | 
					    self._invalid_file = "not_existed_file"
 | 
				
			||||||
 | 
					    self._model_buf = self._create_model_buf()
 | 
				
			||||||
 | 
					    self._model_file = self.create_tempfile().full_path
 | 
				
			||||||
 | 
					    with open(self._model_file, "wb") as f:
 | 
				
			||||||
 | 
					      f.write(self._model_buf)
 | 
				
			||||||
 | 
					    self._metadata_file = self._create_metadata_file()
 | 
				
			||||||
 | 
					    self._metadata_file_with_version = self._create_metadata_file_with_version(
 | 
				
			||||||
 | 
					        self._metadata_file, "1.0.0")
 | 
				
			||||||
 | 
					    self._file1 = self.create_tempfile("file1").full_path
 | 
				
			||||||
 | 
					    self._file2 = self.create_tempfile("file2").full_path
 | 
				
			||||||
 | 
					    self._file2_content = b"file2_content"
 | 
				
			||||||
 | 
					    with open(self._file2, "wb") as f:
 | 
				
			||||||
 | 
					      f.write(self._file2_content)
 | 
				
			||||||
 | 
					    self._file3 = self.create_tempfile("file3").full_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_model_buf(self):
 | 
				
			||||||
 | 
					    # Create a model with two inputs and one output, which matches the metadata
 | 
				
			||||||
 | 
					    # created by _create_metadata_file().
 | 
				
			||||||
 | 
					    metadata_field = _schema_fb.MetadataT()
 | 
				
			||||||
 | 
					    subgraph = _schema_fb.SubGraphT()
 | 
				
			||||||
 | 
					    subgraph.inputs = [0, 1]
 | 
				
			||||||
 | 
					    subgraph.outputs = [2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    metadata_field.name = "meta"
 | 
				
			||||||
 | 
					    buffer_field = _schema_fb.BufferT()
 | 
				
			||||||
 | 
					    model = _schema_fb.ModelT()
 | 
				
			||||||
 | 
					    model.subgraphs = [subgraph]
 | 
				
			||||||
 | 
					    # Creates the metadata and buffer fields for testing purposes.
 | 
				
			||||||
 | 
					    model.metadata = [metadata_field, metadata_field]
 | 
				
			||||||
 | 
					    model.buffers = [buffer_field, buffer_field, buffer_field]
 | 
				
			||||||
 | 
					    model_builder = flatbuffers.Builder(0)
 | 
				
			||||||
 | 
					    model_builder.Finish(
 | 
				
			||||||
 | 
					        model.Pack(model_builder),
 | 
				
			||||||
 | 
					        _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
 | 
				
			||||||
 | 
					    return model_builder.Output()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_metadata_file(self):
 | 
				
			||||||
 | 
					    associated_file1 = _metadata_fb.AssociatedFileT()
 | 
				
			||||||
 | 
					    associated_file1.name = b"file1"
 | 
				
			||||||
 | 
					    associated_file2 = _metadata_fb.AssociatedFileT()
 | 
				
			||||||
 | 
					    associated_file2.name = b"file2"
 | 
				
			||||||
 | 
					    self.expected_recorded_files = [
 | 
				
			||||||
 | 
					        six.ensure_str(associated_file1.name),
 | 
				
			||||||
 | 
					        six.ensure_str(associated_file2.name)
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    input_meta = _metadata_fb.TensorMetadataT()
 | 
				
			||||||
 | 
					    output_meta = _metadata_fb.TensorMetadataT()
 | 
				
			||||||
 | 
					    output_meta.associatedFiles = [associated_file2]
 | 
				
			||||||
 | 
					    subgraph = _metadata_fb.SubGraphMetadataT()
 | 
				
			||||||
 | 
					    # Create a model with two inputs and one output.
 | 
				
			||||||
 | 
					    subgraph.inputTensorMetadata = [input_meta, input_meta]
 | 
				
			||||||
 | 
					    subgraph.outputTensorMetadata = [output_meta]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_meta = _metadata_fb.ModelMetadataT()
 | 
				
			||||||
 | 
					    model_meta.name = "Mobilenet_quantized"
 | 
				
			||||||
 | 
					    model_meta.associatedFiles = [associated_file1]
 | 
				
			||||||
 | 
					    model_meta.subgraphMetadata = [subgraph]
 | 
				
			||||||
 | 
					    b = flatbuffers.Builder(0)
 | 
				
			||||||
 | 
					    b.Finish(
 | 
				
			||||||
 | 
					        model_meta.Pack(b),
 | 
				
			||||||
 | 
					        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    metadata_file = self.create_tempfile().full_path
 | 
				
			||||||
 | 
					    with open(metadata_file, "wb") as f:
 | 
				
			||||||
 | 
					      f.write(b.Output())
 | 
				
			||||||
 | 
					    return metadata_file
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_model_buffer_with_wrong_identifier(self):
 | 
				
			||||||
 | 
					    wrong_identifier = b"widn"
 | 
				
			||||||
 | 
					    model = _schema_fb.ModelT()
 | 
				
			||||||
 | 
					    model_builder = flatbuffers.Builder(0)
 | 
				
			||||||
 | 
					    model_builder.Finish(model.Pack(model_builder), wrong_identifier)
 | 
				
			||||||
 | 
					    return model_builder.Output()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_metadata_buffer_with_wrong_identifier(self):
 | 
				
			||||||
 | 
					    # Creates a metadata with wrong identifier
 | 
				
			||||||
 | 
					    wrong_identifier = b"widn"
 | 
				
			||||||
 | 
					    metadata = _metadata_fb.ModelMetadataT()
 | 
				
			||||||
 | 
					    metadata_builder = flatbuffers.Builder(0)
 | 
				
			||||||
 | 
					    metadata_builder.Finish(metadata.Pack(metadata_builder), wrong_identifier)
 | 
				
			||||||
 | 
					    return metadata_builder.Output()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _populate_metadata_with_identifier(self, model_buf, metadata_buf,
 | 
				
			||||||
 | 
					                                         identifier):
 | 
				
			||||||
 | 
					    # For testing purposes only. MetadataPopulator cannot populate metadata with
 | 
				
			||||||
 | 
					    # wrong identifiers.
 | 
				
			||||||
 | 
					    model = _schema_fb.ModelT.InitFromObj(
 | 
				
			||||||
 | 
					        _schema_fb.Model.GetRootAsModel(model_buf, 0))
 | 
				
			||||||
 | 
					    buffer_field = _schema_fb.BufferT()
 | 
				
			||||||
 | 
					    buffer_field.data = metadata_buf
 | 
				
			||||||
 | 
					    model.buffers = [buffer_field]
 | 
				
			||||||
 | 
					    # Creates a new metadata field.
 | 
				
			||||||
 | 
					    metadata_field = _schema_fb.MetadataT()
 | 
				
			||||||
 | 
					    metadata_field.name = _metadata.MetadataPopulator.METADATA_FIELD_NAME
 | 
				
			||||||
 | 
					    metadata_field.buffer = len(model.buffers) - 1
 | 
				
			||||||
 | 
					    model.metadata = [metadata_field]
 | 
				
			||||||
 | 
					    b = flatbuffers.Builder(0)
 | 
				
			||||||
 | 
					    b.Finish(model.Pack(b), identifier)
 | 
				
			||||||
 | 
					    return b.Output()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_metadata_file_with_version(self, metadata_file, min_version):
 | 
				
			||||||
 | 
					    # Creates a new metadata file with the specified min_version for testing
 | 
				
			||||||
 | 
					    # purposes.
 | 
				
			||||||
 | 
					    metadata_buf = bytearray(_read_file(metadata_file))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    metadata = _metadata_fb.ModelMetadataT.InitFromObj(
 | 
				
			||||||
 | 
					        _metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
 | 
				
			||||||
 | 
					    metadata.minParserVersion = min_version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    b = flatbuffers.Builder(0)
 | 
				
			||||||
 | 
					    b.Finish(
 | 
				
			||||||
 | 
					        metadata.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    metadata_file_with_version = self.create_tempfile().full_path
 | 
				
			||||||
 | 
					    with open(metadata_file_with_version, "wb") as f:
 | 
				
			||||||
 | 
					      f.write(b.Output())
 | 
				
			||||||
 | 
					    return metadata_file_with_version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MetadataPopulatorTest(MetadataTest):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_bert_tokenizer(self):
 | 
				
			||||||
 | 
					    vocab_file_name = "bert_vocab"
 | 
				
			||||||
 | 
					    vocab = _metadata_fb.AssociatedFileT()
 | 
				
			||||||
 | 
					    vocab.name = vocab_file_name
 | 
				
			||||||
 | 
					    vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
 | 
				
			||||||
 | 
					    tokenizer = _metadata_fb.ProcessUnitT()
 | 
				
			||||||
 | 
					    tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions
 | 
				
			||||||
 | 
					    tokenizer.options = _metadata_fb.BertTokenizerOptionsT()
 | 
				
			||||||
 | 
					    tokenizer.options.vocabFile = [vocab]
 | 
				
			||||||
 | 
					    return tokenizer, [vocab_file_name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_sentence_piece_tokenizer(self):
 | 
				
			||||||
 | 
					    sp_model_name = "sp_model"
 | 
				
			||||||
 | 
					    vocab_file_name = "sp_vocab"
 | 
				
			||||||
 | 
					    sp_model = _metadata_fb.AssociatedFileT()
 | 
				
			||||||
 | 
					    sp_model.name = sp_model_name
 | 
				
			||||||
 | 
					    vocab = _metadata_fb.AssociatedFileT()
 | 
				
			||||||
 | 
					    vocab.name = vocab_file_name
 | 
				
			||||||
 | 
					    vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
 | 
				
			||||||
 | 
					    tokenizer = _metadata_fb.ProcessUnitT()
 | 
				
			||||||
 | 
					    tokenizer.optionsType = (
 | 
				
			||||||
 | 
					        _metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions)
 | 
				
			||||||
 | 
					    tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT()
 | 
				
			||||||
 | 
					    tokenizer.options.sentencePieceModel = [sp_model]
 | 
				
			||||||
 | 
					    tokenizer.options.vocabFile = [vocab]
 | 
				
			||||||
 | 
					    return tokenizer, [sp_model_name, vocab_file_name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_tokenizer(self, tokenizer_type):
 | 
				
			||||||
 | 
					    if tokenizer_type is Tokenizer.BERT_TOKENIZER:
 | 
				
			||||||
 | 
					      return self._create_bert_tokenizer()
 | 
				
			||||||
 | 
					    elif tokenizer_type is Tokenizer.SENTENCE_PIECE:
 | 
				
			||||||
 | 
					      return self._create_sentence_piece_tokenizer()
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					      raise ValueError(
 | 
				
			||||||
 | 
					          "The tokenizer type, {0}, is unsupported.".format(tokenizer_type))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_tempfiles(self, file_names):
 | 
				
			||||||
 | 
					    tempfiles = []
 | 
				
			||||||
 | 
					    for name in file_names:
 | 
				
			||||||
 | 
					      tempfiles.append(self.create_tempfile(name).full_path)
 | 
				
			||||||
 | 
					    return tempfiles
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_model_meta_with_subgraph_meta(self, subgraph_meta):
 | 
				
			||||||
 | 
					    model_meta = _metadata_fb.ModelMetadataT()
 | 
				
			||||||
 | 
					    model_meta.subgraphMetadata = [subgraph_meta]
 | 
				
			||||||
 | 
					    b = flatbuffers.Builder(0)
 | 
				
			||||||
 | 
					    b.Finish(
 | 
				
			||||||
 | 
					        model_meta.Pack(b),
 | 
				
			||||||
 | 
					        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
 | 
				
			||||||
 | 
					    return b.Output()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testToValidModelFile(self):
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
 | 
				
			||||||
 | 
					    self.assertIsInstance(populator, _metadata.MetadataPopulator)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testToInvalidModelFile(self):
 | 
				
			||||||
 | 
					    with self.assertRaises(IOError) as error:
 | 
				
			||||||
 | 
					      _metadata.MetadataPopulator.with_model_file(self._invalid_file)
 | 
				
			||||||
 | 
					    self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
 | 
				
			||||||
 | 
					                     str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testToValidModelBuffer(self):
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
 | 
				
			||||||
 | 
					    self.assertIsInstance(populator, _metadata.MetadataPopulator)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testToInvalidModelBuffer(self):
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      _metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf)
 | 
				
			||||||
 | 
					    self.assertEqual("model_buf cannot be empty.", str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testToModelBufferWithWrongIdentifier(self):
 | 
				
			||||||
 | 
					    model_buf = self._create_model_buffer_with_wrong_identifier()
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      _metadata.MetadataPopulator.with_model_buffer(model_buf)
 | 
				
			||||||
 | 
					    self.assertEqual(
 | 
				
			||||||
 | 
					        "The model provided does not have the expected identifier, and "
 | 
				
			||||||
 | 
					        "may not be a valid TFLite model.", str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testSinglePopulateAssociatedFile(self):
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
 | 
				
			||||||
 | 
					    populator.load_associated_files([self._file1])
 | 
				
			||||||
 | 
					    populator.populate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    packed_files = populator.get_packed_associated_file_list()
 | 
				
			||||||
 | 
					    expected_packed_files = [os.path.basename(self._file1)]
 | 
				
			||||||
 | 
					    self.assertEqual(set(packed_files), set(expected_packed_files))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testRepeatedPopulateAssociatedFile(self):
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
 | 
				
			||||||
 | 
					    populator.load_associated_files([self._file1, self._file2])
 | 
				
			||||||
 | 
					    # Loads file2 multiple times.
 | 
				
			||||||
 | 
					    populator.load_associated_files([self._file2])
 | 
				
			||||||
 | 
					    populator.populate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    packed_files = populator.get_packed_associated_file_list()
 | 
				
			||||||
 | 
					    expected_packed_files = [
 | 
				
			||||||
 | 
					        os.path.basename(self._file1),
 | 
				
			||||||
 | 
					        os.path.basename(self._file2)
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    self.assertLen(packed_files, 2)
 | 
				
			||||||
 | 
					    self.assertEqual(set(packed_files), set(expected_packed_files))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Check if the model buffer read from file is the same as that read from
 | 
				
			||||||
 | 
					    # get_model_buffer().
 | 
				
			||||||
 | 
					    model_buf_from_file = _read_file(self._model_file)
 | 
				
			||||||
 | 
					    model_buf_from_getter = populator.get_model_buffer()
 | 
				
			||||||
 | 
					    self.assertEqual(model_buf_from_file, model_buf_from_getter)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testPopulateInvalidAssociatedFile(self):
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
 | 
				
			||||||
 | 
					    with self.assertRaises(IOError) as error:
 | 
				
			||||||
 | 
					      populator.load_associated_files([self._invalid_file])
 | 
				
			||||||
 | 
					    self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
 | 
				
			||||||
 | 
					                     str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testPopulatePackedAssociatedFile(self):
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
 | 
				
			||||||
 | 
					    populator.load_associated_files([self._file1])
 | 
				
			||||||
 | 
					    populator.populate()
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      populator.load_associated_files([self._file1])
 | 
				
			||||||
 | 
					      populator.populate()
 | 
				
			||||||
 | 
					    self.assertEqual(
 | 
				
			||||||
 | 
					        "File, '{0}', has already been packed.".format(
 | 
				
			||||||
 | 
					            os.path.basename(self._file1)), str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testLoadAssociatedFileBuffers(self):
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
 | 
				
			||||||
 | 
					    file_buffer = _read_file(self._file1)
 | 
				
			||||||
 | 
					    populator.load_associated_file_buffers({self._file1: file_buffer})
 | 
				
			||||||
 | 
					    populator.populate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    packed_files = populator.get_packed_associated_file_list()
 | 
				
			||||||
 | 
					    expected_packed_files = [os.path.basename(self._file1)]
 | 
				
			||||||
 | 
					    self.assertEqual(set(packed_files), set(expected_packed_files))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testRepeatedLoadAssociatedFileBuffers(self):
 | 
				
			||||||
 | 
					    file_buffer1 = _read_file(self._file1)
 | 
				
			||||||
 | 
					    file_buffer2 = _read_file(self._file2)
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    populator.load_associated_file_buffers({
 | 
				
			||||||
 | 
					        self._file1: file_buffer1,
 | 
				
			||||||
 | 
					        self._file2: file_buffer2
 | 
				
			||||||
 | 
					    })
 | 
				
			||||||
 | 
					    # Loads file2 multiple times.
 | 
				
			||||||
 | 
					    populator.load_associated_file_buffers({self._file2: file_buffer2})
 | 
				
			||||||
 | 
					    populator.populate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    packed_files = populator.get_packed_associated_file_list()
 | 
				
			||||||
 | 
					    expected_packed_files = [
 | 
				
			||||||
 | 
					        os.path.basename(self._file1),
 | 
				
			||||||
 | 
					        os.path.basename(self._file2)
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    self.assertEqual(set(packed_files), set(expected_packed_files))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Check if the model buffer read from file is the same as that read from
 | 
				
			||||||
 | 
					    # get_model_buffer().
 | 
				
			||||||
 | 
					    model_buf_from_file = _read_file(self._model_file)
 | 
				
			||||||
 | 
					    model_buf_from_getter = populator.get_model_buffer()
 | 
				
			||||||
 | 
					    self.assertEqual(model_buf_from_file, model_buf_from_getter)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testLoadPackedAssociatedFileBuffersFails(self):
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
 | 
				
			||||||
 | 
					    file_buffer = _read_file(self._file1)
 | 
				
			||||||
 | 
					    populator.load_associated_file_buffers({self._file1: file_buffer})
 | 
				
			||||||
 | 
					    populator.populate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Load file1 again should fail.
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      populator.load_associated_file_buffers({self._file1: file_buffer})
 | 
				
			||||||
 | 
					      populator.populate()
 | 
				
			||||||
 | 
					    self.assertEqual(
 | 
				
			||||||
 | 
					        "File, '{0}', has already been packed.".format(
 | 
				
			||||||
 | 
					            os.path.basename(self._file1)), str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testGetPackedAssociatedFileList(self):
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
 | 
				
			||||||
 | 
					    packed_files = populator.get_packed_associated_file_list()
 | 
				
			||||||
 | 
					    self.assertEqual(packed_files, [])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testPopulateMetadataFileToEmptyModelFile(self):
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
 | 
				
			||||||
 | 
					    populator.load_metadata_file(self._metadata_file)
 | 
				
			||||||
 | 
					    populator.load_associated_files([self._file1, self._file2])
 | 
				
			||||||
 | 
					    populator.populate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_buf_from_file = _read_file(self._model_file)
 | 
				
			||||||
 | 
					    model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
 | 
				
			||||||
 | 
					    # self._model_file already has two elements in the metadata field, so the
 | 
				
			||||||
 | 
					    # populated TFLite metadata will be the third element.
 | 
				
			||||||
 | 
					    metadata_field = model.Metadata(2)
 | 
				
			||||||
 | 
					    self.assertEqual(
 | 
				
			||||||
 | 
					        six.ensure_str(metadata_field.Name()),
 | 
				
			||||||
 | 
					        six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    buffer_index = metadata_field.Buffer()
 | 
				
			||||||
 | 
					    buffer_data = model.Buffers(buffer_index)
 | 
				
			||||||
 | 
					    metadata_buf_np = buffer_data.DataAsNumpy()
 | 
				
			||||||
 | 
					    metadata_buf = metadata_buf_np.tobytes()
 | 
				
			||||||
 | 
					    expected_metadata_buf = bytearray(
 | 
				
			||||||
 | 
					        _read_file(self._metadata_file_with_version))
 | 
				
			||||||
 | 
					    self.assertEqual(metadata_buf, expected_metadata_buf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    recorded_files = populator.get_recorded_associated_file_list()
 | 
				
			||||||
 | 
					    self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Up to now, we've proved the correctness of the model buffer that read from
 | 
				
			||||||
 | 
					    # file. Then we'll test if get_model_buffer() gives the same model buffer.
 | 
				
			||||||
 | 
					    model_buf_from_getter = populator.get_model_buffer()
 | 
				
			||||||
 | 
					    self.assertEqual(model_buf_from_file, model_buf_from_getter)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testPopulateMetadataFileWithoutAssociatedFiles(self):
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
 | 
				
			||||||
 | 
					    populator.load_metadata_file(self._metadata_file)
 | 
				
			||||||
 | 
					    populator.load_associated_files([self._file1])
 | 
				
			||||||
 | 
					    # Suppose to populate self._file2, because it is recorded in the metadta.
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      populator.populate()
 | 
				
			||||||
 | 
					    self.assertEqual(("File, '{0}', is recorded in the metadata, but has "
 | 
				
			||||||
 | 
					                      "not been loaded into the populator.").format(
 | 
				
			||||||
 | 
					                          os.path.basename(self._file2)), str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testPopulateMetadataBufferWithWrongIdentifier(self):
 | 
				
			||||||
 | 
					    metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      populator.load_metadata_buffer(metadata_buf)
 | 
				
			||||||
 | 
					    self.assertEqual(
 | 
				
			||||||
 | 
					        "The metadata buffer does not have the expected identifier, and may not"
 | 
				
			||||||
 | 
					        " be a valid TFLite Metadata.", str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _assert_golden_metadata(self, model_file):
 | 
				
			||||||
 | 
					    model_buf_from_file = _read_file(model_file)
 | 
				
			||||||
 | 
					    model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
 | 
				
			||||||
 | 
					    # There are two elements in model.Metadata array before the population.
 | 
				
			||||||
 | 
					    # Metadata should be packed to the third element in the array.
 | 
				
			||||||
 | 
					    metadata_field = model.Metadata(2)
 | 
				
			||||||
 | 
					    self.assertEqual(
 | 
				
			||||||
 | 
					        six.ensure_str(metadata_field.Name()),
 | 
				
			||||||
 | 
					        six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    buffer_index = metadata_field.Buffer()
 | 
				
			||||||
 | 
					    buffer_data = model.Buffers(buffer_index)
 | 
				
			||||||
 | 
					    metadata_buf_np = buffer_data.DataAsNumpy()
 | 
				
			||||||
 | 
					    metadata_buf = metadata_buf_np.tobytes()
 | 
				
			||||||
 | 
					    expected_metadata_buf = bytearray(
 | 
				
			||||||
 | 
					        _read_file(self._metadata_file_with_version))
 | 
				
			||||||
 | 
					    self.assertEqual(metadata_buf, expected_metadata_buf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testPopulateMetadataFileToModelWithMetadataAndAssociatedFiles(self):
 | 
				
			||||||
 | 
					    # First, creates a dummy metadata different from self._metadata_file. It
 | 
				
			||||||
 | 
					    # needs to have the same input/output tensor numbers as self._model_file.
 | 
				
			||||||
 | 
					    # Populates it and the associated files into the model.
 | 
				
			||||||
 | 
					    input_meta = _metadata_fb.TensorMetadataT()
 | 
				
			||||||
 | 
					    output_meta = _metadata_fb.TensorMetadataT()
 | 
				
			||||||
 | 
					    subgraph = _metadata_fb.SubGraphMetadataT()
 | 
				
			||||||
 | 
					    # Create a model with two inputs and one output.
 | 
				
			||||||
 | 
					    subgraph.inputTensorMetadata = [input_meta, input_meta]
 | 
				
			||||||
 | 
					    subgraph.outputTensorMetadata = [output_meta]
 | 
				
			||||||
 | 
					    model_meta = _metadata_fb.ModelMetadataT()
 | 
				
			||||||
 | 
					    model_meta.subgraphMetadata = [subgraph]
 | 
				
			||||||
 | 
					    b = flatbuffers.Builder(0)
 | 
				
			||||||
 | 
					    b.Finish(
 | 
				
			||||||
 | 
					        model_meta.Pack(b),
 | 
				
			||||||
 | 
					        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
 | 
				
			||||||
 | 
					    metadata_buf = b.Output()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Populate the metadata.
 | 
				
			||||||
 | 
					    populator1 = _metadata.MetadataPopulator.with_model_file(self._model_file)
 | 
				
			||||||
 | 
					    populator1.load_metadata_buffer(metadata_buf)
 | 
				
			||||||
 | 
					    populator1.load_associated_files([self._file1, self._file2])
 | 
				
			||||||
 | 
					    populator1.populate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Then, populate the metadata again.
 | 
				
			||||||
 | 
					    populator2 = _metadata.MetadataPopulator.with_model_file(self._model_file)
 | 
				
			||||||
 | 
					    populator2.load_metadata_file(self._metadata_file)
 | 
				
			||||||
 | 
					    populator2.populate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Test if the metadata is populated correctly.
 | 
				
			||||||
 | 
					    self._assert_golden_metadata(self._model_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testPopulateMetadataFileToModelFileWithMetadataAndBufFields(self):
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
 | 
				
			||||||
 | 
					    populator.load_metadata_file(self._metadata_file)
 | 
				
			||||||
 | 
					    populator.load_associated_files([self._file1, self._file2])
 | 
				
			||||||
 | 
					    populator.populate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Tests if the metadata is populated correctly.
 | 
				
			||||||
 | 
					    self._assert_golden_metadata(self._model_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    recorded_files = populator.get_recorded_associated_file_list()
 | 
				
			||||||
 | 
					    self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Up to now, we've proved the correctness of the model buffer that read from
 | 
				
			||||||
 | 
					    # file. Then we'll test if get_model_buffer() gives the same model buffer.
 | 
				
			||||||
 | 
					    model_buf_from_file = _read_file(self._model_file)
 | 
				
			||||||
 | 
					    model_buf_from_getter = populator.get_model_buffer()
 | 
				
			||||||
 | 
					    self.assertEqual(model_buf_from_file, model_buf_from_getter)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testPopulateInvalidMetadataFile(self):
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
 | 
				
			||||||
 | 
					    with self.assertRaises(IOError) as error:
 | 
				
			||||||
 | 
					      populator.load_metadata_file(self._invalid_file)
 | 
				
			||||||
 | 
					    self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
 | 
				
			||||||
 | 
					                     str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testPopulateInvalidMetadataBuffer(self):
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      populator.load_metadata_buffer([])
 | 
				
			||||||
 | 
					    self.assertEqual("The metadata to be populated is empty.",
 | 
				
			||||||
 | 
					                     str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testGetModelBufferBeforePopulatingData(self):
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
 | 
				
			||||||
 | 
					    model_buf = populator.get_model_buffer()
 | 
				
			||||||
 | 
					    expected_model_buf = self._model_buf
 | 
				
			||||||
 | 
					    self.assertEqual(model_buf, expected_model_buf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testLoadMetadataBufferWithNoSubgraphMetadataThrowsException(self):
 | 
				
			||||||
 | 
					    # Create a dummy metadata without Subgraph.
 | 
				
			||||||
 | 
					    model_meta = _metadata_fb.ModelMetadataT()
 | 
				
			||||||
 | 
					    builder = flatbuffers.Builder(0)
 | 
				
			||||||
 | 
					    builder.Finish(
 | 
				
			||||||
 | 
					        model_meta.Pack(builder),
 | 
				
			||||||
 | 
					        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
 | 
				
			||||||
 | 
					    meta_buf = builder.Output()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      populator.load_metadata_buffer(meta_buf)
 | 
				
			||||||
 | 
					    self.assertEqual(
 | 
				
			||||||
 | 
					        "The number of SubgraphMetadata should be exactly one, but got 0.",
 | 
				
			||||||
 | 
					        str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testLoadMetadataBufferWithWrongInputMetaNumberThrowsException(self):
 | 
				
			||||||
 | 
					    # Create a dummy metadata with no input tensor metadata, while the expected
 | 
				
			||||||
 | 
					    # number is 2.
 | 
				
			||||||
 | 
					    output_meta = _metadata_fb.TensorMetadataT()
 | 
				
			||||||
 | 
					    subgprah_meta = _metadata_fb.SubGraphMetadataT()
 | 
				
			||||||
 | 
					    subgprah_meta.outputTensorMetadata = [output_meta]
 | 
				
			||||||
 | 
					    model_meta = _metadata_fb.ModelMetadataT()
 | 
				
			||||||
 | 
					    model_meta.subgraphMetadata = [subgprah_meta]
 | 
				
			||||||
 | 
					    builder = flatbuffers.Builder(0)
 | 
				
			||||||
 | 
					    builder.Finish(
 | 
				
			||||||
 | 
					        model_meta.Pack(builder),
 | 
				
			||||||
 | 
					        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
 | 
				
			||||||
 | 
					    meta_buf = builder.Output()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      populator.load_metadata_buffer(meta_buf)
 | 
				
			||||||
 | 
					    self.assertEqual(
 | 
				
			||||||
 | 
					        ("The number of input tensors (2) should match the number of "
 | 
				
			||||||
 | 
					         "input tensor metadata (0)"), str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testLoadMetadataBufferWithWrongOutputMetaNumberThrowsException(self):
 | 
				
			||||||
 | 
					    # Create a dummy metadata with no output tensor metadata, while the expected
 | 
				
			||||||
 | 
					    # number is 1.
 | 
				
			||||||
 | 
					    input_meta = _metadata_fb.TensorMetadataT()
 | 
				
			||||||
 | 
					    subgprah_meta = _metadata_fb.SubGraphMetadataT()
 | 
				
			||||||
 | 
					    subgprah_meta.inputTensorMetadata = [input_meta, input_meta]
 | 
				
			||||||
 | 
					    model_meta = _metadata_fb.ModelMetadataT()
 | 
				
			||||||
 | 
					    model_meta.subgraphMetadata = [subgprah_meta]
 | 
				
			||||||
 | 
					    builder = flatbuffers.Builder(0)
 | 
				
			||||||
 | 
					    builder.Finish(
 | 
				
			||||||
 | 
					        model_meta.Pack(builder),
 | 
				
			||||||
 | 
					        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
 | 
				
			||||||
 | 
					    meta_buf = builder.Output()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      populator.load_metadata_buffer(meta_buf)
 | 
				
			||||||
 | 
					    self.assertEqual(
 | 
				
			||||||
 | 
					        ("The number of output tensors (1) should match the number of "
 | 
				
			||||||
 | 
					         "output tensor metadata (0)"), str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testLoadMetadataAndAssociatedFilesShouldSucceeds(self):
 | 
				
			||||||
 | 
					    # Create a src model with metadata and two associated files.
 | 
				
			||||||
 | 
					    src_model_buf = self._create_model_buf()
 | 
				
			||||||
 | 
					    populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf)
 | 
				
			||||||
 | 
					    populator_src.load_metadata_file(self._metadata_file)
 | 
				
			||||||
 | 
					    populator_src.load_associated_files([self._file1, self._file2])
 | 
				
			||||||
 | 
					    populator_src.populate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create a model to be populated with the metadata and files from
 | 
				
			||||||
 | 
					    # src_model_buf.
 | 
				
			||||||
 | 
					    dst_model_buf = self._create_model_buf()
 | 
				
			||||||
 | 
					    populator_dst = _metadata.MetadataPopulator.with_model_buffer(dst_model_buf)
 | 
				
			||||||
 | 
					    populator_dst.load_metadata_and_associated_files(
 | 
				
			||||||
 | 
					        populator_src.get_model_buffer())
 | 
				
			||||||
 | 
					    populator_dst.populate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Tests if the metadata and associated files are populated correctly.
 | 
				
			||||||
 | 
					    dst_model_file = self.create_tempfile().full_path
 | 
				
			||||||
 | 
					    with open(dst_model_file, "wb") as f:
 | 
				
			||||||
 | 
					      f.write(populator_dst.get_model_buffer())
 | 
				
			||||||
 | 
					    self._assert_golden_metadata(dst_model_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    recorded_files = populator_dst.get_recorded_associated_file_list()
 | 
				
			||||||
 | 
					    self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @parameterized.named_parameters(
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					          "testcase_name": "InputTensorWithBert",
 | 
				
			||||||
 | 
					          "tensor_type": TensorType.INPUT,
 | 
				
			||||||
 | 
					          "tokenizer_type": Tokenizer.BERT_TOKENIZER
 | 
				
			||||||
 | 
					      }, {
 | 
				
			||||||
 | 
					          "testcase_name": "OutputTensorWithBert",
 | 
				
			||||||
 | 
					          "tensor_type": TensorType.OUTPUT,
 | 
				
			||||||
 | 
					          "tokenizer_type": Tokenizer.BERT_TOKENIZER
 | 
				
			||||||
 | 
					      }, {
 | 
				
			||||||
 | 
					          "testcase_name": "InputTensorWithSentencePiece",
 | 
				
			||||||
 | 
					          "tensor_type": TensorType.INPUT,
 | 
				
			||||||
 | 
					          "tokenizer_type": Tokenizer.SENTENCE_PIECE
 | 
				
			||||||
 | 
					      }, {
 | 
				
			||||||
 | 
					          "testcase_name": "OutputTensorWithSentencePiece",
 | 
				
			||||||
 | 
					          "tensor_type": TensorType.OUTPUT,
 | 
				
			||||||
 | 
					          "tokenizer_type": Tokenizer.SENTENCE_PIECE
 | 
				
			||||||
 | 
					      })
 | 
				
			||||||
 | 
					  def testGetRecordedAssociatedFileListWithSubgraphTensor(
 | 
				
			||||||
 | 
					      self, tensor_type, tokenizer_type):
 | 
				
			||||||
 | 
					    # Creates a metadata with the tokenizer in the tensor process units.
 | 
				
			||||||
 | 
					    tokenizer, expected_files = self._create_tokenizer(tokenizer_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create the tensor with process units.
 | 
				
			||||||
 | 
					    tensor = _metadata_fb.TensorMetadataT()
 | 
				
			||||||
 | 
					    tensor.processUnits = [tokenizer]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create the subgrah with the tensor.
 | 
				
			||||||
 | 
					    subgraph = _metadata_fb.SubGraphMetadataT()
 | 
				
			||||||
 | 
					    dummy_tensor_meta = _metadata_fb.TensorMetadataT()
 | 
				
			||||||
 | 
					    subgraph.outputTensorMetadata = [dummy_tensor_meta]
 | 
				
			||||||
 | 
					    if tensor_type is TensorType.INPUT:
 | 
				
			||||||
 | 
					      subgraph.inputTensorMetadata = [tensor, dummy_tensor_meta]
 | 
				
			||||||
 | 
					      subgraph.outputTensorMetadata = [dummy_tensor_meta]
 | 
				
			||||||
 | 
					    elif tensor_type is TensorType.OUTPUT:
 | 
				
			||||||
 | 
					      subgraph.inputTensorMetadata = [dummy_tensor_meta, dummy_tensor_meta]
 | 
				
			||||||
 | 
					      subgraph.outputTensorMetadata = [tensor]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					      raise ValueError(
 | 
				
			||||||
 | 
					          "The tensor type, {0}, is unsupported.".format(tensor_type))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create a model metadata with the subgraph metadata
 | 
				
			||||||
 | 
					    meta_buffer = self._create_model_meta_with_subgraph_meta(subgraph)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Creates the tempfiles.
 | 
				
			||||||
 | 
					    tempfiles = self._create_tempfiles(expected_files)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Creates the MetadataPopulator object.
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
 | 
				
			||||||
 | 
					    populator.load_metadata_buffer(meta_buffer)
 | 
				
			||||||
 | 
					    populator.load_associated_files(tempfiles)
 | 
				
			||||||
 | 
					    populator.populate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    recorded_files = populator.get_recorded_associated_file_list()
 | 
				
			||||||
 | 
					    self.assertEqual(set(recorded_files), set(expected_files))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @parameterized.named_parameters(
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					          "testcase_name": "InputTensorWithBert",
 | 
				
			||||||
 | 
					          "tensor_type": TensorType.INPUT,
 | 
				
			||||||
 | 
					          "tokenizer_type": Tokenizer.BERT_TOKENIZER
 | 
				
			||||||
 | 
					      }, {
 | 
				
			||||||
 | 
					          "testcase_name": "OutputTensorWithBert",
 | 
				
			||||||
 | 
					          "tensor_type": TensorType.OUTPUT,
 | 
				
			||||||
 | 
					          "tokenizer_type": Tokenizer.BERT_TOKENIZER
 | 
				
			||||||
 | 
					      }, {
 | 
				
			||||||
 | 
					          "testcase_name": "InputTensorWithSentencePiece",
 | 
				
			||||||
 | 
					          "tensor_type": TensorType.INPUT,
 | 
				
			||||||
 | 
					          "tokenizer_type": Tokenizer.SENTENCE_PIECE
 | 
				
			||||||
 | 
					      }, {
 | 
				
			||||||
 | 
					          "testcase_name": "OutputTensorWithSentencePiece",
 | 
				
			||||||
 | 
					          "tensor_type": TensorType.OUTPUT,
 | 
				
			||||||
 | 
					          "tokenizer_type": Tokenizer.SENTENCE_PIECE
 | 
				
			||||||
 | 
					      })
 | 
				
			||||||
 | 
					  def testGetRecordedAssociatedFileListWithSubgraphProcessUnits(
 | 
				
			||||||
 | 
					      self, tensor_type, tokenizer_type):
 | 
				
			||||||
 | 
					    # Creates a metadata with the tokenizer in the subgraph process units.
 | 
				
			||||||
 | 
					    tokenizer, expected_files = self._create_tokenizer(tokenizer_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create the subgraph with process units.
 | 
				
			||||||
 | 
					    subgraph = _metadata_fb.SubGraphMetadataT()
 | 
				
			||||||
 | 
					    if tensor_type is TensorType.INPUT:
 | 
				
			||||||
 | 
					      subgraph.inputProcessUnits = [tokenizer]
 | 
				
			||||||
 | 
					    elif tensor_type is TensorType.OUTPUT:
 | 
				
			||||||
 | 
					      subgraph.outputProcessUnits = [tokenizer]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					      raise ValueError(
 | 
				
			||||||
 | 
					          "The tensor type, {0}, is unsupported.".format(tensor_type))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Creates the input and output tensor meta to match self._model_file.
 | 
				
			||||||
 | 
					    dummy_tensor_meta = _metadata_fb.TensorMetadataT()
 | 
				
			||||||
 | 
					    subgraph.inputTensorMetadata = [dummy_tensor_meta, dummy_tensor_meta]
 | 
				
			||||||
 | 
					    subgraph.outputTensorMetadata = [dummy_tensor_meta]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create a model metadata with the subgraph metadata
 | 
				
			||||||
 | 
					    meta_buffer = self._create_model_meta_with_subgraph_meta(subgraph)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Creates the tempfiles.
 | 
				
			||||||
 | 
					    tempfiles = self._create_tempfiles(expected_files)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Creates the MetadataPopulator object.
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
 | 
				
			||||||
 | 
					    populator.load_metadata_buffer(meta_buffer)
 | 
				
			||||||
 | 
					    populator.load_associated_files(tempfiles)
 | 
				
			||||||
 | 
					    populator.populate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    recorded_files = populator.get_recorded_associated_file_list()
 | 
				
			||||||
 | 
					    self.assertEqual(set(recorded_files), set(expected_files))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testPopulatedFullPathAssociatedFileShouldSucceed(self):
 | 
				
			||||||
 | 
					    # Create AssociatedFileT using the full path file name.
 | 
				
			||||||
 | 
					    associated_file = _metadata_fb.AssociatedFileT()
 | 
				
			||||||
 | 
					    associated_file.name = self._file1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create model metadata with the associated file.
 | 
				
			||||||
 | 
					    subgraph = _metadata_fb.SubGraphMetadataT()
 | 
				
			||||||
 | 
					    subgraph.associatedFiles = [associated_file]
 | 
				
			||||||
 | 
					    # Creates the input and output tensor metadata to match self._model_file.
 | 
				
			||||||
 | 
					    dummy_tensor = _metadata_fb.TensorMetadataT()
 | 
				
			||||||
 | 
					    subgraph.inputTensorMetadata = [dummy_tensor, dummy_tensor]
 | 
				
			||||||
 | 
					    subgraph.outputTensorMetadata = [dummy_tensor]
 | 
				
			||||||
 | 
					    md_buffer = self._create_model_meta_with_subgraph_meta(subgraph)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Populate the metadata to a model.
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
 | 
				
			||||||
 | 
					    populator.load_metadata_buffer(md_buffer)
 | 
				
			||||||
 | 
					    populator.load_associated_files([self._file1])
 | 
				
			||||||
 | 
					    populator.populate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # The recorded file name in metadata should only contain file basename; file
 | 
				
			||||||
 | 
					    # directory should not be included.
 | 
				
			||||||
 | 
					    recorded_files = populator.get_recorded_associated_file_list()
 | 
				
			||||||
 | 
					    self.assertEqual(set(recorded_files), set([os.path.basename(self._file1)]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MetadataDisplayerTest(MetadataTest):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def setUp(self):
 | 
				
			||||||
 | 
					    super(MetadataDisplayerTest, self).setUp()
 | 
				
			||||||
 | 
					    self._model_with_meta_file = (
 | 
				
			||||||
 | 
					        self._create_model_with_metadata_and_associated_files())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_model_with_metadata_and_associated_files(self):
 | 
				
			||||||
 | 
					    model_buf = self._create_model_buf()
 | 
				
			||||||
 | 
					    model_file = self.create_tempfile().full_path
 | 
				
			||||||
 | 
					    with open(model_file, "wb") as f:
 | 
				
			||||||
 | 
					      f.write(model_buf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    populator = _metadata.MetadataPopulator.with_model_file(model_file)
 | 
				
			||||||
 | 
					    populator.load_metadata_file(self._metadata_file)
 | 
				
			||||||
 | 
					    populator.load_associated_files([self._file1, self._file2])
 | 
				
			||||||
 | 
					    populator.populate()
 | 
				
			||||||
 | 
					    return model_file
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testLoadModelBufferMetadataBufferWithWrongIdentifierThrowsException(self):
 | 
				
			||||||
 | 
					    model_buf = self._create_model_buffer_with_wrong_identifier()
 | 
				
			||||||
 | 
					    metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
 | 
				
			||||||
 | 
					    model_buf = self._populate_metadata_with_identifier(
 | 
				
			||||||
 | 
					        model_buf, metadata_buf,
 | 
				
			||||||
 | 
					        _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      _metadata.MetadataDisplayer.with_model_buffer(model_buf)
 | 
				
			||||||
 | 
					    self.assertEqual(
 | 
				
			||||||
 | 
					        "The metadata buffer does not have the expected identifier, and may not"
 | 
				
			||||||
 | 
					        " be a valid TFLite Metadata.", str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testLoadModelBufferModelBufferWithWrongIdentifierThrowsException(self):
 | 
				
			||||||
 | 
					    model_buf = self._create_model_buffer_with_wrong_identifier()
 | 
				
			||||||
 | 
					    metadata_file = self._create_metadata_file()
 | 
				
			||||||
 | 
					    wrong_identifier = b"widn"
 | 
				
			||||||
 | 
					    metadata_buf = bytearray(_read_file(metadata_file))
 | 
				
			||||||
 | 
					    model_buf = self._populate_metadata_with_identifier(model_buf, metadata_buf,
 | 
				
			||||||
 | 
					                                                        wrong_identifier)
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      _metadata.MetadataDisplayer.with_model_buffer(model_buf)
 | 
				
			||||||
 | 
					    self.assertEqual(
 | 
				
			||||||
 | 
					        "The model provided does not have the expected identifier, and "
 | 
				
			||||||
 | 
					        "may not be a valid TFLite model.", str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testLoadModelFileInvalidModelFileThrowsException(self):
 | 
				
			||||||
 | 
					    with self.assertRaises(IOError) as error:
 | 
				
			||||||
 | 
					      _metadata.MetadataDisplayer.with_model_file(self._invalid_file)
 | 
				
			||||||
 | 
					    self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
 | 
				
			||||||
 | 
					                     str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testLoadModelFileModelWithoutMetadataThrowsException(self):
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      _metadata.MetadataDisplayer.with_model_file(self._model_file)
 | 
				
			||||||
 | 
					    self.assertEqual("The model does not have metadata.", str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testLoadModelFileModelWithMetadata(self):
 | 
				
			||||||
 | 
					    displayer = _metadata.MetadataDisplayer.with_model_file(
 | 
				
			||||||
 | 
					        self._model_with_meta_file)
 | 
				
			||||||
 | 
					    self.assertIsInstance(displayer, _metadata.MetadataDisplayer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testLoadModelBufferInvalidModelBufferThrowsException(self):
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      _metadata.MetadataDisplayer.with_model_buffer(_read_file(self._file1))
 | 
				
			||||||
 | 
					    self.assertEqual("model_buffer cannot be empty.", str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testLoadModelBufferModelWithOutMetadataThrowsException(self):
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      _metadata.MetadataDisplayer.with_model_buffer(self._create_model_buf())
 | 
				
			||||||
 | 
					    self.assertEqual("The model does not have metadata.", str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testLoadModelBufferModelWithMetadata(self):
 | 
				
			||||||
 | 
					    displayer = _metadata.MetadataDisplayer.with_model_buffer(
 | 
				
			||||||
 | 
					        _read_file(self._model_with_meta_file))
 | 
				
			||||||
 | 
					    self.assertIsInstance(displayer, _metadata.MetadataDisplayer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testGetAssociatedFileBufferShouldSucceed(self):
 | 
				
			||||||
 | 
					    # _model_with_meta_file contains file1 and file2.
 | 
				
			||||||
 | 
					    displayer = _metadata.MetadataDisplayer.with_model_file(
 | 
				
			||||||
 | 
					        self._model_with_meta_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    actual_content = displayer.get_associated_file_buffer("file2")
 | 
				
			||||||
 | 
					    self.assertEqual(actual_content, self._file2_content)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testGetAssociatedFileBufferFailsWithNonExistentFile(self):
 | 
				
			||||||
 | 
					    # _model_with_meta_file contains file1 and file2.
 | 
				
			||||||
 | 
					    displayer = _metadata.MetadataDisplayer.with_model_file(
 | 
				
			||||||
 | 
					        self._model_with_meta_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    non_existent_file = "non_existent_file"
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError) as error:
 | 
				
			||||||
 | 
					      displayer.get_associated_file_buffer(non_existent_file)
 | 
				
			||||||
 | 
					    self.assertEqual(
 | 
				
			||||||
 | 
					        "The file, {}, does not exist in the model.".format(non_existent_file),
 | 
				
			||||||
 | 
					        str(error.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testGetMetadataBufferShouldSucceed(self):
 | 
				
			||||||
 | 
					    displayer = _metadata.MetadataDisplayer.with_model_file(
 | 
				
			||||||
 | 
					        self._model_with_meta_file)
 | 
				
			||||||
 | 
					    actual_buffer = displayer.get_metadata_buffer()
 | 
				
			||||||
 | 
					    actual_json = _metadata.convert_to_json(actual_buffer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Verifies the generated json file.
 | 
				
			||||||
 | 
					    golden_json_file_path = test_utils.get_test_data_path("golden_json.json")
 | 
				
			||||||
 | 
					    with open(golden_json_file_path, "r") as f:
 | 
				
			||||||
 | 
					      expected = f.read()
 | 
				
			||||||
 | 
					    self.assertEqual(actual_json, expected)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testGetMetadataJsonModelWithMetadata(self):
 | 
				
			||||||
 | 
					    displayer = _metadata.MetadataDisplayer.with_model_file(
 | 
				
			||||||
 | 
					        self._model_with_meta_file)
 | 
				
			||||||
 | 
					    actual = displayer.get_metadata_json()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Verifies the generated json file.
 | 
				
			||||||
 | 
					    golden_json_file_path = test_utils.get_test_data_path("golden_json.json")
 | 
				
			||||||
 | 
					    expected = _read_file(golden_json_file_path, "r")
 | 
				
			||||||
 | 
					    self.assertEqual(actual, expected)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testGetPackedAssociatedFileListModelWithMetadata(self):
 | 
				
			||||||
 | 
					    displayer = _metadata.MetadataDisplayer.with_model_file(
 | 
				
			||||||
 | 
					        self._model_with_meta_file)
 | 
				
			||||||
 | 
					    packed_files = displayer.get_packed_associated_file_list()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    expected_packed_files = [
 | 
				
			||||||
 | 
					        os.path.basename(self._file1),
 | 
				
			||||||
 | 
					        os.path.basename(self._file2)
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    self.assertLen(
 | 
				
			||||||
 | 
					        packed_files, 2,
 | 
				
			||||||
 | 
					        "The following two associated files packed to the model: {0}; {1}"
 | 
				
			||||||
 | 
					        .format(expected_packed_files[0], expected_packed_files[1]))
 | 
				
			||||||
 | 
					    self.assertEqual(set(packed_files), set(expected_packed_files))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MetadataUtilTest(MetadataTest):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_convert_to_json_should_succeed(self):
 | 
				
			||||||
 | 
					    metadata_buf = _read_file(self._metadata_file_with_version)
 | 
				
			||||||
 | 
					    metadata_json = _metadata.convert_to_json(metadata_buf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Verifies the generated json file.
 | 
				
			||||||
 | 
					    golden_json_file_path = test_utils.get_test_data_path("golden_json.json")
 | 
				
			||||||
 | 
					    expected = _read_file(golden_json_file_path, "r")
 | 
				
			||||||
 | 
					    self.assertEqual(metadata_json, expected)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					  absltest.main()
 | 
				
			||||||
							
								
								
									
										45
									
								
								mediapipe/tasks/python/test/test_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								mediapipe/tasks/python/test/test_utils.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,45 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Test util for MediaPipe Tasks."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from absl import flags
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.python._framework_bindings import image as image_module
 | 
				
			||||||
 | 
					from mediapipe.python._framework_bindings import image_frame as image_frame_module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					FLAGS = flags.FLAGS
 | 
				
			||||||
 | 
					_Image = image_module.Image
 | 
				
			||||||
 | 
					_ImageFormat = image_frame_module.ImageFormat
 | 
				
			||||||
 | 
					_RGB_CHANNELS = 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_srcdir():
 | 
				
			||||||
 | 
					  """Returns the path where to look for test data files."""
 | 
				
			||||||
 | 
					  if "test_srcdir" in flags.FLAGS:
 | 
				
			||||||
 | 
					    return flags.FLAGS["test_srcdir"].value
 | 
				
			||||||
 | 
					  elif "TEST_SRCDIR" in os.environ:
 | 
				
			||||||
 | 
					    return os.environ["TEST_SRCDIR"]
 | 
				
			||||||
 | 
					  else:
 | 
				
			||||||
 | 
					    raise RuntimeError("Missing TEST_SRCDIR environment.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_test_data_path(file_or_dirname: str) -> str:
 | 
				
			||||||
 | 
					  """Returns full test data path."""
 | 
				
			||||||
 | 
					  for (directory, subdirs, files) in os.walk(test_srcdir()):
 | 
				
			||||||
 | 
					    for f in subdirs + files:
 | 
				
			||||||
 | 
					      if f.endswith(file_or_dirname):
 | 
				
			||||||
 | 
					        return os.path.join(directory, f)
 | 
				
			||||||
 | 
					  raise ValueError("No %s in test directory" % file_or_dirname)
 | 
				
			||||||
| 
						 | 
					@ -31,7 +31,7 @@ py_test(
 | 
				
			||||||
        "//mediapipe/tasks/python/components/containers:category",
 | 
					        "//mediapipe/tasks/python/components/containers:category",
 | 
				
			||||||
        "//mediapipe/tasks/python/components/containers:detections",
 | 
					        "//mediapipe/tasks/python/components/containers:detections",
 | 
				
			||||||
        "//mediapipe/tasks/python/core:base_options",
 | 
					        "//mediapipe/tasks/python/core:base_options",
 | 
				
			||||||
        "//mediapipe/tasks/python/test:test_util",
 | 
					        "//mediapipe/tasks/python/test:test_utils",
 | 
				
			||||||
        "//mediapipe/tasks/python/vision:object_detector",
 | 
					        "//mediapipe/tasks/python/vision:object_detector",
 | 
				
			||||||
        "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
 | 
					        "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,7 +25,7 @@ from mediapipe.tasks.python.components.containers import bounding_box as boundin
 | 
				
			||||||
from mediapipe.tasks.python.components.containers import category as category_module
 | 
					from mediapipe.tasks.python.components.containers import category as category_module
 | 
				
			||||||
from mediapipe.tasks.python.components.containers import detections as detections_module
 | 
					from mediapipe.tasks.python.components.containers import detections as detections_module
 | 
				
			||||||
from mediapipe.tasks.python.core import base_options as base_options_module
 | 
					from mediapipe.tasks.python.core import base_options as base_options_module
 | 
				
			||||||
from mediapipe.tasks.python.test import test_util
 | 
					from mediapipe.tasks.python.test import test_utils
 | 
				
			||||||
from mediapipe.tasks.python.vision import object_detector
 | 
					from mediapipe.tasks.python.vision import object_detector
 | 
				
			||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
 | 
					from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -99,8 +99,8 @@ class ObjectDetectorTest(parameterized.TestCase):
 | 
				
			||||||
  def setUp(self):
 | 
					  def setUp(self):
 | 
				
			||||||
    super().setUp()
 | 
					    super().setUp()
 | 
				
			||||||
    self.test_image = _Image.create_from_file(
 | 
					    self.test_image = _Image.create_from_file(
 | 
				
			||||||
        test_util.get_test_data_path(_IMAGE_FILE))
 | 
					        test_utils.get_test_data_path(_IMAGE_FILE))
 | 
				
			||||||
    self.model_path = test_util.get_test_data_path(_MODEL_FILE)
 | 
					    self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_create_from_file_succeeds_with_valid_model_path(self):
 | 
					  def test_create_from_file_succeeds_with_valid_model_path(self):
 | 
				
			||||||
    # Creates with default option and valid model file successfully.
 | 
					    # Creates with default option and valid model file successfully.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										12
									
								
								mediapipe/tasks/testdata/metadata/BUILD
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								mediapipe/tasks/testdata/metadata/BUILD
									
									
									
									
										vendored
									
									
								
							| 
						 | 
					@ -28,9 +28,13 @@ mediapipe_files(srcs = [
 | 
				
			||||||
    "mobile_ica_8bit-without-model-metadata.tflite",
 | 
					    "mobile_ica_8bit-without-model-metadata.tflite",
 | 
				
			||||||
    "mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
 | 
					    "mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
 | 
				
			||||||
    "mobilenet_v1_0.25_224_1_default_1.tflite",
 | 
					    "mobilenet_v1_0.25_224_1_default_1.tflite",
 | 
				
			||||||
 | 
					    "mobilenet_v2_1.0_224_quant.tflite",
 | 
				
			||||||
])
 | 
					])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
exports_files(["external_file"])
 | 
					exports_files([
 | 
				
			||||||
 | 
					    "external_file",
 | 
				
			||||||
 | 
					    "golden_json.json",
 | 
				
			||||||
 | 
					])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
filegroup(
 | 
					filegroup(
 | 
				
			||||||
    name = "model_files",
 | 
					    name = "model_files",
 | 
				
			||||||
| 
						 | 
					@ -40,10 +44,14 @@ filegroup(
 | 
				
			||||||
        "mobile_ica_8bit-without-model-metadata.tflite",
 | 
					        "mobile_ica_8bit-without-model-metadata.tflite",
 | 
				
			||||||
        "mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
 | 
					        "mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
 | 
				
			||||||
        "mobilenet_v1_0.25_224_1_default_1.tflite",
 | 
					        "mobilenet_v1_0.25_224_1_default_1.tflite",
 | 
				
			||||||
 | 
					        "mobilenet_v2_1.0_224_quant.tflite",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
filegroup(
 | 
					filegroup(
 | 
				
			||||||
    name = "data_files",
 | 
					    name = "data_files",
 | 
				
			||||||
    srcs = ["external_file"],
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        "external_file",
 | 
				
			||||||
 | 
					        "golden_json.json",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										28
									
								
								mediapipe/tasks/testdata/metadata/golden_json.json
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								mediapipe/tasks/testdata/metadata/golden_json.json
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,28 @@
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  "name": "Mobilenet_quantized",
 | 
				
			||||||
 | 
					  "subgraph_metadata": [
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      "input_tensor_metadata": [
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      ],
 | 
				
			||||||
 | 
					      "output_tensor_metadata": [
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          "associated_files": [
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					              "name": "file2"
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					          ]
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      ]
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  ],
 | 
				
			||||||
 | 
					  "associated_files": [
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      "name": "file1"
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  ],
 | 
				
			||||||
 | 
					  "min_parser_version": "1.0.0"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										12
									
								
								third_party/external_files.bzl
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								third_party/external_files.bzl
									
									
									
									
										vendored
									
									
								
							| 
						 | 
					@ -166,6 +166,12 @@ def external_files():
 | 
				
			||||||
        urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmark_with_attention.tflite?generation=1661875751615925"],
 | 
					        urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmark_with_attention.tflite?generation=1661875751615925"],
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http_file(
 | 
				
			||||||
 | 
					        name = "com_google_mediapipe_golden_json_json",
 | 
				
			||||||
 | 
					        sha256 = "55c0c88748d099aa379930504df62c6c8f1d8874ea52d2f8a925f352c4c7f09c",
 | 
				
			||||||
 | 
					        urls = ["https://storage.googleapis.com/mediapipe-assets/golden_json.json?generation=1664340169675228"],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    http_file(
 | 
					    http_file(
 | 
				
			||||||
        name = "com_google_mediapipe_hair_segmentation_tflite",
 | 
					        name = "com_google_mediapipe_hair_segmentation_tflite",
 | 
				
			||||||
        sha256 = "d2c940c4fd80edeaf38f5d7387d1b4235ee320ed120080df67c663e749e77633",
 | 
					        sha256 = "d2c940c4fd80edeaf38f5d7387d1b4235ee320ed120080df67c663e749e77633",
 | 
				
			||||||
| 
						 | 
					@ -316,6 +322,12 @@ def external_files():
 | 
				
			||||||
        urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite?generation=1661875836078124"],
 | 
					        urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite?generation=1661875836078124"],
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http_file(
 | 
				
			||||||
 | 
					        name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_tflite",
 | 
				
			||||||
 | 
					        sha256 = "f08d447cde49b4e0446428aa921aff0a14ea589fa9c5817b31f83128e9a43c1d",
 | 
				
			||||||
 | 
					        urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.tflite?generation=1664340173966530"],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    http_file(
 | 
					    http_file(
 | 
				
			||||||
        name = "com_google_mediapipe_mobilenet_v2_1_0_224_tflite",
 | 
					        name = "com_google_mediapipe_mobilenet_v2_1_0_224_tflite",
 | 
				
			||||||
        sha256 = "ff5cb7f9e62c92ebdad971f8a98aa6b3106d82a64587a7787c6a385c9e791339",
 | 
					        sha256 = "ff5cb7f9e62c92ebdad971f8a98aa6b3106d82a64587a7787c6a385c9e791339",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user