Internal change
PiperOrigin-RevId: 514001732
This commit is contained in:
parent
c98b4b6ec6
commit
dbe4175a08
|
@ -14,9 +14,13 @@ stamp_metadata_parser_version(
|
|||
cc_library(
|
||||
name = "metadata_extractor",
|
||||
srcs = ["metadata_extractor.cc"],
|
||||
hdrs = ["metadata_extractor.h"],
|
||||
hdrs = [
|
||||
"metadata_extractor.h",
|
||||
"metadata_parser_h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":metadata_version_utils",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
|
||||
|
@ -68,3 +72,10 @@ cc_library(
|
|||
"@zlib//:zlib_minizip",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "metadata_version_utils",
|
||||
srcs = ["metadata_version_utils.cc"],
|
||||
hdrs = ["metadata_version_utils.h"],
|
||||
deps = ["@com_google_absl//absl/strings"],
|
||||
)
|
||||
|
|
|
@ -26,6 +26,8 @@ limitations under the License.
|
|||
#include "flatbuffers/flatbuffers.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_parser.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_version_utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
@ -164,6 +166,18 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer(
|
|||
return CreateStatusWithPayload(StatusCode::kInternal,
|
||||
"Expected Model Metadata not to be null.");
|
||||
}
|
||||
auto min_parser_version = model_metadata_->min_parser_version();
|
||||
if (min_parser_version != nullptr &&
|
||||
CompareVersions(min_parser_version->c_str(), kMetadataParserVersion) >
|
||||
0) {
|
||||
return CreateStatusWithPayload(
|
||||
StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"Metadata schema version %s is smaller than the minimum version "
|
||||
"%s to parse the metadata flatbuffer.",
|
||||
kMetadataParserVersion, min_parser_version->c_str()),
|
||||
MediaPipeTasksStatus::kMetadataInvalidSchemaVersionError);
|
||||
}
|
||||
return ExtractAssociatedFiles(buffer_data, buffer_size);
|
||||
break;
|
||||
}
|
||||
|
@ -299,6 +313,29 @@ int ModelMetadataExtractor::GetOutputProcessUnitsCount() const {
|
|||
return output_process_units == nullptr ? 0 : output_process_units->size();
|
||||
}
|
||||
|
||||
const flatbuffers::Vector<flatbuffers::Offset<tflite::CustomMetadata>>*
|
||||
ModelMetadataExtractor::GetCustomMetadataList() const {
|
||||
if (model_metadata_ == nullptr ||
|
||||
model_metadata_->subgraph_metadata() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return model_metadata_->subgraph_metadata()
|
||||
->Get(kDefaultSubgraphIndex)
|
||||
->custom_metadata();
|
||||
}
|
||||
|
||||
const tflite::CustomMetadata* ModelMetadataExtractor::GetCustomMetadata(
|
||||
int index) const {
|
||||
return GetItemFromVector<tflite::CustomMetadata>(GetCustomMetadataList(),
|
||||
index);
|
||||
}
|
||||
|
||||
int ModelMetadataExtractor::GetCustomMetadataCount() const {
|
||||
const Vector<flatbuffers::Offset<tflite::CustomMetadata>>* custom_medata_vec =
|
||||
GetCustomMetadataList();
|
||||
return custom_medata_vec == nullptr ? 0 : custom_medata_vec->size();
|
||||
}
|
||||
|
||||
} // namespace metadata
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -136,6 +136,19 @@ class ModelMetadataExtractor {
|
|||
// there is no output process units.
|
||||
int GetOutputProcessUnitsCount() const;
|
||||
|
||||
// Gets a list of custom metadata from SubgraphMetadata.custom_metadata,
|
||||
// could be nullptr.
|
||||
const flatbuffers::Vector<flatbuffers::Offset<tflite::CustomMetadata>>*
|
||||
GetCustomMetadataList() const;
|
||||
|
||||
// Gets the custom metadata specified by the given index, or nullptr in case
|
||||
// there is no custom metadata or the index is out of range.
|
||||
const tflite::CustomMetadata* GetCustomMetadata(int index) const;
|
||||
|
||||
// Gets the count of custom metadata. In particular, 0 is returned when
|
||||
// there is no custom metadata.
|
||||
int GetCustomMetadataCount() const;
|
||||
|
||||
private:
|
||||
static constexpr int kDefaultSubgraphIndex = 0;
|
||||
// Private default constructor, called from CreateFromModel().
|
||||
|
|
|
@ -21,7 +21,7 @@ namespace metadata {
|
|||
|
||||
// The version of the metadata parser that this metadata versioning library is
|
||||
// depending on.
|
||||
inline constexpr char kMatadataParserVersion[] = "{LATEST_METADATA_PARSER_VERSION}";
|
||||
inline constexpr char kMetadataParserVersion[] = "{LATEST_METADATA_PARSER_VERSION}";
|
||||
|
||||
} // namespace metadata
|
||||
} // namespace tasks
|
||||
|
|
|
@ -57,6 +57,7 @@ enum class SchemaMembers {
|
|||
kContentPropertiesAudioProperties = 8,
|
||||
kAssociatedFileTypeScannIndexFile = 9,
|
||||
kAssociatedFileVersion = 10,
|
||||
kCustomMetadata = 11,
|
||||
};
|
||||
|
||||
// Helper class to compare semantic versions in terms of three integers, major,
|
||||
|
@ -125,6 +126,8 @@ Version GetMemberVersion(SchemaMembers member) {
|
|||
return Version(1, 4, 0);
|
||||
case SchemaMembers::kAssociatedFileVersion:
|
||||
return Version(1, 4, 1);
|
||||
case SchemaMembers::kCustomMetadata:
|
||||
return Version(1, 5, 0);
|
||||
default:
|
||||
// Should never happen.
|
||||
TFLITE_LOG(FATAL) << "Unsupported schema member: "
|
||||
|
@ -281,6 +284,12 @@ void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
|
|||
GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputTensorGroups),
|
||||
min_version);
|
||||
}
|
||||
|
||||
// Checks for the options field.
|
||||
if (table->custom_metadata() != nullptr) {
|
||||
UpdateMinimumVersion(GetMemberVersion(SchemaMembers::kCustomMetadata),
|
||||
min_version);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
48
mediapipe/tasks/cc/metadata/metadata_version_utils.cc
Normal file
48
mediapipe/tasks/cc/metadata/metadata_version_utils.cc
Normal file
|
@ -0,0 +1,48 @@
|
|||
#include "mediapipe/tasks/cc/metadata/metadata_version_utils.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/str_split.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace metadata {
|
||||
namespace {
|
||||
|
||||
static int32_t GetValueOrZero(const std::vector<std::string> &list,
|
||||
const int index) {
|
||||
int32_t value = 0;
|
||||
if (index <= list.size() - 1) {
|
||||
value = std::stoi(list[index]);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int CompareVersions(absl::string_view version_a, absl::string_view version_b) {
|
||||
std::vector<std::string> version_a_components =
|
||||
absl::StrSplit(version_a, '.', absl::SkipEmpty());
|
||||
std::vector<std::string> version_b_components =
|
||||
absl::StrSplit(version_b, '.', absl::SkipEmpty());
|
||||
|
||||
const int a_length = version_a_components.size();
|
||||
const int b_length = version_b_components.size();
|
||||
const int max_length = std::max(a_length, b_length);
|
||||
|
||||
for (int i = 0; i < max_length; ++i) {
|
||||
const int a_val = GetValueOrZero(version_a_components, i);
|
||||
const int b_val = GetValueOrZero(version_b_components, i);
|
||||
if (a_val > b_val) {
|
||||
return 1;
|
||||
}
|
||||
if (a_val < b_val) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace metadata
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
21
mediapipe/tasks/cc/metadata/metadata_version_utils.h
Normal file
21
mediapipe/tasks/cc/metadata/metadata_version_utils.h
Normal file
|
@ -0,0 +1,21 @@
|
|||
#ifndef MEDIAPIPE_TASKS_CC_METADATA_METADATA_VERSION_UTIL_H_
|
||||
#define MEDIAPIPE_TASKS_CC_METADATA_METADATA_VERSION_UTIL_H_
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace metadata {
|
||||
|
||||
// Compares two versions. The version format is "**.**.**" such as "1.12.3".
|
||||
// If version_a is newer than version_b, return 1; if version_a is
|
||||
// older than version_b, return -1; if version_a equals to version_b,
|
||||
// returns 0. For example, if version_a = 1.12.3 and version_b = 1.12.1,
|
||||
// version_a is newer than version_b, and the function return is 1.
|
||||
int CompareVersions(absl::string_view version_a, absl::string_view version_b);
|
||||
|
||||
} // namespace metadata
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_METADATA_METADATA_VERSION_UTIL_H_
|
|
@ -22,6 +22,7 @@ cc_test(
|
|||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:cord",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -44,3 +45,12 @@ cc_test(
|
|||
"//mediapipe/tasks/cc/metadata:metadata_version",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "metadata_version_utils_test",
|
||||
srcs = ["metadata_version_utils_test.cc"],
|
||||
deps = [
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_version_utils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/cord.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
@ -26,6 +27,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_parser.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -43,6 +45,11 @@ constexpr char kMobileIcaWithoutTfLiteMetadata[] =
|
|||
"mobile_ica_8bit-without-model-metadata.tflite";
|
||||
constexpr char kMobileIcaWithTfLiteMetadata[] =
|
||||
"mobile_ica_8bit-with-metadata.tflite";
|
||||
constexpr char kMobileIcaWithCustomMetadata[] =
|
||||
"mobile_ica_8bit-with-custom-metadata.tflite";
|
||||
// `min_parser_version=1000.0.0` for test purpose.
|
||||
constexpr char kMobileIcaWithLargeMinParseVersion[] =
|
||||
"mobile_ica_8bit-with-large-min-parser-version.tflite";
|
||||
|
||||
constexpr char kMobileIcaWithUnsupportedMetadataVersion[] =
|
||||
"mobile_ica_8bit-with-unsupported-metadata-version.tflite";
|
||||
|
@ -334,6 +341,81 @@ TEST(ModelMetadataExtractorTest, GetModelVersionWorks) {
|
|||
MP_EXPECT_OK(extractor->GetModelVersion().status());
|
||||
}
|
||||
|
||||
TEST(ModelMetadataExtractorTest, GetCustomMetadataListWorks) {
|
||||
std::string buffer;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<ModelMetadataExtractor> extractor,
|
||||
CreateMetadataExtractor(kMobileIcaWithCustomMetadata, &buffer));
|
||||
EXPECT_TRUE(extractor->GetCustomMetadataList() != nullptr);
|
||||
}
|
||||
|
||||
TEST(ModelMetadataExtractorTest,
|
||||
GetCustomMetadataListWithoutTfLiteMetadataWorks) {
|
||||
std::string buffer;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<ModelMetadataExtractor> extractor,
|
||||
CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer));
|
||||
EXPECT_TRUE(extractor->GetCustomMetadataList() == nullptr);
|
||||
}
|
||||
|
||||
TEST(ModelMetadataExtractorTest, GetCustomMetadataWorks) {
|
||||
std::string buffer;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<ModelMetadataExtractor> extractor,
|
||||
CreateMetadataExtractor(kMobileIcaWithCustomMetadata, &buffer));
|
||||
EXPECT_TRUE(extractor->GetCustomMetadata(0) != nullptr);
|
||||
}
|
||||
|
||||
TEST(ModelMetadataExtractorTest, GetCustomMetadataWithoutTfLiteMetadataWorks) {
|
||||
std::string buffer;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<ModelMetadataExtractor> extractor,
|
||||
CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer));
|
||||
EXPECT_TRUE(extractor->GetCustomMetadata(0) == nullptr);
|
||||
}
|
||||
|
||||
TEST(ModelMetadataExtractorTest, GetCustomMetadataWithOutOfRangeIndexWorks) {
|
||||
std::string buffer;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<ModelMetadataExtractor> extractor,
|
||||
CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer));
|
||||
EXPECT_TRUE(extractor->GetCustomMetadata(-1) == nullptr);
|
||||
EXPECT_TRUE(extractor->GetCustomMetadata(2) == nullptr);
|
||||
}
|
||||
|
||||
TEST(ModelMetadataExtractorTest, GetCustomMetadataCountWorks) {
|
||||
std::string buffer;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<ModelMetadataExtractor> extractor,
|
||||
CreateMetadataExtractor(kMobileIcaWithCustomMetadata, &buffer));
|
||||
EXPECT_EQ(extractor->GetCustomMetadataCount(), 2);
|
||||
}
|
||||
|
||||
TEST(ModelMetadataExtractorTest,
|
||||
GetCustomMetadataCountWithoutTfLiteMetadataWorks) {
|
||||
std::string buffer;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<ModelMetadataExtractor> extractor,
|
||||
CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer));
|
||||
EXPECT_EQ(extractor->GetCustomMetadataCount(), 0);
|
||||
}
|
||||
|
||||
TEST(ModelMetadataExtractorTest,
|
||||
CreateFailsWithIncompatibleWithMinParserVersion) {
|
||||
std::string buffer;
|
||||
auto extractor =
|
||||
CreateMetadataExtractor(kMobileIcaWithLargeMinParseVersion, &buffer);
|
||||
EXPECT_THAT(extractor.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_EQ(
|
||||
extractor.status().message(),
|
||||
absl::StrFormat("Metadata schema version %s is smaller than the minimum "
|
||||
"version 1000.0.0 to parse the metadata flatbuffer.",
|
||||
kMetadataParserVersion));
|
||||
EXPECT_THAT(extractor.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kMetadataInvalidSchemaVersionError))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace metadata
|
||||
} // namespace tasks
|
||||
|
|
|
@ -27,9 +27,9 @@ using ::testing::MatchesRegex;
|
|||
TEST(MetadataParserTest, MatadataParserVersionIsWellFormed) {
|
||||
// Validates that the version is well-formed (x.y.z).
|
||||
#ifdef _WIN32
|
||||
EXPECT_THAT(kMatadataParserVersion, MatchesRegex("\\d+\\.\\d+\\.\\d+"));
|
||||
EXPECT_THAT(kMetadataParserVersion, MatchesRegex("\\d+\\.\\d+\\.\\d+"));
|
||||
#else
|
||||
EXPECT_THAT(kMatadataParserVersion, MatchesRegex("[0-9]+\\.[0-9]+\\.[0-9]+"));
|
||||
EXPECT_THAT(kMetadataParserVersion, MatchesRegex("[0-9]+\\.[0-9]+\\.[0-9]+"));
|
||||
#endif // _WIN32
|
||||
}
|
||||
|
||||
|
|
|
@ -37,6 +37,8 @@ using ::tflite::AudioPropertiesBuilder;
|
|||
using ::tflite::BertTokenizerOptionsBuilder;
|
||||
using ::tflite::ContentBuilder;
|
||||
using ::tflite::ContentProperties_AudioProperties;
|
||||
using ::tflite::CustomMetadata;
|
||||
using ::tflite::CustomMetadataBuilder;
|
||||
using ::tflite::ModelMetadataBuilder;
|
||||
using ::tflite::NormalizationOptionsBuilder;
|
||||
using ::tflite::ProcessUnit;
|
||||
|
@ -483,6 +485,34 @@ TEST(MetadataVersionTest,
|
|||
EXPECT_THAT(min_version, StrEq("1.4.1"));
|
||||
}
|
||||
|
||||
TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForOptions) {
|
||||
// Creates a metadata flatbuffer with the field custom_metadata in subgraph
|
||||
// metadata.
|
||||
FlatBufferBuilder builder(1024);
|
||||
auto name = builder.CreateString("custom_metadata");
|
||||
auto data = builder.CreateVector(std::vector<unsigned char>{'a'});
|
||||
CustomMetadataBuilder custom_metadata_builder(builder);
|
||||
custom_metadata_builder.add_name(name);
|
||||
custom_metadata_builder.add_data(data);
|
||||
auto custom_metadata = builder.CreateVector(
|
||||
std::vector<Offset<CustomMetadata>>{custom_metadata_builder.Finish()});
|
||||
SubGraphMetadataBuilder subgraph_builder(builder);
|
||||
subgraph_builder.add_custom_metadata(custom_metadata);
|
||||
auto subgraphs = builder.CreateVector(
|
||||
std::vector<Offset<SubGraphMetadata>>{subgraph_builder.Finish()});
|
||||
ModelMetadataBuilder metadata_builder(builder);
|
||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
kTfLiteOk);
|
||||
// Validates that the version is exactly 1.5.0.
|
||||
EXPECT_EQ(min_version, "1.5.0");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace metadata
|
||||
} // namespace tasks
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_version_utils.h"
|
||||
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace metadata {
|
||||
namespace {
|
||||
|
||||
TEST(MetadataVersionTest, CompareVersions) {
|
||||
ASSERT_EQ(0, CompareVersions("1.0", "1.0"));
|
||||
ASSERT_EQ(0, CompareVersions("1", "1.0.0"));
|
||||
|
||||
ASSERT_EQ(-1, CompareVersions("2", "3"));
|
||||
ASSERT_EQ(-1, CompareVersions("1.2", "1.3"));
|
||||
ASSERT_EQ(-1, CompareVersions("3.2.9", "3.2.10"));
|
||||
ASSERT_EQ(-1, CompareVersions("10.1.9", "10.2"));
|
||||
|
||||
ASSERT_EQ(1, CompareVersions("3", "2"));
|
||||
ASSERT_EQ(1, CompareVersions("1.3", "1.2"));
|
||||
ASSERT_EQ(1, CompareVersions("0.95", "0.94.3124"));
|
||||
ASSERT_EQ(1, CompareVersions("1.1.1.12", "1.1.1.9"));
|
||||
ASSERT_EQ(1, CompareVersions("1.1.1.12", "1.1.0.13"));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace metadata
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -49,8 +49,7 @@ namespace tflite;
|
|||
// New fields and types will have associated comments with the schema version
|
||||
// for which they were added.
|
||||
//
|
||||
// TODO: Add LINT change check as needed.
|
||||
// Schema Semantic version: 1.4.1
|
||||
// Schema Semantic version: 1.5.0
|
||||
|
||||
// This indicates the flatbuffer compatibility. The number will bump up when a
|
||||
// break change is applied to the schema, such as removing fields or adding new
|
||||
|
@ -69,11 +68,11 @@ file_identifier "M001";
|
|||
// 1.3.0 - Added AudioProperties to ContentProperties.
|
||||
// 1.4.0 - Added SCANN_INDEX_FILE type to AssociatedFileType.
|
||||
// 1.4.1 - Added version to AssociatedFile.
|
||||
// 1.5.0 - Added CustomMetadata in SubGraphMetadata.
|
||||
|
||||
// File extension of any written files.
|
||||
file_extension "tflitemeta";
|
||||
|
||||
// TODO: Add LINT change check as needed.
|
||||
enum AssociatedFileType : byte {
|
||||
UNKNOWN = 0,
|
||||
|
||||
|
@ -609,6 +608,11 @@ table TensorMetadata {
|
|||
associated_files:[AssociatedFile];
|
||||
}
|
||||
|
||||
table CustomMetadata {
|
||||
name:string;
|
||||
data:[ubyte] (force_align: 16);
|
||||
}
|
||||
|
||||
table SubGraphMetadata {
|
||||
// Name of the subgraph.
|
||||
//
|
||||
|
@ -676,6 +680,7 @@ table SubGraphMetadata {
|
|||
// Added in: 1.2.0
|
||||
output_tensor_groups:[TensorGroup];
|
||||
|
||||
custom_metadata:[CustomMetadata];
|
||||
}
|
||||
|
||||
table ModelMetadata {
|
||||
|
@ -721,5 +726,6 @@ table ModelMetadata {
|
|||
// the metadata is populated into a TFLite model.
|
||||
min_parser_version:string;
|
||||
}
|
||||
// metadata_version.cc)
|
||||
|
||||
root_type ModelMetadata;
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ==============================================================================
|
||||
"""Helper classes for common model metadata information."""
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import csv
|
||||
import os
|
||||
|
@ -377,11 +378,12 @@ class TensorMd:
|
|||
tensor_name: name of the corresponding tensor [1] in the TFLite model. It is
|
||||
used to locate the corresponding tensor and decide the order of the tensor
|
||||
metadata [2] when populating model metadata.
|
||||
content_range_md: information of content range [3]. [1]:
|
||||
content_range_md: information of content range [3].
|
||||
[1]:
|
||||
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
|
||||
[2]:
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
|
||||
[3]:
|
||||
[3]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385
|
||||
"""
|
||||
|
||||
|
@ -777,17 +779,18 @@ class ClassificationTensorMd(TensorMd):
|
|||
order of the tensor metadata [4] when populating model metadata.
|
||||
score_thresholding_md: information of the score thresholding [5] in the
|
||||
classification tensor.
|
||||
content_range_md: information of content range [6]. [1]:
|
||||
content_range_md: information of content range [6].
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||
[2]:
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
||||
[3]:
|
||||
[3]:
|
||||
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
|
||||
[4]:
|
||||
[4]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
|
||||
[5]:
|
||||
[5]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||
[6]:
|
||||
[6]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385
|
||||
"""
|
||||
self.score_calibration_md = score_calibration_md
|
||||
|
@ -890,9 +893,10 @@ class CategoryTensorMd(TensorMd):
|
|||
name: name of the tensor.
|
||||
description: description of what the tensor is.
|
||||
label_files: information of the label files [1] in the category tensor.
|
||||
content_range_md: information of content range [2]. [1]:
|
||||
content_range_md: information of content range [2].
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L116
|
||||
[2]:
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385
|
||||
"""
|
||||
# In category tensors, label files are in the type of TENSOR_VALUE_LABELS.
|
||||
|
@ -934,9 +938,10 @@ class DetectionOutputTensorsMd:
|
|||
label_files: information of the label files [1] in the classification
|
||||
tensor.
|
||||
score_calibration_md: information of the score calibration files operation
|
||||
[2] in the classification tensor. [1]:
|
||||
[2] in the classification tensor.
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||
[2]:
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
||||
"""
|
||||
content_range_md = ValueRangeMd(
|
||||
|
@ -1010,7 +1015,8 @@ class TensorGroupMd:
|
|||
Args:
|
||||
name: name of tensor group.
|
||||
tensor_names: Names of the tensors to group together, corresponding to
|
||||
TensorMetadata.name [1]. [1]:
|
||||
TensorMetadata.name [1].
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L564
|
||||
"""
|
||||
self.name = name
|
||||
|
@ -1022,3 +1028,14 @@ class TensorGroupMd:
|
|||
group.name = self.name
|
||||
group.tensorNames = self.tensor_names
|
||||
return group
|
||||
|
||||
|
||||
class CustomMetadataMd(abc.ABC):
|
||||
"""An abstract class of a container for the custom metadata information."""
|
||||
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
self.name = name
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_metadata(self) -> _metadata_fb.CustomMetadataT:
|
||||
"""Creates the custom metadata based on the information."""
|
||||
|
|
|
@ -317,6 +317,7 @@ def _create_metadata_buffer(
|
|||
output_md: Optional[List[metadata_info.TensorMd]] = None,
|
||||
input_process_units: Optional[List[metadata_fb.ProcessUnitT]] = None,
|
||||
output_group_md: Optional[List[metadata_info.TensorGroupMd]] = None,
|
||||
custom_metadata_md: Optional[List[metadata_info.CustomMetadataMd]] = None,
|
||||
) -> bytearray:
|
||||
"""Creates a buffer of the metadata.
|
||||
|
||||
|
@ -326,9 +327,11 @@ def _create_metadata_buffer(
|
|||
input_md: metadata information of the input tensors.
|
||||
output_md: metadata information of the output tensors.
|
||||
input_process_units: a lists of metadata of the input process units [1].
|
||||
output_group_md: a list of metadata of output tensor groups [2]; [1]:
|
||||
output_group_md: a list of metadata of output tensor groups [2];
|
||||
custom_metadata_md: a lists of custom metadata.
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L655
|
||||
[2]:
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L677
|
||||
|
||||
Returns:
|
||||
|
@ -367,6 +370,10 @@ def _create_metadata_buffer(
|
|||
subgraph_metadata.outputTensorMetadata = output_metadata
|
||||
if input_process_units:
|
||||
subgraph_metadata.inputProcessUnits = input_process_units
|
||||
if custom_metadata_md:
|
||||
subgraph_metadata.customMetadata = [
|
||||
m.create_metadata() for m in custom_metadata_md
|
||||
]
|
||||
if output_group_md:
|
||||
subgraph_metadata.outputTensorGroups = [
|
||||
m.create_metadata() for m in output_group_md
|
||||
|
@ -416,6 +423,7 @@ class MetadataWriter(object):
|
|||
self._output_mds = []
|
||||
self._output_group_mds = []
|
||||
self._associated_files = []
|
||||
self._custom_metadata_mds = []
|
||||
self._temp_folder = tempfile.TemporaryDirectory()
|
||||
|
||||
def __del__(self):
|
||||
|
@ -657,6 +665,12 @@ class MetadataWriter(object):
|
|||
self._output_mds.append(output_md)
|
||||
return self
|
||||
|
||||
def add_custom_metadata(
|
||||
self, custom_metadata_md: metadata_info.CustomMetadataMd
|
||||
) -> 'MetadataWriter':
|
||||
self._custom_metadata_mds.append(custom_metadata_md)
|
||||
return self
|
||||
|
||||
def populate(self) -> Tuple[bytearray, str]:
|
||||
"""Populates metadata into the TFLite file.
|
||||
|
||||
|
@ -674,6 +688,7 @@ class MetadataWriter(object):
|
|||
input_md=self._input_mds,
|
||||
output_md=self._output_mds,
|
||||
input_process_units=self._input_process_units,
|
||||
custom_metadata_md=self._custom_metadata_mds,
|
||||
output_group_md=self._output_group_mds,
|
||||
)
|
||||
populator.load_metadata_buffer(metadata_buffer)
|
||||
|
|
|
@ -33,6 +33,8 @@ py_test(
|
|||
"//mediapipe/tasks/testdata/metadata:model_files",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/metadata:metadata_schema_py",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_info",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
|
|
|
@ -18,6 +18,8 @@ import tempfile
|
|||
|
||||
from absl.testing import absltest
|
||||
|
||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
@ -29,6 +31,16 @@ _SCORE_CALIBRATION_FILE = test_utils.get_test_data_path(
|
|||
os.path.join(_TEST_DATA_DIR, 'score_calibration.txt'))
|
||||
|
||||
|
||||
class TestCustomMetadataMd(metadata_info.CustomMetadataMd):
|
||||
|
||||
def create_metadata(self) -> _metadata_fb.CustomMetadataT:
|
||||
"""Creates the custom metadata based on the information."""
|
||||
custom_metadata_field = _metadata_fb.CustomMetadataT()
|
||||
custom_metadata_field.name = self.name
|
||||
custom_metadata_field.data = b'\x01\x02'
|
||||
return custom_metadata_field
|
||||
|
||||
|
||||
class LabelsTest(absltest.TestCase):
|
||||
|
||||
def test_category_name(self):
|
||||
|
@ -415,7 +427,6 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
|||
score_thresholding=metadata_writer.ScoreThresholding(
|
||||
global_score_threshold=0.5))
|
||||
_, metadata_json = writer.populate()
|
||||
print(metadata_json)
|
||||
self.assertJsonEqual(
|
||||
metadata_json, """{
|
||||
"subgraph_metadata": [
|
||||
|
@ -465,6 +476,46 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
|||
}
|
||||
""")
|
||||
|
||||
def test_add_custom_metadata(self):
|
||||
writer = metadata_writer.MetadataWriter.create(
|
||||
self.image_classifier_model_buffer
|
||||
)
|
||||
writer.add_custom_metadata(
|
||||
TestCustomMetadataMd(name='test_custom_metadata')
|
||||
)
|
||||
_, metadata_json = writer.populate()
|
||||
self.assertJsonEqual(
|
||||
metadata_json,
|
||||
"""
|
||||
{
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "input"
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "MobilenetV1/Predictions/Reshape_1"
|
||||
}
|
||||
],
|
||||
"custom_metadata": [
|
||||
{
|
||||
"name": "test_custom_metadata",
|
||||
"data": [
|
||||
1,
|
||||
2
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.5.0"
|
||||
}
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
4
mediapipe/tasks/testdata/metadata/BUILD
vendored
4
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -31,6 +31,8 @@ mediapipe_files(srcs = [
|
|||
"efficientdet_lite0_v1.json",
|
||||
"efficientdet_lite0_v1.tflite",
|
||||
"labelmap.txt",
|
||||
"mobile_ica_8bit-with-custom-metadata.tflite",
|
||||
"mobile_ica_8bit-with-large-min-parser-version.tflite",
|
||||
"mobile_ica_8bit-with-metadata.tflite",
|
||||
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
|
||||
"mobile_ica_8bit-without-model-metadata.tflite",
|
||||
|
@ -86,6 +88,8 @@ filegroup(
|
|||
"bert_text_classifier_no_metadata.tflite",
|
||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite",
|
||||
"efficientdet_lite0_v1.tflite",
|
||||
"mobile_ica_8bit-with-custom-metadata.tflite",
|
||||
"mobile_ica_8bit-with-large-min-parser-version.tflite",
|
||||
"mobile_ica_8bit-with-metadata.tflite",
|
||||
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
|
||||
"mobile_ica_8bit-without-model-metadata.tflite",
|
||||
|
|
12
third_party/external_files.bzl
vendored
12
third_party/external_files.bzl
vendored
|
@ -544,6 +544,18 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_with_metadata.tflite?generation=1661875806733025"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobile_ica_8bit-with-custom-metadata_tflite",
|
||||
sha256 = "31f34f0dd0dc39e69e9c3deb1e3f3278febeb82ecf57c235834348a75df8fb51",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobile_ica_8bit-with-custom-metadata.tflite?generation=1677906531317767"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobile_ica_8bit-with-large-min-parser-version_tflite",
|
||||
sha256 = "53d0ea047682539964820fcfc5dc81f4928957470f453f2065f4c2ab87406803",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobile_ica_8bit-with-large-min-parser-version.tflite?generation=1677906534624784"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobile_ica_8bit-with-metadata_tflite",
|
||||
sha256 = "4afa3970d3efd6726d147d505e28c7ff1e4fe1c24be7bcda6b5429eb099777a5",
|
||||
|
|
Loading…
Reference in New Issue
Block a user