Internal change
PiperOrigin-RevId: 514001732
This commit is contained in:
		
							parent
							
								
									c98b4b6ec6
								
							
						
					
					
						commit
						dbe4175a08
					
				|  | @ -14,9 +14,13 @@ stamp_metadata_parser_version( | ||||||
| cc_library( | cc_library( | ||||||
|     name = "metadata_extractor", |     name = "metadata_extractor", | ||||||
|     srcs = ["metadata_extractor.cc"], |     srcs = ["metadata_extractor.cc"], | ||||||
|     hdrs = ["metadata_extractor.h"], |     hdrs = [ | ||||||
|  |         "metadata_extractor.h", | ||||||
|  |         "metadata_parser_h", | ||||||
|  |     ], | ||||||
|     visibility = ["//visibility:public"], |     visibility = ["//visibility:public"], | ||||||
|     deps = [ |     deps = [ | ||||||
|  |         ":metadata_version_utils", | ||||||
|         "//mediapipe/framework/port:status", |         "//mediapipe/framework/port:status", | ||||||
|         "//mediapipe/tasks/cc:common", |         "//mediapipe/tasks/cc:common", | ||||||
|         "//mediapipe/tasks/cc/metadata/utils:zip_utils", |         "//mediapipe/tasks/cc/metadata/utils:zip_utils", | ||||||
|  | @ -68,3 +72,10 @@ cc_library( | ||||||
|         "@zlib//:zlib_minizip", |         "@zlib//:zlib_minizip", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
|  | 
 | ||||||
|  | cc_library( | ||||||
|  |     name = "metadata_version_utils", | ||||||
|  |     srcs = ["metadata_version_utils.cc"], | ||||||
|  |     hdrs = ["metadata_version_utils.h"], | ||||||
|  |     deps = ["@com_google_absl//absl/strings"], | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | @ -26,6 +26,8 @@ limitations under the License. | ||||||
| #include "flatbuffers/flatbuffers.h" | #include "flatbuffers/flatbuffers.h" | ||||||
| #include "mediapipe/framework/port/status_macros.h" | #include "mediapipe/framework/port/status_macros.h" | ||||||
| #include "mediapipe/tasks/cc/common.h" | #include "mediapipe/tasks/cc/common.h" | ||||||
|  | #include "mediapipe/tasks/cc/metadata/metadata_parser.h" | ||||||
|  | #include "mediapipe/tasks/cc/metadata/metadata_version_utils.h" | ||||||
| #include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" | #include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" | ||||||
| #include "mediapipe/tasks/metadata/metadata_schema_generated.h" | #include "mediapipe/tasks/metadata/metadata_schema_generated.h" | ||||||
| #include "tensorflow/lite/schema/schema_generated.h" | #include "tensorflow/lite/schema/schema_generated.h" | ||||||
|  | @ -164,6 +166,18 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer( | ||||||
|       return CreateStatusWithPayload(StatusCode::kInternal, |       return CreateStatusWithPayload(StatusCode::kInternal, | ||||||
|                                      "Expected Model Metadata not to be null."); |                                      "Expected Model Metadata not to be null."); | ||||||
|     } |     } | ||||||
|  |     auto min_parser_version = model_metadata_->min_parser_version(); | ||||||
|  |     if (min_parser_version != nullptr && | ||||||
|  |         CompareVersions(min_parser_version->c_str(), kMetadataParserVersion) > | ||||||
|  |             0) { | ||||||
|  |       return CreateStatusWithPayload( | ||||||
|  |           StatusCode::kInvalidArgument, | ||||||
|  |           absl::StrFormat( | ||||||
|  |               "Metadata schema version %s is smaller than the minimum version " | ||||||
|  |               "%s to parse the metadata flatbuffer.", | ||||||
|  |               kMetadataParserVersion, min_parser_version->c_str()), | ||||||
|  |           MediaPipeTasksStatus::kMetadataInvalidSchemaVersionError); | ||||||
|  |     } | ||||||
|     return ExtractAssociatedFiles(buffer_data, buffer_size); |     return ExtractAssociatedFiles(buffer_data, buffer_size); | ||||||
|     break; |     break; | ||||||
|   } |   } | ||||||
|  | @ -299,6 +313,29 @@ int ModelMetadataExtractor::GetOutputProcessUnitsCount() const { | ||||||
|   return output_process_units == nullptr ? 0 : output_process_units->size(); |   return output_process_units == nullptr ? 0 : output_process_units->size(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | const flatbuffers::Vector<flatbuffers::Offset<tflite::CustomMetadata>>* | ||||||
|  | ModelMetadataExtractor::GetCustomMetadataList() const { | ||||||
|  |   if (model_metadata_ == nullptr || | ||||||
|  |       model_metadata_->subgraph_metadata() == nullptr) { | ||||||
|  |     return nullptr; | ||||||
|  |   } | ||||||
|  |   return model_metadata_->subgraph_metadata() | ||||||
|  |       ->Get(kDefaultSubgraphIndex) | ||||||
|  |       ->custom_metadata(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | const tflite::CustomMetadata* ModelMetadataExtractor::GetCustomMetadata( | ||||||
|  |     int index) const { | ||||||
|  |   return GetItemFromVector<tflite::CustomMetadata>(GetCustomMetadataList(), | ||||||
|  |                                                    index); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | int ModelMetadataExtractor::GetCustomMetadataCount() const { | ||||||
|  |   const Vector<flatbuffers::Offset<tflite::CustomMetadata>>* custom_medata_vec = | ||||||
|  |       GetCustomMetadataList(); | ||||||
|  |   return custom_medata_vec == nullptr ? 0 : custom_medata_vec->size(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace metadata
 | }  // namespace metadata
 | ||||||
| }  // namespace tasks
 | }  // namespace tasks
 | ||||||
| }  // namespace mediapipe
 | }  // namespace mediapipe
 | ||||||
|  |  | ||||||
|  | @ -136,6 +136,19 @@ class ModelMetadataExtractor { | ||||||
|   // there is no output process units.
 |   // there is no output process units.
 | ||||||
|   int GetOutputProcessUnitsCount() const; |   int GetOutputProcessUnitsCount() const; | ||||||
| 
 | 
 | ||||||
|  |   // Gets a list of custom metadata from SubgraphMetadata.custom_metadata,
 | ||||||
|  |   // could be nullptr.
 | ||||||
|  |   const flatbuffers::Vector<flatbuffers::Offset<tflite::CustomMetadata>>* | ||||||
|  |   GetCustomMetadataList() const; | ||||||
|  | 
 | ||||||
|  |   // Gets the custom metadata specified by the given index, or nullptr in case
 | ||||||
|  |   // there is no custom metadata or the index is out of range.
 | ||||||
|  |   const tflite::CustomMetadata* GetCustomMetadata(int index) const; | ||||||
|  | 
 | ||||||
|  |   // Gets the count of custom metadata. In particular, 0 is returned when
 | ||||||
|  |   // there is no custom metadata.
 | ||||||
|  |   int GetCustomMetadataCount() const; | ||||||
|  | 
 | ||||||
|  private: |  private: | ||||||
|   static constexpr int kDefaultSubgraphIndex = 0; |   static constexpr int kDefaultSubgraphIndex = 0; | ||||||
|   // Private default constructor, called from CreateFromModel().
 |   // Private default constructor, called from CreateFromModel().
 | ||||||
|  |  | ||||||
|  | @ -21,7 +21,7 @@ namespace metadata { | ||||||
| 
 | 
 | ||||||
| // The version of the metadata parser that this metadata versioning library is | // The version of the metadata parser that this metadata versioning library is | ||||||
| // depending on. | // depending on. | ||||||
| inline constexpr char kMatadataParserVersion[] = "{LATEST_METADATA_PARSER_VERSION}"; | inline constexpr char kMetadataParserVersion[] = "{LATEST_METADATA_PARSER_VERSION}"; | ||||||
| 
 | 
 | ||||||
| }  // namespace metadata | }  // namespace metadata | ||||||
| }  // namespace tasks | }  // namespace tasks | ||||||
|  |  | ||||||
|  | @ -57,6 +57,7 @@ enum class SchemaMembers { | ||||||
|   kContentPropertiesAudioProperties = 8, |   kContentPropertiesAudioProperties = 8, | ||||||
|   kAssociatedFileTypeScannIndexFile = 9, |   kAssociatedFileTypeScannIndexFile = 9, | ||||||
|   kAssociatedFileVersion = 10, |   kAssociatedFileVersion = 10, | ||||||
|  |   kCustomMetadata = 11, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| // Helper class to compare semantic versions in terms of three integers, major,
 | // Helper class to compare semantic versions in terms of three integers, major,
 | ||||||
|  | @ -125,6 +126,8 @@ Version GetMemberVersion(SchemaMembers member) { | ||||||
|       return Version(1, 4, 0); |       return Version(1, 4, 0); | ||||||
|     case SchemaMembers::kAssociatedFileVersion: |     case SchemaMembers::kAssociatedFileVersion: | ||||||
|       return Version(1, 4, 1); |       return Version(1, 4, 1); | ||||||
|  |     case SchemaMembers::kCustomMetadata: | ||||||
|  |       return Version(1, 5, 0); | ||||||
|     default: |     default: | ||||||
|       // Should never happen.
 |       // Should never happen.
 | ||||||
|       TFLITE_LOG(FATAL) << "Unsupported schema member: " |       TFLITE_LOG(FATAL) << "Unsupported schema member: " | ||||||
|  | @ -281,6 +284,12 @@ void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>( | ||||||
|         GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputTensorGroups), |         GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputTensorGroups), | ||||||
|         min_version); |         min_version); | ||||||
|   } |   } | ||||||
|  | 
 | ||||||
|  |   // Checks for the options field.
 | ||||||
|  |   if (table->custom_metadata() != nullptr) { | ||||||
|  |     UpdateMinimumVersion(GetMemberVersion(SchemaMembers::kCustomMetadata), | ||||||
|  |                          min_version); | ||||||
|  |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| template <> | template <> | ||||||
|  |  | ||||||
							
								
								
									
										48
									
								
								mediapipe/tasks/cc/metadata/metadata_version_utils.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								mediapipe/tasks/cc/metadata/metadata_version_utils.cc
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,48 @@ | ||||||
|  | #include "mediapipe/tasks/cc/metadata/metadata_version_utils.h" | ||||||
|  | 
 | ||||||
|  | #include <string> | ||||||
|  | 
 | ||||||
|  | #include "absl/strings/str_split.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe { | ||||||
|  | namespace tasks { | ||||||
|  | namespace metadata { | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | static int32_t GetValueOrZero(const std::vector<std::string> &list, | ||||||
|  |                               const int index) { | ||||||
|  |   int32_t value = 0; | ||||||
|  |   if (index <= list.size() - 1) { | ||||||
|  |     value = std::stoi(list[index]); | ||||||
|  |   } | ||||||
|  |   return value; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace
 | ||||||
|  | 
 | ||||||
|  | int CompareVersions(absl::string_view version_a, absl::string_view version_b) { | ||||||
|  |   std::vector<std::string> version_a_components = | ||||||
|  |       absl::StrSplit(version_a, '.', absl::SkipEmpty()); | ||||||
|  |   std::vector<std::string> version_b_components = | ||||||
|  |       absl::StrSplit(version_b, '.', absl::SkipEmpty()); | ||||||
|  | 
 | ||||||
|  |   const int a_length = version_a_components.size(); | ||||||
|  |   const int b_length = version_b_components.size(); | ||||||
|  |   const int max_length = std::max(a_length, b_length); | ||||||
|  | 
 | ||||||
|  |   for (int i = 0; i < max_length; ++i) { | ||||||
|  |     const int a_val = GetValueOrZero(version_a_components, i); | ||||||
|  |     const int b_val = GetValueOrZero(version_b_components, i); | ||||||
|  |     if (a_val > b_val) { | ||||||
|  |       return 1; | ||||||
|  |     } | ||||||
|  |     if (a_val < b_val) { | ||||||
|  |       return -1; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   return 0; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace metadata
 | ||||||
|  | }  // namespace tasks
 | ||||||
|  | }  // namespace mediapipe
 | ||||||
							
								
								
									
										21
									
								
								mediapipe/tasks/cc/metadata/metadata_version_utils.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								mediapipe/tasks/cc/metadata/metadata_version_utils.h
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,21 @@ | ||||||
|  | #ifndef MEDIAPIPE_TASKS_CC_METADATA_METADATA_VERSION_UTIL_H_ | ||||||
|  | #define MEDIAPIPE_TASKS_CC_METADATA_METADATA_VERSION_UTIL_H_ | ||||||
|  | 
 | ||||||
|  | #include "absl/strings/string_view.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe { | ||||||
|  | namespace tasks { | ||||||
|  | namespace metadata { | ||||||
|  | 
 | ||||||
|  | // Compares two versions. The version format is "**.**.**" such as "1.12.3".
 | ||||||
|  | // If version_a is newer than version_b, return 1; if version_a is
 | ||||||
|  | // older than version_b, return -1; if version_a equals to version_b,
 | ||||||
|  | // returns 0. For example, if version_a = 1.12.3 and version_b = 1.12.1,
 | ||||||
|  | // version_a is newer than version_b, and the function return is 1.
 | ||||||
|  | int CompareVersions(absl::string_view version_a, absl::string_view version_b); | ||||||
|  | 
 | ||||||
|  | }  // namespace metadata
 | ||||||
|  | }  // namespace tasks
 | ||||||
|  | }  // namespace mediapipe
 | ||||||
|  | 
 | ||||||
|  | #endif  // MEDIAPIPE_TASKS_CC_METADATA_METADATA_VERSION_UTIL_H_
 | ||||||
|  | @ -22,6 +22,7 @@ cc_test( | ||||||
|         "@com_google_absl//absl/status:statusor", |         "@com_google_absl//absl/status:statusor", | ||||||
|         "@com_google_absl//absl/strings", |         "@com_google_absl//absl/strings", | ||||||
|         "@com_google_absl//absl/strings:cord", |         "@com_google_absl//absl/strings:cord", | ||||||
|  |         "@com_google_absl//absl/strings:str_format", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -44,3 +45,12 @@ cc_test( | ||||||
|         "//mediapipe/tasks/cc/metadata:metadata_version", |         "//mediapipe/tasks/cc/metadata:metadata_version", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
|  | 
 | ||||||
|  | cc_test( | ||||||
|  |     name = "metadata_version_utils_test", | ||||||
|  |     srcs = ["metadata_version_utils_test.cc"], | ||||||
|  |     deps = [ | ||||||
|  |         "//mediapipe/framework/port:gtest_main", | ||||||
|  |         "//mediapipe/tasks/cc/metadata:metadata_version_utils", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | @ -19,6 +19,7 @@ limitations under the License. | ||||||
| #include "absl/status/statusor.h" | #include "absl/status/statusor.h" | ||||||
| #include "absl/strings/cord.h" | #include "absl/strings/cord.h" | ||||||
| #include "absl/strings/str_cat.h" | #include "absl/strings/str_cat.h" | ||||||
|  | #include "absl/strings/str_format.h" | ||||||
| #include "mediapipe/framework/port/file_helpers.h" | #include "mediapipe/framework/port/file_helpers.h" | ||||||
| #include "mediapipe/framework/port/gmock.h" | #include "mediapipe/framework/port/gmock.h" | ||||||
| #include "mediapipe/framework/port/gtest.h" | #include "mediapipe/framework/port/gtest.h" | ||||||
|  | @ -26,6 +27,7 @@ limitations under the License. | ||||||
| #include "mediapipe/framework/port/status_matchers.h" | #include "mediapipe/framework/port/status_matchers.h" | ||||||
| #include "mediapipe/tasks/cc/common.h" | #include "mediapipe/tasks/cc/common.h" | ||||||
| #include "mediapipe/tasks/cc/core/utils.h" | #include "mediapipe/tasks/cc/core/utils.h" | ||||||
|  | #include "mediapipe/tasks/cc/metadata/metadata_parser.h" | ||||||
| 
 | 
 | ||||||
| namespace mediapipe { | namespace mediapipe { | ||||||
| namespace tasks { | namespace tasks { | ||||||
|  | @ -43,6 +45,11 @@ constexpr char kMobileIcaWithoutTfLiteMetadata[] = | ||||||
|     "mobile_ica_8bit-without-model-metadata.tflite"; |     "mobile_ica_8bit-without-model-metadata.tflite"; | ||||||
| constexpr char kMobileIcaWithTfLiteMetadata[] = | constexpr char kMobileIcaWithTfLiteMetadata[] = | ||||||
|     "mobile_ica_8bit-with-metadata.tflite"; |     "mobile_ica_8bit-with-metadata.tflite"; | ||||||
|  | constexpr char kMobileIcaWithCustomMetadata[] = | ||||||
|  |     "mobile_ica_8bit-with-custom-metadata.tflite"; | ||||||
|  | // `min_parser_version=1000.0.0` for test purpose.
 | ||||||
|  | constexpr char kMobileIcaWithLargeMinParseVersion[] = | ||||||
|  |     "mobile_ica_8bit-with-large-min-parser-version.tflite"; | ||||||
| 
 | 
 | ||||||
| constexpr char kMobileIcaWithUnsupportedMetadataVersion[] = | constexpr char kMobileIcaWithUnsupportedMetadataVersion[] = | ||||||
|     "mobile_ica_8bit-with-unsupported-metadata-version.tflite"; |     "mobile_ica_8bit-with-unsupported-metadata-version.tflite"; | ||||||
|  | @ -334,6 +341,81 @@ TEST(ModelMetadataExtractorTest, GetModelVersionWorks) { | ||||||
|   MP_EXPECT_OK(extractor->GetModelVersion().status()); |   MP_EXPECT_OK(extractor->GetModelVersion().status()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST(ModelMetadataExtractorTest, GetCustomMetadataListWorks) { | ||||||
|  |   std::string buffer; | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN( | ||||||
|  |       std::unique_ptr<ModelMetadataExtractor> extractor, | ||||||
|  |       CreateMetadataExtractor(kMobileIcaWithCustomMetadata, &buffer)); | ||||||
|  |   EXPECT_TRUE(extractor->GetCustomMetadataList() != nullptr); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(ModelMetadataExtractorTest, | ||||||
|  |      GetCustomMetadataListWithoutTfLiteMetadataWorks) { | ||||||
|  |   std::string buffer; | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN( | ||||||
|  |       std::unique_ptr<ModelMetadataExtractor> extractor, | ||||||
|  |       CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer)); | ||||||
|  |   EXPECT_TRUE(extractor->GetCustomMetadataList() == nullptr); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(ModelMetadataExtractorTest, GetCustomMetadataWorks) { | ||||||
|  |   std::string buffer; | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN( | ||||||
|  |       std::unique_ptr<ModelMetadataExtractor> extractor, | ||||||
|  |       CreateMetadataExtractor(kMobileIcaWithCustomMetadata, &buffer)); | ||||||
|  |   EXPECT_TRUE(extractor->GetCustomMetadata(0) != nullptr); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(ModelMetadataExtractorTest, GetCustomMetadataWithoutTfLiteMetadataWorks) { | ||||||
|  |   std::string buffer; | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN( | ||||||
|  |       std::unique_ptr<ModelMetadataExtractor> extractor, | ||||||
|  |       CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer)); | ||||||
|  |   EXPECT_TRUE(extractor->GetCustomMetadata(0) == nullptr); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(ModelMetadataExtractorTest, GetCustomMetadataWithOutOfRangeIndexWorks) { | ||||||
|  |   std::string buffer; | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN( | ||||||
|  |       std::unique_ptr<ModelMetadataExtractor> extractor, | ||||||
|  |       CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer)); | ||||||
|  |   EXPECT_TRUE(extractor->GetCustomMetadata(-1) == nullptr); | ||||||
|  |   EXPECT_TRUE(extractor->GetCustomMetadata(2) == nullptr); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(ModelMetadataExtractorTest, GetCustomMetadataCountWorks) { | ||||||
|  |   std::string buffer; | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN( | ||||||
|  |       std::unique_ptr<ModelMetadataExtractor> extractor, | ||||||
|  |       CreateMetadataExtractor(kMobileIcaWithCustomMetadata, &buffer)); | ||||||
|  |   EXPECT_EQ(extractor->GetCustomMetadataCount(), 2); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(ModelMetadataExtractorTest, | ||||||
|  |      GetCustomMetadataCountWithoutTfLiteMetadataWorks) { | ||||||
|  |   std::string buffer; | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN( | ||||||
|  |       std::unique_ptr<ModelMetadataExtractor> extractor, | ||||||
|  |       CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer)); | ||||||
|  |   EXPECT_EQ(extractor->GetCustomMetadataCount(), 0); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(ModelMetadataExtractorTest, | ||||||
|  |      CreateFailsWithIncompatibleWithMinParserVersion) { | ||||||
|  |   std::string buffer; | ||||||
|  |   auto extractor = | ||||||
|  |       CreateMetadataExtractor(kMobileIcaWithLargeMinParseVersion, &buffer); | ||||||
|  |   EXPECT_THAT(extractor.status().code(), absl::StatusCode::kInvalidArgument); | ||||||
|  |   EXPECT_EQ( | ||||||
|  |       extractor.status().message(), | ||||||
|  |       absl::StrFormat("Metadata schema version %s is smaller than the minimum " | ||||||
|  |                       "version 1000.0.0 to parse the metadata flatbuffer.", | ||||||
|  |                       kMetadataParserVersion)); | ||||||
|  |   EXPECT_THAT(extractor.status().GetPayload(kMediaPipeTasksPayload), | ||||||
|  |               Optional(absl::Cord(absl::StrCat( | ||||||
|  |                   MediaPipeTasksStatus::kMetadataInvalidSchemaVersionError)))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| }  // namespace metadata
 | }  // namespace metadata
 | ||||||
| }  // namespace tasks
 | }  // namespace tasks
 | ||||||
|  |  | ||||||
|  | @ -27,9 +27,9 @@ using ::testing::MatchesRegex; | ||||||
| TEST(MetadataParserTest, MatadataParserVersionIsWellFormed) { | TEST(MetadataParserTest, MatadataParserVersionIsWellFormed) { | ||||||
|   // Validates that the version is well-formed (x.y.z).
 |   // Validates that the version is well-formed (x.y.z).
 | ||||||
| #ifdef _WIN32 | #ifdef _WIN32 | ||||||
|   EXPECT_THAT(kMatadataParserVersion, MatchesRegex("\\d+\\.\\d+\\.\\d+")); |   EXPECT_THAT(kMetadataParserVersion, MatchesRegex("\\d+\\.\\d+\\.\\d+")); | ||||||
| #else | #else | ||||||
|   EXPECT_THAT(kMatadataParserVersion, MatchesRegex("[0-9]+\\.[0-9]+\\.[0-9]+")); |   EXPECT_THAT(kMetadataParserVersion, MatchesRegex("[0-9]+\\.[0-9]+\\.[0-9]+")); | ||||||
| #endif  // _WIN32
 | #endif  // _WIN32
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -37,6 +37,8 @@ using ::tflite::AudioPropertiesBuilder; | ||||||
| using ::tflite::BertTokenizerOptionsBuilder; | using ::tflite::BertTokenizerOptionsBuilder; | ||||||
| using ::tflite::ContentBuilder; | using ::tflite::ContentBuilder; | ||||||
| using ::tflite::ContentProperties_AudioProperties; | using ::tflite::ContentProperties_AudioProperties; | ||||||
|  | using ::tflite::CustomMetadata; | ||||||
|  | using ::tflite::CustomMetadataBuilder; | ||||||
| using ::tflite::ModelMetadataBuilder; | using ::tflite::ModelMetadataBuilder; | ||||||
| using ::tflite::NormalizationOptionsBuilder; | using ::tflite::NormalizationOptionsBuilder; | ||||||
| using ::tflite::ProcessUnit; | using ::tflite::ProcessUnit; | ||||||
|  | @ -483,6 +485,34 @@ TEST(MetadataVersionTest, | ||||||
|   EXPECT_THAT(min_version, StrEq("1.4.1")); |   EXPECT_THAT(min_version, StrEq("1.4.1")); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForOptions) { | ||||||
|  |   // Creates a metadata flatbuffer with the field custom_metadata in subgraph
 | ||||||
|  |   // metadata.
 | ||||||
|  |   FlatBufferBuilder builder(1024); | ||||||
|  |   auto name = builder.CreateString("custom_metadata"); | ||||||
|  |   auto data = builder.CreateVector(std::vector<unsigned char>{'a'}); | ||||||
|  |   CustomMetadataBuilder custom_metadata_builder(builder); | ||||||
|  |   custom_metadata_builder.add_name(name); | ||||||
|  |   custom_metadata_builder.add_data(data); | ||||||
|  |   auto custom_metadata = builder.CreateVector( | ||||||
|  |       std::vector<Offset<CustomMetadata>>{custom_metadata_builder.Finish()}); | ||||||
|  |   SubGraphMetadataBuilder subgraph_builder(builder); | ||||||
|  |   subgraph_builder.add_custom_metadata(custom_metadata); | ||||||
|  |   auto subgraphs = builder.CreateVector( | ||||||
|  |       std::vector<Offset<SubGraphMetadata>>{subgraph_builder.Finish()}); | ||||||
|  |   ModelMetadataBuilder metadata_builder(builder); | ||||||
|  |   metadata_builder.add_subgraph_metadata(subgraphs); | ||||||
|  |   FinishModelMetadataBuffer(builder, metadata_builder.Finish()); | ||||||
|  | 
 | ||||||
|  |   // Gets the mimimum metadata parser version.
 | ||||||
|  |   std::string min_version; | ||||||
|  |   EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), | ||||||
|  |                                             builder.GetSize(), &min_version), | ||||||
|  |             kTfLiteOk); | ||||||
|  |   // Validates that the version is exactly 1.5.0.
 | ||||||
|  |   EXPECT_EQ(min_version, "1.5.0"); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| }  // namespace metadata
 | }  // namespace metadata
 | ||||||
| }  // namespace tasks
 | }  // namespace tasks
 | ||||||
|  |  | ||||||
|  | @ -0,0 +1,43 @@ | ||||||
|  | /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | #include "mediapipe/tasks/cc/metadata/metadata_version_utils.h" | ||||||
|  | 
 | ||||||
|  | #include "mediapipe/framework/port/gmock.h" | ||||||
|  | #include "mediapipe/framework/port/gtest.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe { | ||||||
|  | namespace tasks { | ||||||
|  | namespace metadata { | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | TEST(MetadataVersionTest, CompareVersions) { | ||||||
|  |   ASSERT_EQ(0, CompareVersions("1.0", "1.0")); | ||||||
|  |   ASSERT_EQ(0, CompareVersions("1", "1.0.0")); | ||||||
|  | 
 | ||||||
|  |   ASSERT_EQ(-1, CompareVersions("2", "3")); | ||||||
|  |   ASSERT_EQ(-1, CompareVersions("1.2", "1.3")); | ||||||
|  |   ASSERT_EQ(-1, CompareVersions("3.2.9", "3.2.10")); | ||||||
|  |   ASSERT_EQ(-1, CompareVersions("10.1.9", "10.2")); | ||||||
|  | 
 | ||||||
|  |   ASSERT_EQ(1, CompareVersions("3", "2")); | ||||||
|  |   ASSERT_EQ(1, CompareVersions("1.3", "1.2")); | ||||||
|  |   ASSERT_EQ(1, CompareVersions("0.95", "0.94.3124")); | ||||||
|  |   ASSERT_EQ(1, CompareVersions("1.1.1.12", "1.1.1.9")); | ||||||
|  |   ASSERT_EQ(1, CompareVersions("1.1.1.12", "1.1.0.13")); | ||||||
|  | } | ||||||
|  | }  // namespace
 | ||||||
|  | }  // namespace metadata
 | ||||||
|  | }  // namespace tasks
 | ||||||
|  | }  // namespace mediapipe
 | ||||||
|  | @ -49,8 +49,7 @@ namespace tflite; | ||||||
| // New fields and types will have associated comments with the schema version | // New fields and types will have associated comments with the schema version | ||||||
| // for which they were added. | // for which they were added. | ||||||
| // | // | ||||||
| // TODO: Add LINT change check as needed. | // Schema Semantic version: 1.5.0 | ||||||
| // Schema Semantic version: 1.4.1 |  | ||||||
| 
 | 
 | ||||||
| // This indicates the flatbuffer compatibility. The number will bump up when a | // This indicates the flatbuffer compatibility. The number will bump up when a | ||||||
| // break change is applied to the schema, such as removing fields or adding new | // break change is applied to the schema, such as removing fields or adding new | ||||||
|  | @ -69,11 +68,11 @@ file_identifier "M001"; | ||||||
| // 1.3.0 - Added AudioProperties to ContentProperties. | // 1.3.0 - Added AudioProperties to ContentProperties. | ||||||
| // 1.4.0 - Added SCANN_INDEX_FILE type to AssociatedFileType. | // 1.4.0 - Added SCANN_INDEX_FILE type to AssociatedFileType. | ||||||
| // 1.4.1 - Added version to AssociatedFile. | // 1.4.1 - Added version to AssociatedFile. | ||||||
|  | // 1.5.0 - Added CustomMetadata in SubGraphMetadata. | ||||||
| 
 | 
 | ||||||
| // File extension of any written files. | // File extension of any written files. | ||||||
| file_extension "tflitemeta"; | file_extension "tflitemeta"; | ||||||
| 
 | 
 | ||||||
| // TODO: Add LINT change check as needed. |  | ||||||
| enum AssociatedFileType : byte { | enum AssociatedFileType : byte { | ||||||
|   UNKNOWN = 0, |   UNKNOWN = 0, | ||||||
| 
 | 
 | ||||||
|  | @ -609,6 +608,11 @@ table TensorMetadata { | ||||||
|   associated_files:[AssociatedFile]; |   associated_files:[AssociatedFile]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | table CustomMetadata { | ||||||
|  |   name:string; | ||||||
|  |   data:[ubyte] (force_align: 16); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| table SubGraphMetadata { | table SubGraphMetadata { | ||||||
|   // Name of the subgraph. |   // Name of the subgraph. | ||||||
|   // |   // | ||||||
|  | @ -676,6 +680,7 @@ table SubGraphMetadata { | ||||||
|   // Added in: 1.2.0 |   // Added in: 1.2.0 | ||||||
|   output_tensor_groups:[TensorGroup]; |   output_tensor_groups:[TensorGroup]; | ||||||
| 
 | 
 | ||||||
|  |   custom_metadata:[CustomMetadata]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| table ModelMetadata { | table ModelMetadata { | ||||||
|  | @ -721,5 +726,6 @@ table ModelMetadata { | ||||||
|   // the metadata is populated into a TFLite model. |   // the metadata is populated into a TFLite model. | ||||||
|   min_parser_version:string; |   min_parser_version:string; | ||||||
| } | } | ||||||
|  | //     metadata_version.cc) | ||||||
| 
 | 
 | ||||||
| root_type ModelMetadata; | root_type ModelMetadata; | ||||||
|  |  | ||||||
|  | @ -14,6 +14,7 @@ | ||||||
| # ============================================================================== | # ============================================================================== | ||||||
| """Helper classes for common model metadata information.""" | """Helper classes for common model metadata information.""" | ||||||
| 
 | 
 | ||||||
|  | import abc | ||||||
| import collections | import collections | ||||||
| import csv | import csv | ||||||
| import os | import os | ||||||
|  | @ -377,7 +378,8 @@ class TensorMd: | ||||||
|     tensor_name: name of the corresponding tensor [1] in the TFLite model. It is |     tensor_name: name of the corresponding tensor [1] in the TFLite model. It is | ||||||
|       used to locate the corresponding tensor and decide the order of the tensor |       used to locate the corresponding tensor and decide the order of the tensor | ||||||
|       metadata [2] when populating model metadata. |       metadata [2] when populating model metadata. | ||||||
|     content_range_md: information of content range [3]. [1]: |     content_range_md: information of content range [3]. | ||||||
|  |     [1]: | ||||||
|       https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136 |       https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136 | ||||||
|     [2]: |     [2]: | ||||||
|       https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640 |       https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640 | ||||||
|  | @ -777,7 +779,8 @@ class ClassificationTensorMd(TensorMd): | ||||||
|         order of the tensor metadata [4] when populating model metadata. |         order of the tensor metadata [4] when populating model metadata. | ||||||
|       score_thresholding_md: information of the score thresholding [5] in the |       score_thresholding_md: information of the score thresholding [5] in the | ||||||
|         classification tensor. |         classification tensor. | ||||||
|       content_range_md: information of content range [6]. [1]: |       content_range_md: information of content range [6]. | ||||||
|  |       [1]: | ||||||
|         https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99 |         https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99 | ||||||
|       [2]: |       [2]: | ||||||
|         https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456 |         https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456 | ||||||
|  | @ -890,7 +893,8 @@ class CategoryTensorMd(TensorMd): | ||||||
|       name: name of the tensor. |       name: name of the tensor. | ||||||
|       description: description of what the tensor is. |       description: description of what the tensor is. | ||||||
|       label_files: information of the label files [1] in the category tensor. |       label_files: information of the label files [1] in the category tensor. | ||||||
|       content_range_md: information of content range [2]. [1]: |       content_range_md: information of content range [2]. | ||||||
|  |       [1]: | ||||||
|         https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L116 |         https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L116 | ||||||
|       [2]: |       [2]: | ||||||
|         https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385 |         https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385 | ||||||
|  | @ -934,7 +938,8 @@ class DetectionOutputTensorsMd: | ||||||
|       label_files: information of the label files [1] in the classification |       label_files: information of the label files [1] in the classification | ||||||
|         tensor. |         tensor. | ||||||
|       score_calibration_md: information of the score calibration files operation |       score_calibration_md: information of the score calibration files operation | ||||||
|         [2] in the classification tensor. [1]: |         [2] in the classification tensor. | ||||||
|  |       [1]: | ||||||
|         https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99 |         https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99 | ||||||
|       [2]: |       [2]: | ||||||
|         https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456 |         https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456 | ||||||
|  | @ -1010,7 +1015,8 @@ class TensorGroupMd: | ||||||
|     Args: |     Args: | ||||||
|       name: name of tensor group. |       name: name of tensor group. | ||||||
|       tensor_names:  Names of the tensors to group together, corresponding to |       tensor_names:  Names of the tensors to group together, corresponding to | ||||||
|         TensorMetadata.name [1]. [1]: |         TensorMetadata.name [1]. | ||||||
|  |       [1]: | ||||||
|         https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L564 |         https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L564 | ||||||
|     """ |     """ | ||||||
|     self.name = name |     self.name = name | ||||||
|  | @ -1022,3 +1028,14 @@ class TensorGroupMd: | ||||||
|     group.name = self.name |     group.name = self.name | ||||||
|     group.tensorNames = self.tensor_names |     group.tensorNames = self.tensor_names | ||||||
|     return group |     return group | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class CustomMetadataMd(abc.ABC): | ||||||
|  |   """An abstract class of a container for the custom metadata information.""" | ||||||
|  | 
 | ||||||
|  |   def __init__(self, name: Optional[str] = None): | ||||||
|  |     self.name = name | ||||||
|  | 
 | ||||||
|  |   @abc.abstractmethod | ||||||
|  |   def create_metadata(self) -> _metadata_fb.CustomMetadataT: | ||||||
|  |     """Creates the custom metadata based on the information.""" | ||||||
|  |  | ||||||
|  | @ -317,6 +317,7 @@ def _create_metadata_buffer( | ||||||
|     output_md: Optional[List[metadata_info.TensorMd]] = None, |     output_md: Optional[List[metadata_info.TensorMd]] = None, | ||||||
|     input_process_units: Optional[List[metadata_fb.ProcessUnitT]] = None, |     input_process_units: Optional[List[metadata_fb.ProcessUnitT]] = None, | ||||||
|     output_group_md: Optional[List[metadata_info.TensorGroupMd]] = None, |     output_group_md: Optional[List[metadata_info.TensorGroupMd]] = None, | ||||||
|  |     custom_metadata_md: Optional[List[metadata_info.CustomMetadataMd]] = None, | ||||||
| ) -> bytearray: | ) -> bytearray: | ||||||
|   """Creates a buffer of the metadata. |   """Creates a buffer of the metadata. | ||||||
| 
 | 
 | ||||||
|  | @ -326,7 +327,9 @@ def _create_metadata_buffer( | ||||||
|     input_md: metadata information of the input tensors. |     input_md: metadata information of the input tensors. | ||||||
|     output_md: metadata information of the output tensors. |     output_md: metadata information of the output tensors. | ||||||
|     input_process_units: a lists of metadata of the input process units [1]. |     input_process_units: a lists of metadata of the input process units [1]. | ||||||
|     output_group_md: a list of metadata of output tensor groups [2]; [1]: |     output_group_md: a list of metadata of output tensor groups [2]; | ||||||
|  |     custom_metadata_md: a lists of custom metadata. | ||||||
|  |     [1]: | ||||||
|       https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L655 |       https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L655 | ||||||
|     [2]: |     [2]: | ||||||
|       https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L677 |       https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L677 | ||||||
|  | @ -367,6 +370,10 @@ def _create_metadata_buffer( | ||||||
|   subgraph_metadata.outputTensorMetadata = output_metadata |   subgraph_metadata.outputTensorMetadata = output_metadata | ||||||
|   if input_process_units: |   if input_process_units: | ||||||
|     subgraph_metadata.inputProcessUnits = input_process_units |     subgraph_metadata.inputProcessUnits = input_process_units | ||||||
|  |   if custom_metadata_md: | ||||||
|  |     subgraph_metadata.customMetadata = [ | ||||||
|  |         m.create_metadata() for m in custom_metadata_md | ||||||
|  |     ] | ||||||
|   if output_group_md: |   if output_group_md: | ||||||
|     subgraph_metadata.outputTensorGroups = [ |     subgraph_metadata.outputTensorGroups = [ | ||||||
|         m.create_metadata() for m in output_group_md |         m.create_metadata() for m in output_group_md | ||||||
|  | @ -416,6 +423,7 @@ class MetadataWriter(object): | ||||||
|     self._output_mds = [] |     self._output_mds = [] | ||||||
|     self._output_group_mds = [] |     self._output_group_mds = [] | ||||||
|     self._associated_files = [] |     self._associated_files = [] | ||||||
|  |     self._custom_metadata_mds = [] | ||||||
|     self._temp_folder = tempfile.TemporaryDirectory() |     self._temp_folder = tempfile.TemporaryDirectory() | ||||||
| 
 | 
 | ||||||
|   def __del__(self): |   def __del__(self): | ||||||
|  | @ -657,6 +665,12 @@ class MetadataWriter(object): | ||||||
|     self._output_mds.append(output_md) |     self._output_mds.append(output_md) | ||||||
|     return self |     return self | ||||||
| 
 | 
 | ||||||
|  |   def add_custom_metadata( | ||||||
|  |       self, custom_metadata_md: metadata_info.CustomMetadataMd | ||||||
|  |   ) -> 'MetadataWriter': | ||||||
|  |     self._custom_metadata_mds.append(custom_metadata_md) | ||||||
|  |     return self | ||||||
|  | 
 | ||||||
|   def populate(self) -> Tuple[bytearray, str]: |   def populate(self) -> Tuple[bytearray, str]: | ||||||
|     """Populates metadata into the TFLite file. |     """Populates metadata into the TFLite file. | ||||||
| 
 | 
 | ||||||
|  | @ -674,6 +688,7 @@ class MetadataWriter(object): | ||||||
|         input_md=self._input_mds, |         input_md=self._input_mds, | ||||||
|         output_md=self._output_mds, |         output_md=self._output_mds, | ||||||
|         input_process_units=self._input_process_units, |         input_process_units=self._input_process_units, | ||||||
|  |         custom_metadata_md=self._custom_metadata_mds, | ||||||
|         output_group_md=self._output_group_mds, |         output_group_md=self._output_group_mds, | ||||||
|     ) |     ) | ||||||
|     populator.load_metadata_buffer(metadata_buffer) |     populator.load_metadata_buffer(metadata_buffer) | ||||||
|  |  | ||||||
|  | @ -33,6 +33,8 @@ py_test( | ||||||
|         "//mediapipe/tasks/testdata/metadata:model_files", |         "//mediapipe/tasks/testdata/metadata:model_files", | ||||||
|     ], |     ], | ||||||
|     deps = [ |     deps = [ | ||||||
|  |         "//mediapipe/tasks/metadata:metadata_schema_py", | ||||||
|  |         "//mediapipe/tasks/python/metadata/metadata_writers:metadata_info", | ||||||
|         "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", |         "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", | ||||||
|         "//mediapipe/tasks/python/test:test_utils", |         "//mediapipe/tasks/python/test:test_utils", | ||||||
|     ], |     ], | ||||||
|  |  | ||||||
|  | @ -18,6 +18,8 @@ import tempfile | ||||||
| 
 | 
 | ||||||
| from absl.testing import absltest | from absl.testing import absltest | ||||||
| 
 | 
 | ||||||
|  | from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb | ||||||
|  | from mediapipe.tasks.python.metadata.metadata_writers import metadata_info | ||||||
| from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer | from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer | ||||||
| from mediapipe.tasks.python.test import test_utils | from mediapipe.tasks.python.test import test_utils | ||||||
| 
 | 
 | ||||||
|  | @ -29,6 +31,16 @@ _SCORE_CALIBRATION_FILE = test_utils.get_test_data_path( | ||||||
|     os.path.join(_TEST_DATA_DIR, 'score_calibration.txt')) |     os.path.join(_TEST_DATA_DIR, 'score_calibration.txt')) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class TestCustomMetadataMd(metadata_info.CustomMetadataMd): | ||||||
|  | 
 | ||||||
|  |   def create_metadata(self) -> _metadata_fb.CustomMetadataT: | ||||||
|  |     """Creates the custom metadata based on the information.""" | ||||||
|  |     custom_metadata_field = _metadata_fb.CustomMetadataT() | ||||||
|  |     custom_metadata_field.name = self.name | ||||||
|  |     custom_metadata_field.data = b'\x01\x02' | ||||||
|  |     return custom_metadata_field | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class LabelsTest(absltest.TestCase): | class LabelsTest(absltest.TestCase): | ||||||
| 
 | 
 | ||||||
|   def test_category_name(self): |   def test_category_name(self): | ||||||
|  | @ -415,7 +427,6 @@ class MetadataWriterForTaskTest(absltest.TestCase): | ||||||
|         score_thresholding=metadata_writer.ScoreThresholding( |         score_thresholding=metadata_writer.ScoreThresholding( | ||||||
|             global_score_threshold=0.5)) |             global_score_threshold=0.5)) | ||||||
|     _, metadata_json = writer.populate() |     _, metadata_json = writer.populate() | ||||||
|     print(metadata_json) |  | ||||||
|     self.assertJsonEqual( |     self.assertJsonEqual( | ||||||
|         metadata_json, """{ |         metadata_json, """{ | ||||||
|         "subgraph_metadata": [ |         "subgraph_metadata": [ | ||||||
|  | @ -465,6 +476,46 @@ class MetadataWriterForTaskTest(absltest.TestCase): | ||||||
|       } |       } | ||||||
|       """) |       """) | ||||||
| 
 | 
 | ||||||
|  |   def test_add_custom_metadata(self): | ||||||
|  |     writer = metadata_writer.MetadataWriter.create( | ||||||
|  |         self.image_classifier_model_buffer | ||||||
|  |     ) | ||||||
|  |     writer.add_custom_metadata( | ||||||
|  |         TestCustomMetadataMd(name='test_custom_metadata') | ||||||
|  |     ) | ||||||
|  |     _, metadata_json = writer.populate() | ||||||
|  |     self.assertJsonEqual( | ||||||
|  |         metadata_json, | ||||||
|  |         """ | ||||||
|  |         { | ||||||
|  |           "subgraph_metadata": [ | ||||||
|  |             { | ||||||
|  |               "input_tensor_metadata": [ | ||||||
|  |                 { | ||||||
|  |                   "name": "input" | ||||||
|  |                 } | ||||||
|  |               ], | ||||||
|  |               "output_tensor_metadata": [ | ||||||
|  |                 { | ||||||
|  |                   "name": "MobilenetV1/Predictions/Reshape_1" | ||||||
|  |                 } | ||||||
|  |               ], | ||||||
|  |               "custom_metadata": [ | ||||||
|  |                 { | ||||||
|  |                   "name": "test_custom_metadata", | ||||||
|  |                   "data": [ | ||||||
|  |                     1, | ||||||
|  |                     2 | ||||||
|  |                   ] | ||||||
|  |                 } | ||||||
|  |               ] | ||||||
|  |             } | ||||||
|  |           ], | ||||||
|  |           "min_parser_version": "1.5.0" | ||||||
|  |         } | ||||||
|  |         """, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   absltest.main() |   absltest.main() | ||||||
|  |  | ||||||
							
								
								
									
										4
									
								
								mediapipe/tasks/testdata/metadata/BUILD
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								mediapipe/tasks/testdata/metadata/BUILD
									
									
									
									
										vendored
									
									
								
							|  | @ -31,6 +31,8 @@ mediapipe_files(srcs = [ | ||||||
|     "efficientdet_lite0_v1.json", |     "efficientdet_lite0_v1.json", | ||||||
|     "efficientdet_lite0_v1.tflite", |     "efficientdet_lite0_v1.tflite", | ||||||
|     "labelmap.txt", |     "labelmap.txt", | ||||||
|  |     "mobile_ica_8bit-with-custom-metadata.tflite", | ||||||
|  |     "mobile_ica_8bit-with-large-min-parser-version.tflite", | ||||||
|     "mobile_ica_8bit-with-metadata.tflite", |     "mobile_ica_8bit-with-metadata.tflite", | ||||||
|     "mobile_ica_8bit-with-unsupported-metadata-version.tflite", |     "mobile_ica_8bit-with-unsupported-metadata-version.tflite", | ||||||
|     "mobile_ica_8bit-without-model-metadata.tflite", |     "mobile_ica_8bit-without-model-metadata.tflite", | ||||||
|  | @ -86,6 +88,8 @@ filegroup( | ||||||
|         "bert_text_classifier_no_metadata.tflite", |         "bert_text_classifier_no_metadata.tflite", | ||||||
|         "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite", |         "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite", | ||||||
|         "efficientdet_lite0_v1.tflite", |         "efficientdet_lite0_v1.tflite", | ||||||
|  |         "mobile_ica_8bit-with-custom-metadata.tflite", | ||||||
|  |         "mobile_ica_8bit-with-large-min-parser-version.tflite", | ||||||
|         "mobile_ica_8bit-with-metadata.tflite", |         "mobile_ica_8bit-with-metadata.tflite", | ||||||
|         "mobile_ica_8bit-with-unsupported-metadata-version.tflite", |         "mobile_ica_8bit-with-unsupported-metadata-version.tflite", | ||||||
|         "mobile_ica_8bit-without-model-metadata.tflite", |         "mobile_ica_8bit-without-model-metadata.tflite", | ||||||
|  |  | ||||||
							
								
								
									
										12
									
								
								third_party/external_files.bzl
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								third_party/external_files.bzl
									
									
									
									
										vendored
									
									
								
							|  | @ -544,6 +544,18 @@ def external_files(): | ||||||
|         urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_with_metadata.tflite?generation=1661875806733025"], |         urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_with_metadata.tflite?generation=1661875806733025"], | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|  |     http_file( | ||||||
|  |         name = "com_google_mediapipe_mobile_ica_8bit-with-custom-metadata_tflite", | ||||||
|  |         sha256 = "31f34f0dd0dc39e69e9c3deb1e3f3278febeb82ecf57c235834348a75df8fb51", | ||||||
|  |         urls = ["https://storage.googleapis.com/mediapipe-assets/mobile_ica_8bit-with-custom-metadata.tflite?generation=1677906531317767"], | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     http_file( | ||||||
|  |         name = "com_google_mediapipe_mobile_ica_8bit-with-large-min-parser-version_tflite", | ||||||
|  |         sha256 = "53d0ea047682539964820fcfc5dc81f4928957470f453f2065f4c2ab87406803", | ||||||
|  |         urls = ["https://storage.googleapis.com/mediapipe-assets/mobile_ica_8bit-with-large-min-parser-version.tflite?generation=1677906534624784"], | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|     http_file( |     http_file( | ||||||
|         name = "com_google_mediapipe_mobile_ica_8bit-with-metadata_tflite", |         name = "com_google_mediapipe_mobile_ica_8bit-with-metadata_tflite", | ||||||
|         sha256 = "4afa3970d3efd6726d147d505e28c7ff1e4fe1c24be7bcda6b5429eb099777a5", |         sha256 = "4afa3970d3efd6726d147d505e28c7ff1e4fe1c24be7bcda6b5429eb099777a5", | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user