Internal change

PiperOrigin-RevId: 514001732
This commit is contained in:
Yuqi Li 2023-03-03 22:12:38 -08:00 committed by Copybara-Service
parent c98b4b6ec6
commit dbe4175a08
19 changed files with 435 additions and 24 deletions

View File

@ -14,9 +14,13 @@ stamp_metadata_parser_version(
cc_library(
name = "metadata_extractor",
srcs = ["metadata_extractor.cc"],
hdrs = ["metadata_extractor.h"],
hdrs = [
"metadata_extractor.h",
"metadata_parser_h",
],
visibility = ["//visibility:public"],
deps = [
":metadata_version_utils",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
@ -68,3 +72,10 @@ cc_library(
"@zlib//:zlib_minizip",
],
)
cc_library(
name = "metadata_version_utils",
srcs = ["metadata_version_utils.cc"],
hdrs = ["metadata_version_utils.h"],
deps = ["@com_google_absl//absl/strings"],
)

View File

@ -26,6 +26,8 @@ limitations under the License.
#include "flatbuffers/flatbuffers.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/metadata/metadata_parser.h"
#include "mediapipe/tasks/cc/metadata/metadata_version_utils.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -164,6 +166,18 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer(
return CreateStatusWithPayload(StatusCode::kInternal,
"Expected Model Metadata not to be null.");
}
auto min_parser_version = model_metadata_->min_parser_version();
if (min_parser_version != nullptr &&
CompareVersions(min_parser_version->c_str(), kMetadataParserVersion) >
0) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat(
"Metadata schema version %s is smaller than the minimum version "
"%s to parse the metadata flatbuffer.",
kMetadataParserVersion, min_parser_version->c_str()),
MediaPipeTasksStatus::kMetadataInvalidSchemaVersionError);
}
return ExtractAssociatedFiles(buffer_data, buffer_size);
break;
}
@ -299,6 +313,29 @@ int ModelMetadataExtractor::GetOutputProcessUnitsCount() const {
return output_process_units == nullptr ? 0 : output_process_units->size();
}
const flatbuffers::Vector<flatbuffers::Offset<tflite::CustomMetadata>>*
ModelMetadataExtractor::GetCustomMetadataList() const {
if (model_metadata_ == nullptr ||
model_metadata_->subgraph_metadata() == nullptr) {
return nullptr;
}
return model_metadata_->subgraph_metadata()
->Get(kDefaultSubgraphIndex)
->custom_metadata();
}
const tflite::CustomMetadata* ModelMetadataExtractor::GetCustomMetadata(
int index) const {
return GetItemFromVector<tflite::CustomMetadata>(GetCustomMetadataList(),
index);
}
int ModelMetadataExtractor::GetCustomMetadataCount() const {
const Vector<flatbuffers::Offset<tflite::CustomMetadata>>* custom_medata_vec =
GetCustomMetadataList();
return custom_medata_vec == nullptr ? 0 : custom_medata_vec->size();
}
} // namespace metadata
} // namespace tasks
} // namespace mediapipe

View File

@ -136,6 +136,19 @@ class ModelMetadataExtractor {
// there is no output process units.
int GetOutputProcessUnitsCount() const;
// Gets a list of custom metadata from SubgraphMetadata.custom_metadata,
// could be nullptr.
const flatbuffers::Vector<flatbuffers::Offset<tflite::CustomMetadata>>*
GetCustomMetadataList() const;
// Gets the custom metadata specified by the given index, or nullptr in case
// there is no custom metadata or the index is out of range.
const tflite::CustomMetadata* GetCustomMetadata(int index) const;
// Gets the count of custom metadata. In particular, 0 is returned when
// there is no custom metadata.
int GetCustomMetadataCount() const;
private:
static constexpr int kDefaultSubgraphIndex = 0;
// Private default constructor, called from CreateFromModel().

View File

@ -21,7 +21,7 @@ namespace metadata {
// The version of the metadata parser that this metadata versioning library is
// depending on.
inline constexpr char kMatadataParserVersion[] = "{LATEST_METADATA_PARSER_VERSION}";
inline constexpr char kMetadataParserVersion[] = "{LATEST_METADATA_PARSER_VERSION}";
} // namespace metadata
} // namespace tasks

View File

@ -57,6 +57,7 @@ enum class SchemaMembers {
kContentPropertiesAudioProperties = 8,
kAssociatedFileTypeScannIndexFile = 9,
kAssociatedFileVersion = 10,
kCustomMetadata = 11,
};
// Helper class to compare semantic versions in terms of three integers, major,
@ -125,6 +126,8 @@ Version GetMemberVersion(SchemaMembers member) {
return Version(1, 4, 0);
case SchemaMembers::kAssociatedFileVersion:
return Version(1, 4, 1);
case SchemaMembers::kCustomMetadata:
return Version(1, 5, 0);
default:
// Should never happen.
TFLITE_LOG(FATAL) << "Unsupported schema member: "
@ -281,6 +284,12 @@ void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputTensorGroups),
min_version);
}
// Checks for the options field.
if (table->custom_metadata() != nullptr) {
UpdateMinimumVersion(GetMemberVersion(SchemaMembers::kCustomMetadata),
min_version);
}
}
template <>

View File

@ -0,0 +1,48 @@
#include "mediapipe/tasks/cc/metadata/metadata_version_utils.h"
#include <string>
#include "absl/strings/str_split.h"
namespace mediapipe {
namespace tasks {
namespace metadata {
namespace {
static int32_t GetValueOrZero(const std::vector<std::string> &list,
const int index) {
int32_t value = 0;
if (index <= list.size() - 1) {
value = std::stoi(list[index]);
}
return value;
}
} // namespace
int CompareVersions(absl::string_view version_a, absl::string_view version_b) {
std::vector<std::string> version_a_components =
absl::StrSplit(version_a, '.', absl::SkipEmpty());
std::vector<std::string> version_b_components =
absl::StrSplit(version_b, '.', absl::SkipEmpty());
const int a_length = version_a_components.size();
const int b_length = version_b_components.size();
const int max_length = std::max(a_length, b_length);
for (int i = 0; i < max_length; ++i) {
const int a_val = GetValueOrZero(version_a_components, i);
const int b_val = GetValueOrZero(version_b_components, i);
if (a_val > b_val) {
return 1;
}
if (a_val < b_val) {
return -1;
}
}
return 0;
}
} // namespace metadata
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,21 @@
#ifndef MEDIAPIPE_TASKS_CC_METADATA_METADATA_VERSION_UTIL_H_
#define MEDIAPIPE_TASKS_CC_METADATA_METADATA_VERSION_UTIL_H_
#include "absl/strings/string_view.h"
namespace mediapipe {
namespace tasks {
namespace metadata {
// Compares two versions. The version format is "**.**.**" such as "1.12.3".
// If version_a is newer than version_b, return 1; if version_a is
// older than version_b, return -1; if version_a equals to version_b,
// returns 0. For example, if version_a = 1.12.3 and version_b = 1.12.1,
// version_a is newer than version_b, and the function return is 1.
int CompareVersions(absl::string_view version_a, absl::string_view version_b);
} // namespace metadata
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_METADATA_METADATA_VERSION_UTIL_H_

View File

@ -22,6 +22,7 @@ cc_test(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:str_format",
],
)
@ -44,3 +45,12 @@ cc_test(
"//mediapipe/tasks/cc/metadata:metadata_version",
],
)
cc_test(
name = "metadata_version_utils_test",
srcs = ["metadata_version_utils_test.cc"],
deps = [
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc/metadata:metadata_version_utils",
],
)

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
@ -26,6 +27,7 @@ limitations under the License.
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_parser.h"
namespace mediapipe {
namespace tasks {
@ -43,6 +45,11 @@ constexpr char kMobileIcaWithoutTfLiteMetadata[] =
"mobile_ica_8bit-without-model-metadata.tflite";
constexpr char kMobileIcaWithTfLiteMetadata[] =
"mobile_ica_8bit-with-metadata.tflite";
constexpr char kMobileIcaWithCustomMetadata[] =
"mobile_ica_8bit-with-custom-metadata.tflite";
// `min_parser_version=1000.0.0` for test purpose.
constexpr char kMobileIcaWithLargeMinParseVersion[] =
"mobile_ica_8bit-with-large-min-parser-version.tflite";
constexpr char kMobileIcaWithUnsupportedMetadataVersion[] =
"mobile_ica_8bit-with-unsupported-metadata-version.tflite";
@ -334,6 +341,81 @@ TEST(ModelMetadataExtractorTest, GetModelVersionWorks) {
MP_EXPECT_OK(extractor->GetModelVersion().status());
}
TEST(ModelMetadataExtractorTest, GetCustomMetadataListWorks) {
std::string buffer;
MP_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ModelMetadataExtractor> extractor,
CreateMetadataExtractor(kMobileIcaWithCustomMetadata, &buffer));
EXPECT_TRUE(extractor->GetCustomMetadataList() != nullptr);
}
TEST(ModelMetadataExtractorTest,
GetCustomMetadataListWithoutTfLiteMetadataWorks) {
std::string buffer;
MP_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ModelMetadataExtractor> extractor,
CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer));
EXPECT_TRUE(extractor->GetCustomMetadataList() == nullptr);
}
TEST(ModelMetadataExtractorTest, GetCustomMetadataWorks) {
std::string buffer;
MP_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ModelMetadataExtractor> extractor,
CreateMetadataExtractor(kMobileIcaWithCustomMetadata, &buffer));
EXPECT_TRUE(extractor->GetCustomMetadata(0) != nullptr);
}
TEST(ModelMetadataExtractorTest, GetCustomMetadataWithoutTfLiteMetadataWorks) {
std::string buffer;
MP_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ModelMetadataExtractor> extractor,
CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer));
EXPECT_TRUE(extractor->GetCustomMetadata(0) == nullptr);
}
TEST(ModelMetadataExtractorTest, GetCustomMetadataWithOutOfRangeIndexWorks) {
std::string buffer;
MP_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ModelMetadataExtractor> extractor,
CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer));
EXPECT_TRUE(extractor->GetCustomMetadata(-1) == nullptr);
EXPECT_TRUE(extractor->GetCustomMetadata(2) == nullptr);
}
TEST(ModelMetadataExtractorTest, GetCustomMetadataCountWorks) {
std::string buffer;
MP_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ModelMetadataExtractor> extractor,
CreateMetadataExtractor(kMobileIcaWithCustomMetadata, &buffer));
EXPECT_EQ(extractor->GetCustomMetadataCount(), 2);
}
TEST(ModelMetadataExtractorTest,
GetCustomMetadataCountWithoutTfLiteMetadataWorks) {
std::string buffer;
MP_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<ModelMetadataExtractor> extractor,
CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer));
EXPECT_EQ(extractor->GetCustomMetadataCount(), 0);
}
TEST(ModelMetadataExtractorTest,
CreateFailsWithIncompatibleWithMinParserVersion) {
std::string buffer;
auto extractor =
CreateMetadataExtractor(kMobileIcaWithLargeMinParseVersion, &buffer);
EXPECT_THAT(extractor.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_EQ(
extractor.status().message(),
absl::StrFormat("Metadata schema version %s is smaller than the minimum "
"version 1000.0.0 to parse the metadata flatbuffer.",
kMetadataParserVersion));
EXPECT_THAT(extractor.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kMetadataInvalidSchemaVersionError))));
}
} // namespace
} // namespace metadata
} // namespace tasks

View File

@ -27,9 +27,9 @@ using ::testing::MatchesRegex;
TEST(MetadataParserTest, MatadataParserVersionIsWellFormed) {
// Validates that the version is well-formed (x.y.z).
#ifdef _WIN32
EXPECT_THAT(kMatadataParserVersion, MatchesRegex("\\d+\\.\\d+\\.\\d+"));
EXPECT_THAT(kMetadataParserVersion, MatchesRegex("\\d+\\.\\d+\\.\\d+"));
#else
EXPECT_THAT(kMatadataParserVersion, MatchesRegex("[0-9]+\\.[0-9]+\\.[0-9]+"));
EXPECT_THAT(kMetadataParserVersion, MatchesRegex("[0-9]+\\.[0-9]+\\.[0-9]+"));
#endif // _WIN32
}

View File

@ -37,6 +37,8 @@ using ::tflite::AudioPropertiesBuilder;
using ::tflite::BertTokenizerOptionsBuilder;
using ::tflite::ContentBuilder;
using ::tflite::ContentProperties_AudioProperties;
using ::tflite::CustomMetadata;
using ::tflite::CustomMetadataBuilder;
using ::tflite::ModelMetadataBuilder;
using ::tflite::NormalizationOptionsBuilder;
using ::tflite::ProcessUnit;
@ -483,6 +485,34 @@ TEST(MetadataVersionTest,
EXPECT_THAT(min_version, StrEq("1.4.1"));
}
TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForOptions) {
// Creates a metadata flatbuffer with the field custom_metadata in subgraph
// metadata.
FlatBufferBuilder builder(1024);
auto name = builder.CreateString("custom_metadata");
auto data = builder.CreateVector(std::vector<unsigned char>{'a'});
CustomMetadataBuilder custom_metadata_builder(builder);
custom_metadata_builder.add_name(name);
custom_metadata_builder.add_data(data);
auto custom_metadata = builder.CreateVector(
std::vector<Offset<CustomMetadata>>{custom_metadata_builder.Finish()});
SubGraphMetadataBuilder subgraph_builder(builder);
subgraph_builder.add_custom_metadata(custom_metadata);
auto subgraphs = builder.CreateVector(
std::vector<Offset<SubGraphMetadata>>{subgraph_builder.Finish()});
ModelMetadataBuilder metadata_builder(builder);
metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
kTfLiteOk);
// Validates that the version is exactly 1.5.0.
EXPECT_EQ(min_version, "1.5.0");
}
} // namespace
} // namespace metadata
} // namespace tasks

View File

@ -0,0 +1,43 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/metadata/metadata_version_utils.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
namespace mediapipe {
namespace tasks {
namespace metadata {
namespace {
TEST(MetadataVersionTest, CompareVersions) {
ASSERT_EQ(0, CompareVersions("1.0", "1.0"));
ASSERT_EQ(0, CompareVersions("1", "1.0.0"));
ASSERT_EQ(-1, CompareVersions("2", "3"));
ASSERT_EQ(-1, CompareVersions("1.2", "1.3"));
ASSERT_EQ(-1, CompareVersions("3.2.9", "3.2.10"));
ASSERT_EQ(-1, CompareVersions("10.1.9", "10.2"));
ASSERT_EQ(1, CompareVersions("3", "2"));
ASSERT_EQ(1, CompareVersions("1.3", "1.2"));
ASSERT_EQ(1, CompareVersions("0.95", "0.94.3124"));
ASSERT_EQ(1, CompareVersions("1.1.1.12", "1.1.1.9"));
ASSERT_EQ(1, CompareVersions("1.1.1.12", "1.1.0.13"));
}
} // namespace
} // namespace metadata
} // namespace tasks
} // namespace mediapipe

View File

@ -49,8 +49,7 @@ namespace tflite;
// New fields and types will have associated comments with the schema version
// for which they were added.
//
// TODO: Add LINT change check as needed.
// Schema Semantic version: 1.4.1
// Schema Semantic version: 1.5.0
// This indicates the flatbuffer compatibility. The number will bump up when a
// break change is applied to the schema, such as removing fields or adding new
@ -69,11 +68,11 @@ file_identifier "M001";
// 1.3.0 - Added AudioProperties to ContentProperties.
// 1.4.0 - Added SCANN_INDEX_FILE type to AssociatedFileType.
// 1.4.1 - Added version to AssociatedFile.
// 1.5.0 - Added CustomMetadata in SubGraphMetadata.
// File extension of any written files.
file_extension "tflitemeta";
// TODO: Add LINT change check as needed.
enum AssociatedFileType : byte {
UNKNOWN = 0,
@ -609,6 +608,11 @@ table TensorMetadata {
associated_files:[AssociatedFile];
}
table CustomMetadata {
name:string;
data:[ubyte] (force_align: 16);
}
table SubGraphMetadata {
// Name of the subgraph.
//
@ -676,6 +680,7 @@ table SubGraphMetadata {
// Added in: 1.2.0
output_tensor_groups:[TensorGroup];
custom_metadata:[CustomMetadata];
}
table ModelMetadata {
@ -721,5 +726,6 @@ table ModelMetadata {
// the metadata is populated into a TFLite model.
min_parser_version:string;
}
// metadata_version.cc)
root_type ModelMetadata;

View File

@ -14,6 +14,7 @@
# ==============================================================================
"""Helper classes for common model metadata information."""
import abc
import collections
import csv
import os
@ -377,11 +378,12 @@ class TensorMd:
tensor_name: name of the corresponding tensor [1] in the TFLite model. It is
used to locate the corresponding tensor and decide the order of the tensor
metadata [2] when populating model metadata.
content_range_md: information of content range [3]. [1]:
content_range_md: information of content range [3].
[1]:
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
[2]:
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
[3]:
[3]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385
"""
@ -777,17 +779,18 @@ class ClassificationTensorMd(TensorMd):
order of the tensor metadata [4] when populating model metadata.
score_thresholding_md: information of the score thresholding [5] in the
classification tensor.
content_range_md: information of content range [6]. [1]:
content_range_md: information of content range [6].
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
[2]:
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
[3]:
[3]:
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
[4]:
[4]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
[5]:
[5]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
[6]:
[6]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385
"""
self.score_calibration_md = score_calibration_md
@ -890,9 +893,10 @@ class CategoryTensorMd(TensorMd):
name: name of the tensor.
description: description of what the tensor is.
label_files: information of the label files [1] in the category tensor.
content_range_md: information of content range [2]. [1]:
content_range_md: information of content range [2].
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L116
[2]:
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385
"""
# In category tensors, label files are in the type of TENSOR_VALUE_LABELS.
@ -934,9 +938,10 @@ class DetectionOutputTensorsMd:
label_files: information of the label files [1] in the classification
tensor.
score_calibration_md: information of the score calibration files operation
[2] in the classification tensor. [1]:
[2] in the classification tensor.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
[2]:
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
"""
content_range_md = ValueRangeMd(
@ -1010,7 +1015,8 @@ class TensorGroupMd:
Args:
name: name of tensor group.
tensor_names: Names of the tensors to group together, corresponding to
TensorMetadata.name [1]. [1]:
TensorMetadata.name [1].
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L564
"""
self.name = name
@ -1022,3 +1028,14 @@ class TensorGroupMd:
group.name = self.name
group.tensorNames = self.tensor_names
return group
class CustomMetadataMd(abc.ABC):
"""An abstract class of a container for the custom metadata information."""
def __init__(self, name: Optional[str] = None):
self.name = name
@abc.abstractmethod
def create_metadata(self) -> _metadata_fb.CustomMetadataT:
"""Creates the custom metadata based on the information."""

View File

@ -317,6 +317,7 @@ def _create_metadata_buffer(
output_md: Optional[List[metadata_info.TensorMd]] = None,
input_process_units: Optional[List[metadata_fb.ProcessUnitT]] = None,
output_group_md: Optional[List[metadata_info.TensorGroupMd]] = None,
custom_metadata_md: Optional[List[metadata_info.CustomMetadataMd]] = None,
) -> bytearray:
"""Creates a buffer of the metadata.
@ -326,9 +327,11 @@ def _create_metadata_buffer(
input_md: metadata information of the input tensors.
output_md: metadata information of the output tensors.
input_process_units: a lists of metadata of the input process units [1].
output_group_md: a list of metadata of output tensor groups [2]; [1]:
output_group_md: a list of metadata of output tensor groups [2];
custom_metadata_md: a lists of custom metadata.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L655
[2]:
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L677
Returns:
@ -367,6 +370,10 @@ def _create_metadata_buffer(
subgraph_metadata.outputTensorMetadata = output_metadata
if input_process_units:
subgraph_metadata.inputProcessUnits = input_process_units
if custom_metadata_md:
subgraph_metadata.customMetadata = [
m.create_metadata() for m in custom_metadata_md
]
if output_group_md:
subgraph_metadata.outputTensorGroups = [
m.create_metadata() for m in output_group_md
@ -416,6 +423,7 @@ class MetadataWriter(object):
self._output_mds = []
self._output_group_mds = []
self._associated_files = []
self._custom_metadata_mds = []
self._temp_folder = tempfile.TemporaryDirectory()
def __del__(self):
@ -657,6 +665,12 @@ class MetadataWriter(object):
self._output_mds.append(output_md)
return self
def add_custom_metadata(
self, custom_metadata_md: metadata_info.CustomMetadataMd
) -> 'MetadataWriter':
self._custom_metadata_mds.append(custom_metadata_md)
return self
def populate(self) -> Tuple[bytearray, str]:
"""Populates metadata into the TFLite file.
@ -674,6 +688,7 @@ class MetadataWriter(object):
input_md=self._input_mds,
output_md=self._output_mds,
input_process_units=self._input_process_units,
custom_metadata_md=self._custom_metadata_mds,
output_group_md=self._output_group_mds,
)
populator.load_metadata_buffer(metadata_buffer)

View File

@ -33,6 +33,8 @@ py_test(
"//mediapipe/tasks/testdata/metadata:model_files",
],
deps = [
"//mediapipe/tasks/metadata:metadata_schema_py",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_info",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
"//mediapipe/tasks/python/test:test_utils",
],

View File

@ -18,6 +18,8 @@ import tempfile
from absl.testing import absltest
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.test import test_utils
@ -29,6 +31,16 @@ _SCORE_CALIBRATION_FILE = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'score_calibration.txt'))
class TestCustomMetadataMd(metadata_info.CustomMetadataMd):
def create_metadata(self) -> _metadata_fb.CustomMetadataT:
"""Creates the custom metadata based on the information."""
custom_metadata_field = _metadata_fb.CustomMetadataT()
custom_metadata_field.name = self.name
custom_metadata_field.data = b'\x01\x02'
return custom_metadata_field
class LabelsTest(absltest.TestCase):
def test_category_name(self):
@ -415,7 +427,6 @@ class MetadataWriterForTaskTest(absltest.TestCase):
score_thresholding=metadata_writer.ScoreThresholding(
global_score_threshold=0.5))
_, metadata_json = writer.populate()
print(metadata_json)
self.assertJsonEqual(
metadata_json, """{
"subgraph_metadata": [
@ -465,6 +476,46 @@ class MetadataWriterForTaskTest(absltest.TestCase):
}
""")
def test_add_custom_metadata(self):
writer = metadata_writer.MetadataWriter.create(
self.image_classifier_model_buffer
)
writer.add_custom_metadata(
TestCustomMetadataMd(name='test_custom_metadata')
)
_, metadata_json = writer.populate()
self.assertJsonEqual(
metadata_json,
"""
{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "input"
}
],
"output_tensor_metadata": [
{
"name": "MobilenetV1/Predictions/Reshape_1"
}
],
"custom_metadata": [
{
"name": "test_custom_metadata",
"data": [
1,
2
]
}
]
}
],
"min_parser_version": "1.5.0"
}
""",
)
if __name__ == '__main__':
absltest.main()

View File

@ -31,6 +31,8 @@ mediapipe_files(srcs = [
"efficientdet_lite0_v1.json",
"efficientdet_lite0_v1.tflite",
"labelmap.txt",
"mobile_ica_8bit-with-custom-metadata.tflite",
"mobile_ica_8bit-with-large-min-parser-version.tflite",
"mobile_ica_8bit-with-metadata.tflite",
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
"mobile_ica_8bit-without-model-metadata.tflite",
@ -86,6 +88,8 @@ filegroup(
"bert_text_classifier_no_metadata.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite",
"efficientdet_lite0_v1.tflite",
"mobile_ica_8bit-with-custom-metadata.tflite",
"mobile_ica_8bit-with-large-min-parser-version.tflite",
"mobile_ica_8bit-with-metadata.tflite",
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
"mobile_ica_8bit-without-model-metadata.tflite",

View File

@ -544,6 +544,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_with_metadata.tflite?generation=1661875806733025"],
)
http_file(
name = "com_google_mediapipe_mobile_ica_8bit-with-custom-metadata_tflite",
sha256 = "31f34f0dd0dc39e69e9c3deb1e3f3278febeb82ecf57c235834348a75df8fb51",
urls = ["https://storage.googleapis.com/mediapipe-assets/mobile_ica_8bit-with-custom-metadata.tflite?generation=1677906531317767"],
)
http_file(
name = "com_google_mediapipe_mobile_ica_8bit-with-large-min-parser-version_tflite",
sha256 = "53d0ea047682539964820fcfc5dc81f4928957470f453f2065f4c2ab87406803",
urls = ["https://storage.googleapis.com/mediapipe-assets/mobile_ica_8bit-with-large-min-parser-version.tflite?generation=1677906534624784"],
)
http_file(
name = "com_google_mediapipe_mobile_ica_8bit-with-metadata_tflite",
sha256 = "4afa3970d3efd6726d147d505e28c7ff1e4fe1c24be7bcda6b5429eb099777a5",