diff --git a/mediapipe/tasks/cc/metadata/BUILD b/mediapipe/tasks/cc/metadata/BUILD index ef32dd184..186940bce 100644 --- a/mediapipe/tasks/cc/metadata/BUILD +++ b/mediapipe/tasks/cc/metadata/BUILD @@ -14,9 +14,13 @@ stamp_metadata_parser_version( cc_library( name = "metadata_extractor", srcs = ["metadata_extractor.cc"], - hdrs = ["metadata_extractor.h"], + hdrs = [ + "metadata_extractor.h", + "metadata_parser_h", + ], visibility = ["//visibility:public"], deps = [ + ":metadata_version_utils", "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/metadata/utils:zip_utils", @@ -68,3 +72,10 @@ cc_library( "@zlib//:zlib_minizip", ], ) + +cc_library( + name = "metadata_version_utils", + srcs = ["metadata_version_utils.cc"], + hdrs = ["metadata_version_utils.h"], + deps = ["@com_google_absl//absl/strings"], +) diff --git a/mediapipe/tasks/cc/metadata/metadata_extractor.cc b/mediapipe/tasks/cc/metadata/metadata_extractor.cc index 4bc3e8ba0..4d6f526f5 100644 --- a/mediapipe/tasks/cc/metadata/metadata_extractor.cc +++ b/mediapipe/tasks/cc/metadata/metadata_extractor.cc @@ -26,6 +26,8 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/metadata/metadata_parser.h" +#include "mediapipe/tasks/cc/metadata/metadata_version_utils.h" #include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -164,6 +166,18 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer( return CreateStatusWithPayload(StatusCode::kInternal, "Expected Model Metadata not to be null."); } + auto min_parser_version = model_metadata_->min_parser_version(); + if (min_parser_version != nullptr && + CompareVersions(min_parser_version->c_str(), kMetadataParserVersion) > + 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Metadata schema version %s is smaller than the minimum version " + "%s to parse the metadata flatbuffer.", + kMetadataParserVersion, min_parser_version->c_str()), + MediaPipeTasksStatus::kMetadataInvalidSchemaVersionError); + } return ExtractAssociatedFiles(buffer_data, buffer_size); break; } @@ -299,6 +313,29 @@ int ModelMetadataExtractor::GetOutputProcessUnitsCount() const { return output_process_units == nullptr ? 0 : output_process_units->size(); } +const flatbuffers::Vector>* +ModelMetadataExtractor::GetCustomMetadataList() const { + if (model_metadata_ == nullptr || + model_metadata_->subgraph_metadata() == nullptr) { + return nullptr; + } + return model_metadata_->subgraph_metadata() + ->Get(kDefaultSubgraphIndex) + ->custom_metadata(); +} + +const tflite::CustomMetadata* ModelMetadataExtractor::GetCustomMetadata( + int index) const { + return GetItemFromVector(GetCustomMetadataList(), + index); +} + +int ModelMetadataExtractor::GetCustomMetadataCount() const { + const Vector>* custom_medata_vec = + GetCustomMetadataList(); + return custom_medata_vec == nullptr ? 0 : custom_medata_vec->size(); +} + } // namespace metadata } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/metadata/metadata_extractor.h b/mediapipe/tasks/cc/metadata/metadata_extractor.h index e74ac50a3..b88eda863 100644 --- a/mediapipe/tasks/cc/metadata/metadata_extractor.h +++ b/mediapipe/tasks/cc/metadata/metadata_extractor.h @@ -136,6 +136,19 @@ class ModelMetadataExtractor { // there is no output process units. int GetOutputProcessUnitsCount() const; + // Gets a list of custom metadata from SubgraphMetadata.custom_metadata, + // could be nullptr. + const flatbuffers::Vector>* + GetCustomMetadataList() const; + + // Gets the custom metadata specified by the given index, or nullptr in case + // there is no custom metadata or the index is out of range. + const tflite::CustomMetadata* GetCustomMetadata(int index) const; + + // Gets the count of custom metadata. In particular, 0 is returned when + // there is no custom metadata. + int GetCustomMetadataCount() const; + private: static constexpr int kDefaultSubgraphIndex = 0; // Private default constructor, called from CreateFromModel(). diff --git a/mediapipe/tasks/cc/metadata/metadata_parser.h.template b/mediapipe/tasks/cc/metadata/metadata_parser.h.template index f5ebfa04d..28c38af82 100644 --- a/mediapipe/tasks/cc/metadata/metadata_parser.h.template +++ b/mediapipe/tasks/cc/metadata/metadata_parser.h.template @@ -21,7 +21,7 @@ namespace metadata { // The version of the metadata parser that this metadata versioning library is // depending on. -inline constexpr char kMatadataParserVersion[] = "{LATEST_METADATA_PARSER_VERSION}"; +inline constexpr char kMetadataParserVersion[] = "{LATEST_METADATA_PARSER_VERSION}"; } // namespace metadata } // namespace tasks diff --git a/mediapipe/tasks/cc/metadata/metadata_version.cc b/mediapipe/tasks/cc/metadata/metadata_version.cc index 7b9f123cb..e836eb7ec 100644 --- a/mediapipe/tasks/cc/metadata/metadata_version.cc +++ b/mediapipe/tasks/cc/metadata/metadata_version.cc @@ -57,6 +57,7 @@ enum class SchemaMembers { kContentPropertiesAudioProperties = 8, kAssociatedFileTypeScannIndexFile = 9, kAssociatedFileVersion = 10, + kCustomMetadata = 11, }; // Helper class to compare semantic versions in terms of three integers, major, @@ -125,6 +126,8 @@ Version GetMemberVersion(SchemaMembers member) { return Version(1, 4, 0); case SchemaMembers::kAssociatedFileVersion: return Version(1, 4, 1); + case SchemaMembers::kCustomMetadata: + return Version(1, 5, 0); default: // Should never happen. TFLITE_LOG(FATAL) << "Unsupported schema member: " @@ -281,6 +284,12 @@ void UpdateMinimumVersionForTable( GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputTensorGroups), min_version); } + + // Checks for the options field. + if (table->custom_metadata() != nullptr) { + UpdateMinimumVersion(GetMemberVersion(SchemaMembers::kCustomMetadata), + min_version); + } } template <> diff --git a/mediapipe/tasks/cc/metadata/metadata_version_utils.cc b/mediapipe/tasks/cc/metadata/metadata_version_utils.cc new file mode 100644 index 000000000..6a7b6e071 --- /dev/null +++ b/mediapipe/tasks/cc/metadata/metadata_version_utils.cc @@ -0,0 +1,48 @@ +#include "mediapipe/tasks/cc/metadata/metadata_version_utils.h" + +#include + +#include "absl/strings/str_split.h" + +namespace mediapipe { +namespace tasks { +namespace metadata { +namespace { + +static int32_t GetValueOrZero(const std::vector &list, + const int index) { + int32_t value = 0; + if (index <= list.size() - 1) { + value = std::stoi(list[index]); + } + return value; +} + +} // namespace + +int CompareVersions(absl::string_view version_a, absl::string_view version_b) { + std::vector version_a_components = + absl::StrSplit(version_a, '.', absl::SkipEmpty()); + std::vector version_b_components = + absl::StrSplit(version_b, '.', absl::SkipEmpty()); + + const int a_length = version_a_components.size(); + const int b_length = version_b_components.size(); + const int max_length = std::max(a_length, b_length); + + for (int i = 0; i < max_length; ++i) { + const int a_val = GetValueOrZero(version_a_components, i); + const int b_val = GetValueOrZero(version_b_components, i); + if (a_val > b_val) { + return 1; + } + if (a_val < b_val) { + return -1; + } + } + return 0; +} + +} // namespace metadata +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/metadata/metadata_version_utils.h b/mediapipe/tasks/cc/metadata/metadata_version_utils.h new file mode 100644 index 000000000..3a2bfb666 --- /dev/null +++ b/mediapipe/tasks/cc/metadata/metadata_version_utils.h @@ -0,0 +1,21 @@ +#ifndef MEDIAPIPE_TASKS_CC_METADATA_METADATA_VERSION_UTIL_H_ +#define MEDIAPIPE_TASKS_CC_METADATA_METADATA_VERSION_UTIL_H_ + +#include "absl/strings/string_view.h" + +namespace mediapipe { +namespace tasks { +namespace metadata { + +// Compares two versions. The version format is "**.**.**" such as "1.12.3". +// If version_a is newer than version_b, return 1; if version_a is +// older than version_b, return -1; if version_a equals to version_b, +// returns 0. For example, if version_a = 1.12.3 and version_b = 1.12.1, +// version_a is newer than version_b, and the function return is 1. +int CompareVersions(absl::string_view version_a, absl::string_view version_b); + +} // namespace metadata +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_METADATA_METADATA_VERSION_UTIL_H_ diff --git a/mediapipe/tasks/cc/metadata/tests/BUILD b/mediapipe/tasks/cc/metadata/tests/BUILD index 33cbf6b54..7732b3460 100644 --- a/mediapipe/tasks/cc/metadata/tests/BUILD +++ b/mediapipe/tasks/cc/metadata/tests/BUILD @@ -22,6 +22,7 @@ cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:str_format", ], ) @@ -44,3 +45,12 @@ cc_test( "//mediapipe/tasks/cc/metadata:metadata_version", ], ) + +cc_test( + name = "metadata_version_utils_test", + srcs = ["metadata_version_utils_test.cc"], + deps = [ + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/metadata:metadata_version_utils", + ], +) diff --git a/mediapipe/tasks/cc/metadata/tests/metadata_extractor_test.cc b/mediapipe/tasks/cc/metadata/tests/metadata_extractor_test.cc index 41f664158..0e05e5167 100644 --- a/mediapipe/tasks/cc/metadata/tests/metadata_extractor_test.cc +++ b/mediapipe/tasks/cc/metadata/tests/metadata_extractor_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -26,6 +27,7 @@ limitations under the License. #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/metadata_parser.h" namespace mediapipe { namespace tasks { @@ -43,6 +45,11 @@ constexpr char kMobileIcaWithoutTfLiteMetadata[] = "mobile_ica_8bit-without-model-metadata.tflite"; constexpr char kMobileIcaWithTfLiteMetadata[] = "mobile_ica_8bit-with-metadata.tflite"; +constexpr char kMobileIcaWithCustomMetadata[] = + "mobile_ica_8bit-with-custom-metadata.tflite"; +// `min_parser_version=1000.0.0` for test purpose. +constexpr char kMobileIcaWithLargeMinParseVersion[] = + "mobile_ica_8bit-with-large-min-parser-version.tflite"; constexpr char kMobileIcaWithUnsupportedMetadataVersion[] = "mobile_ica_8bit-with-unsupported-metadata-version.tflite"; @@ -334,6 +341,81 @@ TEST(ModelMetadataExtractorTest, GetModelVersionWorks) { MP_EXPECT_OK(extractor->GetModelVersion().status()); } +TEST(ModelMetadataExtractorTest, GetCustomMetadataListWorks) { + std::string buffer; + MP_ASSERT_OK_AND_ASSIGN( + std::unique_ptr extractor, + CreateMetadataExtractor(kMobileIcaWithCustomMetadata, &buffer)); + EXPECT_TRUE(extractor->GetCustomMetadataList() != nullptr); +} + +TEST(ModelMetadataExtractorTest, + GetCustomMetadataListWithoutTfLiteMetadataWorks) { + std::string buffer; + MP_ASSERT_OK_AND_ASSIGN( + std::unique_ptr extractor, + CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer)); + EXPECT_TRUE(extractor->GetCustomMetadataList() == nullptr); +} + +TEST(ModelMetadataExtractorTest, GetCustomMetadataWorks) { + std::string buffer; + MP_ASSERT_OK_AND_ASSIGN( + std::unique_ptr extractor, + CreateMetadataExtractor(kMobileIcaWithCustomMetadata, &buffer)); + EXPECT_TRUE(extractor->GetCustomMetadata(0) != nullptr); +} + +TEST(ModelMetadataExtractorTest, GetCustomMetadataWithoutTfLiteMetadataWorks) { + std::string buffer; + MP_ASSERT_OK_AND_ASSIGN( + std::unique_ptr extractor, + CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer)); + EXPECT_TRUE(extractor->GetCustomMetadata(0) == nullptr); +} + +TEST(ModelMetadataExtractorTest, GetCustomMetadataWithOutOfRangeIndexWorks) { + std::string buffer; + MP_ASSERT_OK_AND_ASSIGN( + std::unique_ptr extractor, + CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer)); + EXPECT_TRUE(extractor->GetCustomMetadata(-1) == nullptr); + EXPECT_TRUE(extractor->GetCustomMetadata(2) == nullptr); +} + +TEST(ModelMetadataExtractorTest, GetCustomMetadataCountWorks) { + std::string buffer; + MP_ASSERT_OK_AND_ASSIGN( + std::unique_ptr extractor, + CreateMetadataExtractor(kMobileIcaWithCustomMetadata, &buffer)); + EXPECT_EQ(extractor->GetCustomMetadataCount(), 2); +} + +TEST(ModelMetadataExtractorTest, + GetCustomMetadataCountWithoutTfLiteMetadataWorks) { + std::string buffer; + MP_ASSERT_OK_AND_ASSIGN( + std::unique_ptr extractor, + CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer)); + EXPECT_EQ(extractor->GetCustomMetadataCount(), 0); +} + +TEST(ModelMetadataExtractorTest, + CreateFailsWithIncompatibleWithMinParserVersion) { + std::string buffer; + auto extractor = + CreateMetadataExtractor(kMobileIcaWithLargeMinParseVersion, &buffer); + EXPECT_THAT(extractor.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ( + extractor.status().message(), + absl::StrFormat("Metadata schema version %s is smaller than the minimum " + "version 1000.0.0 to parse the metadata flatbuffer.", + kMetadataParserVersion)); + EXPECT_THAT(extractor.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kMetadataInvalidSchemaVersionError)))); +} + } // namespace } // namespace metadata } // namespace tasks diff --git a/mediapipe/tasks/cc/metadata/tests/metadata_parser_test.cc b/mediapipe/tasks/cc/metadata/tests/metadata_parser_test.cc index 1d2e22cc7..0d613e65e 100644 --- a/mediapipe/tasks/cc/metadata/tests/metadata_parser_test.cc +++ b/mediapipe/tasks/cc/metadata/tests/metadata_parser_test.cc @@ -27,9 +27,9 @@ using ::testing::MatchesRegex; TEST(MetadataParserTest, MatadataParserVersionIsWellFormed) { // Validates that the version is well-formed (x.y.z). #ifdef _WIN32 - EXPECT_THAT(kMatadataParserVersion, MatchesRegex("\\d+\\.\\d+\\.\\d+")); + EXPECT_THAT(kMetadataParserVersion, MatchesRegex("\\d+\\.\\d+\\.\\d+")); #else - EXPECT_THAT(kMatadataParserVersion, MatchesRegex("[0-9]+\\.[0-9]+\\.[0-9]+")); + EXPECT_THAT(kMetadataParserVersion, MatchesRegex("[0-9]+\\.[0-9]+\\.[0-9]+")); #endif // _WIN32 } diff --git a/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc b/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc index 967853028..273c91685 100644 --- a/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc +++ b/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc @@ -37,6 +37,8 @@ using ::tflite::AudioPropertiesBuilder; using ::tflite::BertTokenizerOptionsBuilder; using ::tflite::ContentBuilder; using ::tflite::ContentProperties_AudioProperties; +using ::tflite::CustomMetadata; +using ::tflite::CustomMetadataBuilder; using ::tflite::ModelMetadataBuilder; using ::tflite::NormalizationOptionsBuilder; using ::tflite::ProcessUnit; @@ -483,6 +485,34 @@ TEST(MetadataVersionTest, EXPECT_THAT(min_version, StrEq("1.4.1")); } +TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForOptions) { + // Creates a metadata flatbuffer with the field custom_metadata in subgraph + // metadata. + FlatBufferBuilder builder(1024); + auto name = builder.CreateString("custom_metadata"); + auto data = builder.CreateVector(std::vector{'a'}); + CustomMetadataBuilder custom_metadata_builder(builder); + custom_metadata_builder.add_name(name); + custom_metadata_builder.add_data(data); + auto custom_metadata = builder.CreateVector( + std::vector>{custom_metadata_builder.Finish()}); + SubGraphMetadataBuilder subgraph_builder(builder); + subgraph_builder.add_custom_metadata(custom_metadata); + auto subgraphs = builder.CreateVector( + std::vector>{subgraph_builder.Finish()}); + ModelMetadataBuilder metadata_builder(builder); + metadata_builder.add_subgraph_metadata(subgraphs); + FinishModelMetadataBuffer(builder, metadata_builder.Finish()); + + // Gets the mimimum metadata parser version. + std::string min_version; + EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), + builder.GetSize(), &min_version), + kTfLiteOk); + // Validates that the version is exactly 1.5.0. + EXPECT_EQ(min_version, "1.5.0"); +} + } // namespace } // namespace metadata } // namespace tasks diff --git a/mediapipe/tasks/cc/metadata/tests/metadata_version_utils_test.cc b/mediapipe/tasks/cc/metadata/tests/metadata_version_utils_test.cc new file mode 100644 index 000000000..eaaa39f0e --- /dev/null +++ b/mediapipe/tasks/cc/metadata/tests/metadata_version_utils_test.cc @@ -0,0 +1,43 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "mediapipe/tasks/cc/metadata/metadata_version_utils.h" + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" + +namespace mediapipe { +namespace tasks { +namespace metadata { +namespace { + +TEST(MetadataVersionTest, CompareVersions) { + ASSERT_EQ(0, CompareVersions("1.0", "1.0")); + ASSERT_EQ(0, CompareVersions("1", "1.0.0")); + + ASSERT_EQ(-1, CompareVersions("2", "3")); + ASSERT_EQ(-1, CompareVersions("1.2", "1.3")); + ASSERT_EQ(-1, CompareVersions("3.2.9", "3.2.10")); + ASSERT_EQ(-1, CompareVersions("10.1.9", "10.2")); + + ASSERT_EQ(1, CompareVersions("3", "2")); + ASSERT_EQ(1, CompareVersions("1.3", "1.2")); + ASSERT_EQ(1, CompareVersions("0.95", "0.94.3124")); + ASSERT_EQ(1, CompareVersions("1.1.1.12", "1.1.1.9")); + ASSERT_EQ(1, CompareVersions("1.1.1.12", "1.1.0.13")); +} +} // namespace +} // namespace metadata +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/metadata/metadata_schema.fbs b/mediapipe/tasks/metadata/metadata_schema.fbs index 933fdfb2a..56161b711 100644 --- a/mediapipe/tasks/metadata/metadata_schema.fbs +++ b/mediapipe/tasks/metadata/metadata_schema.fbs @@ -49,8 +49,7 @@ namespace tflite; // New fields and types will have associated comments with the schema version // for which they were added. // -// TODO: Add LINT change check as needed. -// Schema Semantic version: 1.4.1 +// Schema Semantic version: 1.5.0 // This indicates the flatbuffer compatibility. The number will bump up when a // break change is applied to the schema, such as removing fields or adding new @@ -69,11 +68,11 @@ file_identifier "M001"; // 1.3.0 - Added AudioProperties to ContentProperties. // 1.4.0 - Added SCANN_INDEX_FILE type to AssociatedFileType. // 1.4.1 - Added version to AssociatedFile. +// 1.5.0 - Added CustomMetadata in SubGraphMetadata. // File extension of any written files. file_extension "tflitemeta"; -// TODO: Add LINT change check as needed. enum AssociatedFileType : byte { UNKNOWN = 0, @@ -609,6 +608,11 @@ table TensorMetadata { associated_files:[AssociatedFile]; } +table CustomMetadata { + name:string; + data:[ubyte] (force_align: 16); +} + table SubGraphMetadata { // Name of the subgraph. // @@ -676,6 +680,7 @@ table SubGraphMetadata { // Added in: 1.2.0 output_tensor_groups:[TensorGroup]; + custom_metadata:[CustomMetadata]; } table ModelMetadata { @@ -721,5 +726,6 @@ table ModelMetadata { // the metadata is populated into a TFLite model. min_parser_version:string; } +// metadata_version.cc) root_type ModelMetadata; diff --git a/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py b/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py index 6428b835f..4794a12fc 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py +++ b/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py @@ -14,6 +14,7 @@ # ============================================================================== """Helper classes for common model metadata information.""" +import abc import collections import csv import os @@ -377,11 +378,12 @@ class TensorMd: tensor_name: name of the corresponding tensor [1] in the TFLite model. It is used to locate the corresponding tensor and decide the order of the tensor metadata [2] when populating model metadata. - content_range_md: information of content range [3]. [1]: + content_range_md: information of content range [3]. + [1]: https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136 - [2]: + [2]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640 - [3]: + [3]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385 """ @@ -777,17 +779,18 @@ class ClassificationTensorMd(TensorMd): order of the tensor metadata [4] when populating model metadata. score_thresholding_md: information of the score thresholding [5] in the classification tensor. - content_range_md: information of content range [6]. [1]: + content_range_md: information of content range [6]. + [1]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99 - [2]: + [2]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456 - [3]: + [3]: https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136 - [4]: + [4]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640 - [5]: + [5]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468 - [6]: + [6]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385 """ self.score_calibration_md = score_calibration_md @@ -890,9 +893,10 @@ class CategoryTensorMd(TensorMd): name: name of the tensor. description: description of what the tensor is. label_files: information of the label files [1] in the category tensor. - content_range_md: information of content range [2]. [1]: + content_range_md: information of content range [2]. + [1]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L116 - [2]: + [2]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385 """ # In category tensors, label files are in the type of TENSOR_VALUE_LABELS. @@ -934,9 +938,10 @@ class DetectionOutputTensorsMd: label_files: information of the label files [1] in the classification tensor. score_calibration_md: information of the score calibration files operation - [2] in the classification tensor. [1]: + [2] in the classification tensor. + [1]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99 - [2]: + [2]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456 """ content_range_md = ValueRangeMd( @@ -1010,7 +1015,8 @@ class TensorGroupMd: Args: name: name of tensor group. tensor_names: Names of the tensors to group together, corresponding to - TensorMetadata.name [1]. [1]: + TensorMetadata.name [1]. + [1]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L564 """ self.name = name @@ -1022,3 +1028,14 @@ class TensorGroupMd: group.name = self.name group.tensorNames = self.tensor_names return group + + +class CustomMetadataMd(abc.ABC): + """An abstract class of a container for the custom metadata information.""" + + def __init__(self, name: Optional[str] = None): + self.name = name + + @abc.abstractmethod + def create_metadata(self) -> _metadata_fb.CustomMetadataT: + """Creates the custom metadata based on the information.""" diff --git a/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py b/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py index 240655a88..fda6a64d3 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py +++ b/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py @@ -317,6 +317,7 @@ def _create_metadata_buffer( output_md: Optional[List[metadata_info.TensorMd]] = None, input_process_units: Optional[List[metadata_fb.ProcessUnitT]] = None, output_group_md: Optional[List[metadata_info.TensorGroupMd]] = None, + custom_metadata_md: Optional[List[metadata_info.CustomMetadataMd]] = None, ) -> bytearray: """Creates a buffer of the metadata. @@ -326,9 +327,11 @@ def _create_metadata_buffer( input_md: metadata information of the input tensors. output_md: metadata information of the output tensors. input_process_units: a lists of metadata of the input process units [1]. - output_group_md: a list of metadata of output tensor groups [2]; [1]: + output_group_md: a list of metadata of output tensor groups [2]; + custom_metadata_md: a lists of custom metadata. + [1]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L655 - [2]: + [2]: https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L677 Returns: @@ -367,6 +370,10 @@ def _create_metadata_buffer( subgraph_metadata.outputTensorMetadata = output_metadata if input_process_units: subgraph_metadata.inputProcessUnits = input_process_units + if custom_metadata_md: + subgraph_metadata.customMetadata = [ + m.create_metadata() for m in custom_metadata_md + ] if output_group_md: subgraph_metadata.outputTensorGroups = [ m.create_metadata() for m in output_group_md @@ -416,6 +423,7 @@ class MetadataWriter(object): self._output_mds = [] self._output_group_mds = [] self._associated_files = [] + self._custom_metadata_mds = [] self._temp_folder = tempfile.TemporaryDirectory() def __del__(self): @@ -657,6 +665,12 @@ class MetadataWriter(object): self._output_mds.append(output_md) return self + def add_custom_metadata( + self, custom_metadata_md: metadata_info.CustomMetadataMd + ) -> 'MetadataWriter': + self._custom_metadata_mds.append(custom_metadata_md) + return self + def populate(self) -> Tuple[bytearray, str]: """Populates metadata into the TFLite file. @@ -674,6 +688,7 @@ class MetadataWriter(object): input_md=self._input_mds, output_md=self._output_mds, input_process_units=self._input_process_units, + custom_metadata_md=self._custom_metadata_mds, output_group_md=self._output_group_mds, ) populator.load_metadata_buffer(metadata_buffer) diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD b/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD index 7088e341a..417e3e10c 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD @@ -33,6 +33,8 @@ py_test( "//mediapipe/tasks/testdata/metadata:model_files", ], deps = [ + "//mediapipe/tasks/metadata:metadata_schema_py", + "//mediapipe/tasks/python/metadata/metadata_writers:metadata_info", "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", "//mediapipe/tasks/python/test:test_utils", ], diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py b/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py index 3c2eb407c..846274914 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py @@ -18,6 +18,8 @@ import tempfile from absl.testing import absltest +from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb +from mediapipe.tasks.python.metadata.metadata_writers import metadata_info from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.test import test_utils @@ -29,6 +31,16 @@ _SCORE_CALIBRATION_FILE = test_utils.get_test_data_path( os.path.join(_TEST_DATA_DIR, 'score_calibration.txt')) +class TestCustomMetadataMd(metadata_info.CustomMetadataMd): + + def create_metadata(self) -> _metadata_fb.CustomMetadataT: + """Creates the custom metadata based on the information.""" + custom_metadata_field = _metadata_fb.CustomMetadataT() + custom_metadata_field.name = self.name + custom_metadata_field.data = b'\x01\x02' + return custom_metadata_field + + class LabelsTest(absltest.TestCase): def test_category_name(self): @@ -415,7 +427,6 @@ class MetadataWriterForTaskTest(absltest.TestCase): score_thresholding=metadata_writer.ScoreThresholding( global_score_threshold=0.5)) _, metadata_json = writer.populate() - print(metadata_json) self.assertJsonEqual( metadata_json, """{ "subgraph_metadata": [ @@ -465,6 +476,46 @@ class MetadataWriterForTaskTest(absltest.TestCase): } """) + def test_add_custom_metadata(self): + writer = metadata_writer.MetadataWriter.create( + self.image_classifier_model_buffer + ) + writer.add_custom_metadata( + TestCustomMetadataMd(name='test_custom_metadata') + ) + _, metadata_json = writer.populate() + self.assertJsonEqual( + metadata_json, + """ + { + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input" + } + ], + "output_tensor_metadata": [ + { + "name": "MobilenetV1/Predictions/Reshape_1" + } + ], + "custom_metadata": [ + { + "name": "test_custom_metadata", + "data": [ + 1, + 2 + ] + } + ] + } + ], + "min_parser_version": "1.5.0" + } + """, + ) + if __name__ == '__main__': absltest.main() diff --git a/mediapipe/tasks/testdata/metadata/BUILD b/mediapipe/tasks/testdata/metadata/BUILD index 7b2812260..0ac06caac 100644 --- a/mediapipe/tasks/testdata/metadata/BUILD +++ b/mediapipe/tasks/testdata/metadata/BUILD @@ -31,6 +31,8 @@ mediapipe_files(srcs = [ "efficientdet_lite0_v1.json", "efficientdet_lite0_v1.tflite", "labelmap.txt", + "mobile_ica_8bit-with-custom-metadata.tflite", + "mobile_ica_8bit-with-large-min-parser-version.tflite", "mobile_ica_8bit-with-metadata.tflite", "mobile_ica_8bit-with-unsupported-metadata-version.tflite", "mobile_ica_8bit-without-model-metadata.tflite", @@ -86,6 +88,8 @@ filegroup( "bert_text_classifier_no_metadata.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite", "efficientdet_lite0_v1.tflite", + "mobile_ica_8bit-with-custom-metadata.tflite", + "mobile_ica_8bit-with-large-min-parser-version.tflite", "mobile_ica_8bit-with-metadata.tflite", "mobile_ica_8bit-with-unsupported-metadata-version.tflite", "mobile_ica_8bit-without-model-metadata.tflite", diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index edd33dcb2..1fb53ba51 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -544,6 +544,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_with_metadata.tflite?generation=1661875806733025"], ) + http_file( + name = "com_google_mediapipe_mobile_ica_8bit-with-custom-metadata_tflite", + sha256 = "31f34f0dd0dc39e69e9c3deb1e3f3278febeb82ecf57c235834348a75df8fb51", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobile_ica_8bit-with-custom-metadata.tflite?generation=1677906531317767"], + ) + + http_file( + name = "com_google_mediapipe_mobile_ica_8bit-with-large-min-parser-version_tflite", + sha256 = "53d0ea047682539964820fcfc5dc81f4928957470f453f2065f4c2ab87406803", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobile_ica_8bit-with-large-min-parser-version.tflite?generation=1677906534624784"], + ) + http_file( name = "com_google_mediapipe_mobile_ica_8bit-with-metadata_tflite", sha256 = "4afa3970d3efd6726d147d505e28c7ff1e4fe1c24be7bcda6b5429eb099777a5",