Internal change
PiperOrigin-RevId: 514001732
This commit is contained in:
parent
c98b4b6ec6
commit
dbe4175a08
|
@ -14,9 +14,13 @@ stamp_metadata_parser_version(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "metadata_extractor",
|
name = "metadata_extractor",
|
||||||
srcs = ["metadata_extractor.cc"],
|
srcs = ["metadata_extractor.cc"],
|
||||||
hdrs = ["metadata_extractor.h"],
|
hdrs = [
|
||||||
|
"metadata_extractor.h",
|
||||||
|
"metadata_parser_h",
|
||||||
|
],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":metadata_version_utils",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
|
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
|
||||||
|
@ -68,3 +72,10 @@ cc_library(
|
||||||
"@zlib//:zlib_minizip",
|
"@zlib//:zlib_minizip",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "metadata_version_utils",
|
||||||
|
srcs = ["metadata_version_utils.cc"],
|
||||||
|
hdrs = ["metadata_version_utils.h"],
|
||||||
|
deps = ["@com_google_absl//absl/strings"],
|
||||||
|
)
|
||||||
|
|
|
@ -26,6 +26,8 @@ limitations under the License.
|
||||||
#include "flatbuffers/flatbuffers.h"
|
#include "flatbuffers/flatbuffers.h"
|
||||||
#include "mediapipe/framework/port/status_macros.h"
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/metadata/metadata_parser.h"
|
||||||
|
#include "mediapipe/tasks/cc/metadata/metadata_version_utils.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
@ -164,6 +166,18 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer(
|
||||||
return CreateStatusWithPayload(StatusCode::kInternal,
|
return CreateStatusWithPayload(StatusCode::kInternal,
|
||||||
"Expected Model Metadata not to be null.");
|
"Expected Model Metadata not to be null.");
|
||||||
}
|
}
|
||||||
|
auto min_parser_version = model_metadata_->min_parser_version();
|
||||||
|
if (min_parser_version != nullptr &&
|
||||||
|
CompareVersions(min_parser_version->c_str(), kMetadataParserVersion) >
|
||||||
|
0) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat(
|
||||||
|
"Metadata schema version %s is smaller than the minimum version "
|
||||||
|
"%s to parse the metadata flatbuffer.",
|
||||||
|
kMetadataParserVersion, min_parser_version->c_str()),
|
||||||
|
MediaPipeTasksStatus::kMetadataInvalidSchemaVersionError);
|
||||||
|
}
|
||||||
return ExtractAssociatedFiles(buffer_data, buffer_size);
|
return ExtractAssociatedFiles(buffer_data, buffer_size);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -299,6 +313,29 @@ int ModelMetadataExtractor::GetOutputProcessUnitsCount() const {
|
||||||
return output_process_units == nullptr ? 0 : output_process_units->size();
|
return output_process_units == nullptr ? 0 : output_process_units->size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const flatbuffers::Vector<flatbuffers::Offset<tflite::CustomMetadata>>*
|
||||||
|
ModelMetadataExtractor::GetCustomMetadataList() const {
|
||||||
|
if (model_metadata_ == nullptr ||
|
||||||
|
model_metadata_->subgraph_metadata() == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return model_metadata_->subgraph_metadata()
|
||||||
|
->Get(kDefaultSubgraphIndex)
|
||||||
|
->custom_metadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const tflite::CustomMetadata* ModelMetadataExtractor::GetCustomMetadata(
|
||||||
|
int index) const {
|
||||||
|
return GetItemFromVector<tflite::CustomMetadata>(GetCustomMetadataList(),
|
||||||
|
index);
|
||||||
|
}
|
||||||
|
|
||||||
|
int ModelMetadataExtractor::GetCustomMetadataCount() const {
|
||||||
|
const Vector<flatbuffers::Offset<tflite::CustomMetadata>>* custom_medata_vec =
|
||||||
|
GetCustomMetadataList();
|
||||||
|
return custom_medata_vec == nullptr ? 0 : custom_medata_vec->size();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace metadata
|
} // namespace metadata
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -136,6 +136,19 @@ class ModelMetadataExtractor {
|
||||||
// there is no output process units.
|
// there is no output process units.
|
||||||
int GetOutputProcessUnitsCount() const;
|
int GetOutputProcessUnitsCount() const;
|
||||||
|
|
||||||
|
// Gets a list of custom metadata from SubgraphMetadata.custom_metadata,
|
||||||
|
// could be nullptr.
|
||||||
|
const flatbuffers::Vector<flatbuffers::Offset<tflite::CustomMetadata>>*
|
||||||
|
GetCustomMetadataList() const;
|
||||||
|
|
||||||
|
// Gets the custom metadata specified by the given index, or nullptr in case
|
||||||
|
// there is no custom metadata or the index is out of range.
|
||||||
|
const tflite::CustomMetadata* GetCustomMetadata(int index) const;
|
||||||
|
|
||||||
|
// Gets the count of custom metadata. In particular, 0 is returned when
|
||||||
|
// there is no custom metadata.
|
||||||
|
int GetCustomMetadataCount() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static constexpr int kDefaultSubgraphIndex = 0;
|
static constexpr int kDefaultSubgraphIndex = 0;
|
||||||
// Private default constructor, called from CreateFromModel().
|
// Private default constructor, called from CreateFromModel().
|
||||||
|
|
|
@ -21,7 +21,7 @@ namespace metadata {
|
||||||
|
|
||||||
// The version of the metadata parser that this metadata versioning library is
|
// The version of the metadata parser that this metadata versioning library is
|
||||||
// depending on.
|
// depending on.
|
||||||
inline constexpr char kMatadataParserVersion[] = "{LATEST_METADATA_PARSER_VERSION}";
|
inline constexpr char kMetadataParserVersion[] = "{LATEST_METADATA_PARSER_VERSION}";
|
||||||
|
|
||||||
} // namespace metadata
|
} // namespace metadata
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
|
|
|
@ -57,6 +57,7 @@ enum class SchemaMembers {
|
||||||
kContentPropertiesAudioProperties = 8,
|
kContentPropertiesAudioProperties = 8,
|
||||||
kAssociatedFileTypeScannIndexFile = 9,
|
kAssociatedFileTypeScannIndexFile = 9,
|
||||||
kAssociatedFileVersion = 10,
|
kAssociatedFileVersion = 10,
|
||||||
|
kCustomMetadata = 11,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Helper class to compare semantic versions in terms of three integers, major,
|
// Helper class to compare semantic versions in terms of three integers, major,
|
||||||
|
@ -125,6 +126,8 @@ Version GetMemberVersion(SchemaMembers member) {
|
||||||
return Version(1, 4, 0);
|
return Version(1, 4, 0);
|
||||||
case SchemaMembers::kAssociatedFileVersion:
|
case SchemaMembers::kAssociatedFileVersion:
|
||||||
return Version(1, 4, 1);
|
return Version(1, 4, 1);
|
||||||
|
case SchemaMembers::kCustomMetadata:
|
||||||
|
return Version(1, 5, 0);
|
||||||
default:
|
default:
|
||||||
// Should never happen.
|
// Should never happen.
|
||||||
TFLITE_LOG(FATAL) << "Unsupported schema member: "
|
TFLITE_LOG(FATAL) << "Unsupported schema member: "
|
||||||
|
@ -281,6 +284,12 @@ void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
|
||||||
GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputTensorGroups),
|
GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputTensorGroups),
|
||||||
min_version);
|
min_version);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Checks for the options field.
|
||||||
|
if (table->custom_metadata() != nullptr) {
|
||||||
|
UpdateMinimumVersion(GetMemberVersion(SchemaMembers::kCustomMetadata),
|
||||||
|
min_version);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
|
48
mediapipe/tasks/cc/metadata/metadata_version_utils.cc
Normal file
48
mediapipe/tasks/cc/metadata/metadata_version_utils.cc
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
#include "mediapipe/tasks/cc/metadata/metadata_version_utils.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/strings/str_split.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace metadata {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static int32_t GetValueOrZero(const std::vector<std::string> &list,
|
||||||
|
const int index) {
|
||||||
|
int32_t value = 0;
|
||||||
|
if (index <= list.size() - 1) {
|
||||||
|
value = std::stoi(list[index]);
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
int CompareVersions(absl::string_view version_a, absl::string_view version_b) {
|
||||||
|
std::vector<std::string> version_a_components =
|
||||||
|
absl::StrSplit(version_a, '.', absl::SkipEmpty());
|
||||||
|
std::vector<std::string> version_b_components =
|
||||||
|
absl::StrSplit(version_b, '.', absl::SkipEmpty());
|
||||||
|
|
||||||
|
const int a_length = version_a_components.size();
|
||||||
|
const int b_length = version_b_components.size();
|
||||||
|
const int max_length = std::max(a_length, b_length);
|
||||||
|
|
||||||
|
for (int i = 0; i < max_length; ++i) {
|
||||||
|
const int a_val = GetValueOrZero(version_a_components, i);
|
||||||
|
const int b_val = GetValueOrZero(version_b_components, i);
|
||||||
|
if (a_val > b_val) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
if (a_val < b_val) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace metadata
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
21
mediapipe/tasks/cc/metadata/metadata_version_utils.h
Normal file
21
mediapipe/tasks/cc/metadata/metadata_version_utils.h
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_METADATA_METADATA_VERSION_UTIL_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_METADATA_METADATA_VERSION_UTIL_H_
|
||||||
|
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace metadata {
|
||||||
|
|
||||||
|
// Compares two versions. The version format is "**.**.**" such as "1.12.3".
|
||||||
|
// If version_a is newer than version_b, return 1; if version_a is
|
||||||
|
// older than version_b, return -1; if version_a equals to version_b,
|
||||||
|
// returns 0. For example, if version_a = 1.12.3 and version_b = 1.12.1,
|
||||||
|
// version_a is newer than version_b, and the function return is 1.
|
||||||
|
int CompareVersions(absl::string_view version_a, absl::string_view version_b);
|
||||||
|
|
||||||
|
} // namespace metadata
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_METADATA_METADATA_VERSION_UTIL_H_
|
|
@ -22,6 +22,7 @@ cc_test(
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:cord",
|
"@com_google_absl//absl/strings:cord",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -44,3 +45,12 @@ cc_test(
|
||||||
"//mediapipe/tasks/cc/metadata:metadata_version",
|
"//mediapipe/tasks/cc/metadata:metadata_version",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "metadata_version_utils_test",
|
||||||
|
srcs = ["metadata_version_utils_test.cc"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/tasks/cc/metadata:metadata_version_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/cord.h"
|
#include "absl/strings/cord.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
#include "mediapipe/framework/port/file_helpers.h"
|
#include "mediapipe/framework/port/file_helpers.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
@ -26,6 +27,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/core/utils.h"
|
#include "mediapipe/tasks/cc/core/utils.h"
|
||||||
|
#include "mediapipe/tasks/cc/metadata/metadata_parser.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -43,6 +45,11 @@ constexpr char kMobileIcaWithoutTfLiteMetadata[] =
|
||||||
"mobile_ica_8bit-without-model-metadata.tflite";
|
"mobile_ica_8bit-without-model-metadata.tflite";
|
||||||
constexpr char kMobileIcaWithTfLiteMetadata[] =
|
constexpr char kMobileIcaWithTfLiteMetadata[] =
|
||||||
"mobile_ica_8bit-with-metadata.tflite";
|
"mobile_ica_8bit-with-metadata.tflite";
|
||||||
|
constexpr char kMobileIcaWithCustomMetadata[] =
|
||||||
|
"mobile_ica_8bit-with-custom-metadata.tflite";
|
||||||
|
// `min_parser_version=1000.0.0` for test purpose.
|
||||||
|
constexpr char kMobileIcaWithLargeMinParseVersion[] =
|
||||||
|
"mobile_ica_8bit-with-large-min-parser-version.tflite";
|
||||||
|
|
||||||
constexpr char kMobileIcaWithUnsupportedMetadataVersion[] =
|
constexpr char kMobileIcaWithUnsupportedMetadataVersion[] =
|
||||||
"mobile_ica_8bit-with-unsupported-metadata-version.tflite";
|
"mobile_ica_8bit-with-unsupported-metadata-version.tflite";
|
||||||
|
@ -334,6 +341,81 @@ TEST(ModelMetadataExtractorTest, GetModelVersionWorks) {
|
||||||
MP_EXPECT_OK(extractor->GetModelVersion().status());
|
MP_EXPECT_OK(extractor->GetModelVersion().status());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ModelMetadataExtractorTest, GetCustomMetadataListWorks) {
|
||||||
|
std::string buffer;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
std::unique_ptr<ModelMetadataExtractor> extractor,
|
||||||
|
CreateMetadataExtractor(kMobileIcaWithCustomMetadata, &buffer));
|
||||||
|
EXPECT_TRUE(extractor->GetCustomMetadataList() != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ModelMetadataExtractorTest,
|
||||||
|
GetCustomMetadataListWithoutTfLiteMetadataWorks) {
|
||||||
|
std::string buffer;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
std::unique_ptr<ModelMetadataExtractor> extractor,
|
||||||
|
CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer));
|
||||||
|
EXPECT_TRUE(extractor->GetCustomMetadataList() == nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ModelMetadataExtractorTest, GetCustomMetadataWorks) {
|
||||||
|
std::string buffer;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
std::unique_ptr<ModelMetadataExtractor> extractor,
|
||||||
|
CreateMetadataExtractor(kMobileIcaWithCustomMetadata, &buffer));
|
||||||
|
EXPECT_TRUE(extractor->GetCustomMetadata(0) != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ModelMetadataExtractorTest, GetCustomMetadataWithoutTfLiteMetadataWorks) {
|
||||||
|
std::string buffer;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
std::unique_ptr<ModelMetadataExtractor> extractor,
|
||||||
|
CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer));
|
||||||
|
EXPECT_TRUE(extractor->GetCustomMetadata(0) == nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ModelMetadataExtractorTest, GetCustomMetadataWithOutOfRangeIndexWorks) {
|
||||||
|
std::string buffer;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
std::unique_ptr<ModelMetadataExtractor> extractor,
|
||||||
|
CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer));
|
||||||
|
EXPECT_TRUE(extractor->GetCustomMetadata(-1) == nullptr);
|
||||||
|
EXPECT_TRUE(extractor->GetCustomMetadata(2) == nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ModelMetadataExtractorTest, GetCustomMetadataCountWorks) {
|
||||||
|
std::string buffer;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
std::unique_ptr<ModelMetadataExtractor> extractor,
|
||||||
|
CreateMetadataExtractor(kMobileIcaWithCustomMetadata, &buffer));
|
||||||
|
EXPECT_EQ(extractor->GetCustomMetadataCount(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ModelMetadataExtractorTest,
|
||||||
|
GetCustomMetadataCountWithoutTfLiteMetadataWorks) {
|
||||||
|
std::string buffer;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
std::unique_ptr<ModelMetadataExtractor> extractor,
|
||||||
|
CreateMetadataExtractor(kMobileIcaWithoutTfLiteMetadata, &buffer));
|
||||||
|
EXPECT_EQ(extractor->GetCustomMetadataCount(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ModelMetadataExtractorTest,
|
||||||
|
CreateFailsWithIncompatibleWithMinParserVersion) {
|
||||||
|
std::string buffer;
|
||||||
|
auto extractor =
|
||||||
|
CreateMetadataExtractor(kMobileIcaWithLargeMinParseVersion, &buffer);
|
||||||
|
EXPECT_THAT(extractor.status().code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_EQ(
|
||||||
|
extractor.status().message(),
|
||||||
|
absl::StrFormat("Metadata schema version %s is smaller than the minimum "
|
||||||
|
"version 1000.0.0 to parse the metadata flatbuffer.",
|
||||||
|
kMetadataParserVersion));
|
||||||
|
EXPECT_THAT(extractor.status().GetPayload(kMediaPipeTasksPayload),
|
||||||
|
Optional(absl::Cord(absl::StrCat(
|
||||||
|
MediaPipeTasksStatus::kMetadataInvalidSchemaVersionError))));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace metadata
|
} // namespace metadata
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
|
|
|
@ -27,9 +27,9 @@ using ::testing::MatchesRegex;
|
||||||
TEST(MetadataParserTest, MatadataParserVersionIsWellFormed) {
|
TEST(MetadataParserTest, MatadataParserVersionIsWellFormed) {
|
||||||
// Validates that the version is well-formed (x.y.z).
|
// Validates that the version is well-formed (x.y.z).
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
EXPECT_THAT(kMatadataParserVersion, MatchesRegex("\\d+\\.\\d+\\.\\d+"));
|
EXPECT_THAT(kMetadataParserVersion, MatchesRegex("\\d+\\.\\d+\\.\\d+"));
|
||||||
#else
|
#else
|
||||||
EXPECT_THAT(kMatadataParserVersion, MatchesRegex("[0-9]+\\.[0-9]+\\.[0-9]+"));
|
EXPECT_THAT(kMetadataParserVersion, MatchesRegex("[0-9]+\\.[0-9]+\\.[0-9]+"));
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,8 @@ using ::tflite::AudioPropertiesBuilder;
|
||||||
using ::tflite::BertTokenizerOptionsBuilder;
|
using ::tflite::BertTokenizerOptionsBuilder;
|
||||||
using ::tflite::ContentBuilder;
|
using ::tflite::ContentBuilder;
|
||||||
using ::tflite::ContentProperties_AudioProperties;
|
using ::tflite::ContentProperties_AudioProperties;
|
||||||
|
using ::tflite::CustomMetadata;
|
||||||
|
using ::tflite::CustomMetadataBuilder;
|
||||||
using ::tflite::ModelMetadataBuilder;
|
using ::tflite::ModelMetadataBuilder;
|
||||||
using ::tflite::NormalizationOptionsBuilder;
|
using ::tflite::NormalizationOptionsBuilder;
|
||||||
using ::tflite::ProcessUnit;
|
using ::tflite::ProcessUnit;
|
||||||
|
@ -483,6 +485,34 @@ TEST(MetadataVersionTest,
|
||||||
EXPECT_THAT(min_version, StrEq("1.4.1"));
|
EXPECT_THAT(min_version, StrEq("1.4.1"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForOptions) {
|
||||||
|
// Creates a metadata flatbuffer with the field custom_metadata in subgraph
|
||||||
|
// metadata.
|
||||||
|
FlatBufferBuilder builder(1024);
|
||||||
|
auto name = builder.CreateString("custom_metadata");
|
||||||
|
auto data = builder.CreateVector(std::vector<unsigned char>{'a'});
|
||||||
|
CustomMetadataBuilder custom_metadata_builder(builder);
|
||||||
|
custom_metadata_builder.add_name(name);
|
||||||
|
custom_metadata_builder.add_data(data);
|
||||||
|
auto custom_metadata = builder.CreateVector(
|
||||||
|
std::vector<Offset<CustomMetadata>>{custom_metadata_builder.Finish()});
|
||||||
|
SubGraphMetadataBuilder subgraph_builder(builder);
|
||||||
|
subgraph_builder.add_custom_metadata(custom_metadata);
|
||||||
|
auto subgraphs = builder.CreateVector(
|
||||||
|
std::vector<Offset<SubGraphMetadata>>{subgraph_builder.Finish()});
|
||||||
|
ModelMetadataBuilder metadata_builder(builder);
|
||||||
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
|
// Gets the mimimum metadata parser version.
|
||||||
|
std::string min_version;
|
||||||
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
|
builder.GetSize(), &min_version),
|
||||||
|
kTfLiteOk);
|
||||||
|
// Validates that the version is exactly 1.5.0.
|
||||||
|
EXPECT_EQ(min_version, "1.5.0");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace metadata
|
} // namespace metadata
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "mediapipe/tasks/cc/metadata/metadata_version_utils.h"
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace metadata {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(MetadataVersionTest, CompareVersions) {
|
||||||
|
ASSERT_EQ(0, CompareVersions("1.0", "1.0"));
|
||||||
|
ASSERT_EQ(0, CompareVersions("1", "1.0.0"));
|
||||||
|
|
||||||
|
ASSERT_EQ(-1, CompareVersions("2", "3"));
|
||||||
|
ASSERT_EQ(-1, CompareVersions("1.2", "1.3"));
|
||||||
|
ASSERT_EQ(-1, CompareVersions("3.2.9", "3.2.10"));
|
||||||
|
ASSERT_EQ(-1, CompareVersions("10.1.9", "10.2"));
|
||||||
|
|
||||||
|
ASSERT_EQ(1, CompareVersions("3", "2"));
|
||||||
|
ASSERT_EQ(1, CompareVersions("1.3", "1.2"));
|
||||||
|
ASSERT_EQ(1, CompareVersions("0.95", "0.94.3124"));
|
||||||
|
ASSERT_EQ(1, CompareVersions("1.1.1.12", "1.1.1.9"));
|
||||||
|
ASSERT_EQ(1, CompareVersions("1.1.1.12", "1.1.0.13"));
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
} // namespace metadata
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
|
@ -49,8 +49,7 @@ namespace tflite;
|
||||||
// New fields and types will have associated comments with the schema version
|
// New fields and types will have associated comments with the schema version
|
||||||
// for which they were added.
|
// for which they were added.
|
||||||
//
|
//
|
||||||
// TODO: Add LINT change check as needed.
|
// Schema Semantic version: 1.5.0
|
||||||
// Schema Semantic version: 1.4.1
|
|
||||||
|
|
||||||
// This indicates the flatbuffer compatibility. The number will bump up when a
|
// This indicates the flatbuffer compatibility. The number will bump up when a
|
||||||
// break change is applied to the schema, such as removing fields or adding new
|
// break change is applied to the schema, such as removing fields or adding new
|
||||||
|
@ -69,11 +68,11 @@ file_identifier "M001";
|
||||||
// 1.3.0 - Added AudioProperties to ContentProperties.
|
// 1.3.0 - Added AudioProperties to ContentProperties.
|
||||||
// 1.4.0 - Added SCANN_INDEX_FILE type to AssociatedFileType.
|
// 1.4.0 - Added SCANN_INDEX_FILE type to AssociatedFileType.
|
||||||
// 1.4.1 - Added version to AssociatedFile.
|
// 1.4.1 - Added version to AssociatedFile.
|
||||||
|
// 1.5.0 - Added CustomMetadata in SubGraphMetadata.
|
||||||
|
|
||||||
// File extension of any written files.
|
// File extension of any written files.
|
||||||
file_extension "tflitemeta";
|
file_extension "tflitemeta";
|
||||||
|
|
||||||
// TODO: Add LINT change check as needed.
|
|
||||||
enum AssociatedFileType : byte {
|
enum AssociatedFileType : byte {
|
||||||
UNKNOWN = 0,
|
UNKNOWN = 0,
|
||||||
|
|
||||||
|
@ -609,6 +608,11 @@ table TensorMetadata {
|
||||||
associated_files:[AssociatedFile];
|
associated_files:[AssociatedFile];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
table CustomMetadata {
|
||||||
|
name:string;
|
||||||
|
data:[ubyte] (force_align: 16);
|
||||||
|
}
|
||||||
|
|
||||||
table SubGraphMetadata {
|
table SubGraphMetadata {
|
||||||
// Name of the subgraph.
|
// Name of the subgraph.
|
||||||
//
|
//
|
||||||
|
@ -676,6 +680,7 @@ table SubGraphMetadata {
|
||||||
// Added in: 1.2.0
|
// Added in: 1.2.0
|
||||||
output_tensor_groups:[TensorGroup];
|
output_tensor_groups:[TensorGroup];
|
||||||
|
|
||||||
|
custom_metadata:[CustomMetadata];
|
||||||
}
|
}
|
||||||
|
|
||||||
table ModelMetadata {
|
table ModelMetadata {
|
||||||
|
@ -721,5 +726,6 @@ table ModelMetadata {
|
||||||
// the metadata is populated into a TFLite model.
|
// the metadata is populated into a TFLite model.
|
||||||
min_parser_version:string;
|
min_parser_version:string;
|
||||||
}
|
}
|
||||||
|
// metadata_version.cc)
|
||||||
|
|
||||||
root_type ModelMetadata;
|
root_type ModelMetadata;
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Helper classes for common model metadata information."""
|
"""Helper classes for common model metadata information."""
|
||||||
|
|
||||||
|
import abc
|
||||||
import collections
|
import collections
|
||||||
import csv
|
import csv
|
||||||
import os
|
import os
|
||||||
|
@ -377,11 +378,12 @@ class TensorMd:
|
||||||
tensor_name: name of the corresponding tensor [1] in the TFLite model. It is
|
tensor_name: name of the corresponding tensor [1] in the TFLite model. It is
|
||||||
used to locate the corresponding tensor and decide the order of the tensor
|
used to locate the corresponding tensor and decide the order of the tensor
|
||||||
metadata [2] when populating model metadata.
|
metadata [2] when populating model metadata.
|
||||||
content_range_md: information of content range [3]. [1]:
|
content_range_md: information of content range [3].
|
||||||
|
[1]:
|
||||||
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
|
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
|
||||||
[2]:
|
[2]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
|
||||||
[3]:
|
[3]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -777,17 +779,18 @@ class ClassificationTensorMd(TensorMd):
|
||||||
order of the tensor metadata [4] when populating model metadata.
|
order of the tensor metadata [4] when populating model metadata.
|
||||||
score_thresholding_md: information of the score thresholding [5] in the
|
score_thresholding_md: information of the score thresholding [5] in the
|
||||||
classification tensor.
|
classification tensor.
|
||||||
content_range_md: information of content range [6]. [1]:
|
content_range_md: information of content range [6].
|
||||||
|
[1]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||||
[2]:
|
[2]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
||||||
[3]:
|
[3]:
|
||||||
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
|
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
|
||||||
[4]:
|
[4]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
|
||||||
[5]:
|
[5]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||||
[6]:
|
[6]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385
|
||||||
"""
|
"""
|
||||||
self.score_calibration_md = score_calibration_md
|
self.score_calibration_md = score_calibration_md
|
||||||
|
@ -890,9 +893,10 @@ class CategoryTensorMd(TensorMd):
|
||||||
name: name of the tensor.
|
name: name of the tensor.
|
||||||
description: description of what the tensor is.
|
description: description of what the tensor is.
|
||||||
label_files: information of the label files [1] in the category tensor.
|
label_files: information of the label files [1] in the category tensor.
|
||||||
content_range_md: information of content range [2]. [1]:
|
content_range_md: information of content range [2].
|
||||||
|
[1]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L116
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L116
|
||||||
[2]:
|
[2]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L385
|
||||||
"""
|
"""
|
||||||
# In category tensors, label files are in the type of TENSOR_VALUE_LABELS.
|
# In category tensors, label files are in the type of TENSOR_VALUE_LABELS.
|
||||||
|
@ -934,9 +938,10 @@ class DetectionOutputTensorsMd:
|
||||||
label_files: information of the label files [1] in the classification
|
label_files: information of the label files [1] in the classification
|
||||||
tensor.
|
tensor.
|
||||||
score_calibration_md: information of the score calibration files operation
|
score_calibration_md: information of the score calibration files operation
|
||||||
[2] in the classification tensor. [1]:
|
[2] in the classification tensor.
|
||||||
|
[1]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||||
[2]:
|
[2]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
||||||
"""
|
"""
|
||||||
content_range_md = ValueRangeMd(
|
content_range_md = ValueRangeMd(
|
||||||
|
@ -1010,7 +1015,8 @@ class TensorGroupMd:
|
||||||
Args:
|
Args:
|
||||||
name: name of tensor group.
|
name: name of tensor group.
|
||||||
tensor_names: Names of the tensors to group together, corresponding to
|
tensor_names: Names of the tensors to group together, corresponding to
|
||||||
TensorMetadata.name [1]. [1]:
|
TensorMetadata.name [1].
|
||||||
|
[1]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L564
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L564
|
||||||
"""
|
"""
|
||||||
self.name = name
|
self.name = name
|
||||||
|
@ -1022,3 +1028,14 @@ class TensorGroupMd:
|
||||||
group.name = self.name
|
group.name = self.name
|
||||||
group.tensorNames = self.tensor_names
|
group.tensorNames = self.tensor_names
|
||||||
return group
|
return group
|
||||||
|
|
||||||
|
|
||||||
|
class CustomMetadataMd(abc.ABC):
|
||||||
|
"""An abstract class of a container for the custom metadata information."""
|
||||||
|
|
||||||
|
def __init__(self, name: Optional[str] = None):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def create_metadata(self) -> _metadata_fb.CustomMetadataT:
|
||||||
|
"""Creates the custom metadata based on the information."""
|
||||||
|
|
|
@ -317,6 +317,7 @@ def _create_metadata_buffer(
|
||||||
output_md: Optional[List[metadata_info.TensorMd]] = None,
|
output_md: Optional[List[metadata_info.TensorMd]] = None,
|
||||||
input_process_units: Optional[List[metadata_fb.ProcessUnitT]] = None,
|
input_process_units: Optional[List[metadata_fb.ProcessUnitT]] = None,
|
||||||
output_group_md: Optional[List[metadata_info.TensorGroupMd]] = None,
|
output_group_md: Optional[List[metadata_info.TensorGroupMd]] = None,
|
||||||
|
custom_metadata_md: Optional[List[metadata_info.CustomMetadataMd]] = None,
|
||||||
) -> bytearray:
|
) -> bytearray:
|
||||||
"""Creates a buffer of the metadata.
|
"""Creates a buffer of the metadata.
|
||||||
|
|
||||||
|
@ -326,9 +327,11 @@ def _create_metadata_buffer(
|
||||||
input_md: metadata information of the input tensors.
|
input_md: metadata information of the input tensors.
|
||||||
output_md: metadata information of the output tensors.
|
output_md: metadata information of the output tensors.
|
||||||
input_process_units: a lists of metadata of the input process units [1].
|
input_process_units: a lists of metadata of the input process units [1].
|
||||||
output_group_md: a list of metadata of output tensor groups [2]; [1]:
|
output_group_md: a list of metadata of output tensor groups [2];
|
||||||
|
custom_metadata_md: a lists of custom metadata.
|
||||||
|
[1]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L655
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L655
|
||||||
[2]:
|
[2]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L677
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L677
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -367,6 +370,10 @@ def _create_metadata_buffer(
|
||||||
subgraph_metadata.outputTensorMetadata = output_metadata
|
subgraph_metadata.outputTensorMetadata = output_metadata
|
||||||
if input_process_units:
|
if input_process_units:
|
||||||
subgraph_metadata.inputProcessUnits = input_process_units
|
subgraph_metadata.inputProcessUnits = input_process_units
|
||||||
|
if custom_metadata_md:
|
||||||
|
subgraph_metadata.customMetadata = [
|
||||||
|
m.create_metadata() for m in custom_metadata_md
|
||||||
|
]
|
||||||
if output_group_md:
|
if output_group_md:
|
||||||
subgraph_metadata.outputTensorGroups = [
|
subgraph_metadata.outputTensorGroups = [
|
||||||
m.create_metadata() for m in output_group_md
|
m.create_metadata() for m in output_group_md
|
||||||
|
@ -416,6 +423,7 @@ class MetadataWriter(object):
|
||||||
self._output_mds = []
|
self._output_mds = []
|
||||||
self._output_group_mds = []
|
self._output_group_mds = []
|
||||||
self._associated_files = []
|
self._associated_files = []
|
||||||
|
self._custom_metadata_mds = []
|
||||||
self._temp_folder = tempfile.TemporaryDirectory()
|
self._temp_folder = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
@ -657,6 +665,12 @@ class MetadataWriter(object):
|
||||||
self._output_mds.append(output_md)
|
self._output_mds.append(output_md)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def add_custom_metadata(
|
||||||
|
self, custom_metadata_md: metadata_info.CustomMetadataMd
|
||||||
|
) -> 'MetadataWriter':
|
||||||
|
self._custom_metadata_mds.append(custom_metadata_md)
|
||||||
|
return self
|
||||||
|
|
||||||
def populate(self) -> Tuple[bytearray, str]:
|
def populate(self) -> Tuple[bytearray, str]:
|
||||||
"""Populates metadata into the TFLite file.
|
"""Populates metadata into the TFLite file.
|
||||||
|
|
||||||
|
@ -674,6 +688,7 @@ class MetadataWriter(object):
|
||||||
input_md=self._input_mds,
|
input_md=self._input_mds,
|
||||||
output_md=self._output_mds,
|
output_md=self._output_mds,
|
||||||
input_process_units=self._input_process_units,
|
input_process_units=self._input_process_units,
|
||||||
|
custom_metadata_md=self._custom_metadata_mds,
|
||||||
output_group_md=self._output_group_mds,
|
output_group_md=self._output_group_mds,
|
||||||
)
|
)
|
||||||
populator.load_metadata_buffer(metadata_buffer)
|
populator.load_metadata_buffer(metadata_buffer)
|
||||||
|
|
|
@ -33,6 +33,8 @@ py_test(
|
||||||
"//mediapipe/tasks/testdata/metadata:model_files",
|
"//mediapipe/tasks/testdata/metadata:model_files",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/tasks/metadata:metadata_schema_py",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_info",
|
||||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
],
|
],
|
||||||
|
|
|
@ -18,6 +18,8 @@ import tempfile
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
|
||||||
|
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
|
||||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
@ -29,6 +31,16 @@ _SCORE_CALIBRATION_FILE = test_utils.get_test_data_path(
|
||||||
os.path.join(_TEST_DATA_DIR, 'score_calibration.txt'))
|
os.path.join(_TEST_DATA_DIR, 'score_calibration.txt'))
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomMetadataMd(metadata_info.CustomMetadataMd):
|
||||||
|
|
||||||
|
def create_metadata(self) -> _metadata_fb.CustomMetadataT:
|
||||||
|
"""Creates the custom metadata based on the information."""
|
||||||
|
custom_metadata_field = _metadata_fb.CustomMetadataT()
|
||||||
|
custom_metadata_field.name = self.name
|
||||||
|
custom_metadata_field.data = b'\x01\x02'
|
||||||
|
return custom_metadata_field
|
||||||
|
|
||||||
|
|
||||||
class LabelsTest(absltest.TestCase):
|
class LabelsTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_category_name(self):
|
def test_category_name(self):
|
||||||
|
@ -415,7 +427,6 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
||||||
score_thresholding=metadata_writer.ScoreThresholding(
|
score_thresholding=metadata_writer.ScoreThresholding(
|
||||||
global_score_threshold=0.5))
|
global_score_threshold=0.5))
|
||||||
_, metadata_json = writer.populate()
|
_, metadata_json = writer.populate()
|
||||||
print(metadata_json)
|
|
||||||
self.assertJsonEqual(
|
self.assertJsonEqual(
|
||||||
metadata_json, """{
|
metadata_json, """{
|
||||||
"subgraph_metadata": [
|
"subgraph_metadata": [
|
||||||
|
@ -465,6 +476,46 @@ class MetadataWriterForTaskTest(absltest.TestCase):
|
||||||
}
|
}
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
def test_add_custom_metadata(self):
|
||||||
|
writer = metadata_writer.MetadataWriter.create(
|
||||||
|
self.image_classifier_model_buffer
|
||||||
|
)
|
||||||
|
writer.add_custom_metadata(
|
||||||
|
TestCustomMetadataMd(name='test_custom_metadata')
|
||||||
|
)
|
||||||
|
_, metadata_json = writer.populate()
|
||||||
|
self.assertJsonEqual(
|
||||||
|
metadata_json,
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "input"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "MobilenetV1/Predictions/Reshape_1"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"custom_metadata": [
|
||||||
|
{
|
||||||
|
"name": "test_custom_metadata",
|
||||||
|
"data": [
|
||||||
|
1,
|
||||||
|
2
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"min_parser_version": "1.5.0"
|
||||||
|
}
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
4
mediapipe/tasks/testdata/metadata/BUILD
vendored
4
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -31,6 +31,8 @@ mediapipe_files(srcs = [
|
||||||
"efficientdet_lite0_v1.json",
|
"efficientdet_lite0_v1.json",
|
||||||
"efficientdet_lite0_v1.tflite",
|
"efficientdet_lite0_v1.tflite",
|
||||||
"labelmap.txt",
|
"labelmap.txt",
|
||||||
|
"mobile_ica_8bit-with-custom-metadata.tflite",
|
||||||
|
"mobile_ica_8bit-with-large-min-parser-version.tflite",
|
||||||
"mobile_ica_8bit-with-metadata.tflite",
|
"mobile_ica_8bit-with-metadata.tflite",
|
||||||
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
|
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
|
||||||
"mobile_ica_8bit-without-model-metadata.tflite",
|
"mobile_ica_8bit-without-model-metadata.tflite",
|
||||||
|
@ -86,6 +88,8 @@ filegroup(
|
||||||
"bert_text_classifier_no_metadata.tflite",
|
"bert_text_classifier_no_metadata.tflite",
|
||||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite",
|
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite",
|
||||||
"efficientdet_lite0_v1.tflite",
|
"efficientdet_lite0_v1.tflite",
|
||||||
|
"mobile_ica_8bit-with-custom-metadata.tflite",
|
||||||
|
"mobile_ica_8bit-with-large-min-parser-version.tflite",
|
||||||
"mobile_ica_8bit-with-metadata.tflite",
|
"mobile_ica_8bit-with-metadata.tflite",
|
||||||
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
|
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
|
||||||
"mobile_ica_8bit-without-model-metadata.tflite",
|
"mobile_ica_8bit-without-model-metadata.tflite",
|
||||||
|
|
12
third_party/external_files.bzl
vendored
12
third_party/external_files.bzl
vendored
|
@ -544,6 +544,18 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_with_metadata.tflite?generation=1661875806733025"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_with_metadata.tflite?generation=1661875806733025"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_mobile_ica_8bit-with-custom-metadata_tflite",
|
||||||
|
sha256 = "31f34f0dd0dc39e69e9c3deb1e3f3278febeb82ecf57c235834348a75df8fb51",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobile_ica_8bit-with-custom-metadata.tflite?generation=1677906531317767"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_mobile_ica_8bit-with-large-min-parser-version_tflite",
|
||||||
|
sha256 = "53d0ea047682539964820fcfc5dc81f4928957470f453f2065f4c2ab87406803",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobile_ica_8bit-with-large-min-parser-version.tflite?generation=1677906534624784"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_mobile_ica_8bit-with-metadata_tflite",
|
name = "com_google_mediapipe_mobile_ica_8bit-with-metadata_tflite",
|
||||||
sha256 = "4afa3970d3efd6726d147d505e28c7ff1e4fe1c24be7bcda6b5429eb099777a5",
|
sha256 = "4afa3970d3efd6726d147d505e28c7ff1e4fe1c24be7bcda6b5429eb099777a5",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user