Internal change

PiperOrigin-RevId: 481255129
This commit is contained in:
Yuqi Li 2022-10-14 16:11:15 -07:00 committed by Copybara-Service
parent ca28a19822
commit eb52b72707
8 changed files with 476 additions and 10 deletions

View File

@ -138,6 +138,7 @@ cc_test_with_tflite(
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable", "@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
], ],
deps = [ deps = [
":utils",
"//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:packet",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
@ -314,3 +315,41 @@ cc_library(
"@flatbuffers//:runtime_cc", "@flatbuffers//:runtime_cc",
], ],
) )
cc_library(
name = "model_asset_bundle_resources",
srcs = ["model_asset_bundle_resources.cc"],
hdrs = ["model_asset_bundle_resources.h"],
deps = [
":external_file_handler",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
"//mediapipe/util:resource_util",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
)
cc_test(
name = "model_asset_bundle_resources_test",
srcs = ["model_asset_bundle_resources_test.cc"],
data = [
"//mediapipe/tasks/testdata/core:test_models",
],
deps = [
":model_asset_bundle_resources",
":model_resources",
":utils",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
"@org_tensorflow//tensorflow/lite/c:common",
],
)

View File

@ -0,0 +1,107 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/util/resource_util.h"
namespace mediapipe {
namespace tasks {
namespace core {
namespace {
using ::absl::StatusCode;
} // namespace
ModelAssetBundleResources::ModelAssetBundleResources(
const std::string& tag,
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file)
: tag_(tag), model_asset_bundle_file_(std::move(model_asset_bundle_file)) {}
/* static */
absl::StatusOr<std::unique_ptr<ModelAssetBundleResources>>
ModelAssetBundleResources::Create(
const std::string& tag,
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file) {
if (model_asset_bundle_file == nullptr) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"The model asset bundle file proto cannot be nullptr.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
auto model_bundle_resources = absl::WrapUnique(
new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file)));
MP_RETURN_IF_ERROR(
model_bundle_resources->ExtractModelFilesFromExternalFileProto());
return model_bundle_resources;
}
absl::Status
ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
if (model_asset_bundle_file_->has_file_name()) {
// If the model asset bundle file name is a relative path, searches the file
// in a platform-specific location and returns the absolute path on success.
ASSIGN_OR_RETURN(
std::string path_to_resource,
mediapipe::PathToResourceAsFile(model_asset_bundle_file_->file_name()));
model_asset_bundle_file_->set_file_name(path_to_resource);
}
ASSIGN_OR_RETURN(model_asset_bundle_file_handler_,
ExternalFileHandler::CreateFromExternalFile(
model_asset_bundle_file_.get()));
const char* buffer_data =
model_asset_bundle_file_handler_->GetFileContent().data();
size_t buffer_size =
model_asset_bundle_file_handler_->GetFileContent().size();
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size,
&model_files_);
}
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetModelFile(
const std::string& filename) const {
auto it = model_files_.find(filename);
if (it == model_files_.end()) {
auto model_files = ListModelFiles();
std::string all_model_files =
absl::StrJoin(model_files.begin(), model_files.end(), ", ");
return CreateStatusWithPayload(
StatusCode::kNotFound,
absl::StrFormat("No model file with name: %s. All model files in the "
"model asset bundle are: %s.",
filename, all_model_files),
MediaPipeTasksStatus::kFileNotFoundError);
}
return it->second;
}
std::vector<std::string> ModelAssetBundleResources::ListModelFiles() const {
std::vector<std::string> model_names;
for (const auto& [model_name, _] : model_files_) {
model_names.push_back(model_name);
}
return model_names;
}
} // namespace core
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,92 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_CORE_MODEL_ASSET_BUNDLE_RESOURCES_H_
#define MEDIAPIPE_TASKS_CC_CORE_MODEL_ASSET_BUNDLE_RESOURCES_H_
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "mediapipe/tasks/cc/core/external_file_handler.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
namespace mediapipe {
namespace tasks {
namespace core {
// The mediapipe task model asset bundle resources class.
// A ModelAssetBundleResources object, created from an external file proto,
// contains model asset bundle related resources and the method to extract the
// tflite models or model asset bundles for the mediapipe sub-tasks. As the
// resources are owned by the ModelAssetBundleResources object
// callers must keep ModelAssetBundleResources alive while using any of the
// resources.
class ModelAssetBundleResources {
public:
// Takes the ownership of the provided ExternalFile proto and creates
// ModelAssetBundleResources from the proto. A non-empty tag
// must be set if the ModelAssetBundleResources will be used through
// ModelResourcesCacheService.
static absl::StatusOr<std::unique_ptr<ModelAssetBundleResources>> Create(
const std::string& tag,
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
// ModelResources is neither copyable nor movable.
ModelAssetBundleResources(const ModelAssetBundleResources&) = delete;
ModelAssetBundleResources& operator=(const ModelAssetBundleResources&) =
delete;
// Returns the model asset bundle resources tag.
std::string GetTag() const { return tag_; }
// Gets the contents of the model file (either tflite model file or model
// bundle file) with the provided name. An error is returned if there is no
// such model file.
absl::StatusOr<absl::string_view> GetModelFile(
const std::string& filename) const;
// Lists all the model file names in the model asset model.
std::vector<std::string> ListModelFiles() const;
private:
// Constructor.
ModelAssetBundleResources(
const std::string& tag,
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
// Extracts the model files (either tflite model file or model bundle file)
// from the external file proto.
absl::Status ExtractModelFilesFromExternalFileProto();
// The model asset bundle resources tag.
const std::string tag_;
// The model asset bundle file.
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file_;
// The ExternalFileHandler for the model asset bundle.
std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_;
// The model files bundled in model asset bundle, as a map with the filename
// (corresponding to a basename, e.g. "hand_detector.tflite") as key and
// a pointer to the file contents as value. Each model file can be either
// a TFLite model file or a model bundle file for sub-task.
absl::flat_hash_map<std::string, absl::string_view> model_files_;
};
} // namespace core
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_CORE_MODEL_ASSET_BUNDLE_RESOURCES_H_

View File

@ -0,0 +1,229 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include <fcntl.h>
#include <algorithm>
#include <fstream>
#include <iosfwd>
#include <string>
#include <string_view>
#include <vector>
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
namespace mediapipe {
namespace tasks {
namespace core {
namespace {
constexpr char kTestModelResourcesTag[] = "test_model_asset_resources";
constexpr char kTestModelBundleResourcesTag[] =
"test_model_asset_bundle_resources";
// Models files in dummy_gesture_recognizer.task:
// gesture_recognizer.task
// dummy_gesture_recognizer.tflite
// dummy_hand_landmarker.task
// dummy_hand_detector.tflite
// dummy_hand_landmarker.tflite
constexpr char kTestModelBundlePath[] =
"mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task";
constexpr char kInvalidTestModelBundlePath[] =
"mediapipe/tasks/testdata/core/i_do_not_exist.task";
} // namespace
TEST(ModelAssetBundleResourcesTest, CreateFromBinaryContent) {
auto model_file = std::make_unique<proto::ExternalFile>();
model_file->set_file_content(LoadBinaryContent(kTestModelBundlePath));
MP_ASSERT_OK_AND_ASSIGN(
auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
.status());
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
.status());
}
TEST(ModelAssetBundleResourcesTest, CreateFromFile) {
auto model_file = std::make_unique<proto::ExternalFile>();
model_file->set_file_name(kTestModelBundlePath);
MP_ASSERT_OK_AND_ASSIGN(
auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
.status());
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
.status());
}
TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) {
const int model_file_descriptor = open(kTestModelBundlePath, O_RDONLY);
auto model_file = std::make_unique<proto::ExternalFile>();
model_file->mutable_file_descriptor_meta()->set_fd(model_file_descriptor);
MP_ASSERT_OK_AND_ASSIGN(
auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
.status());
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
.status());
}
TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) {
auto file_content = LoadBinaryContent(kTestModelBundlePath);
auto model_file = std::make_unique<proto::ExternalFile>();
metadata::SetExternalFile(file_content, model_file.get());
MP_ASSERT_OK_AND_ASSIGN(
auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
.status());
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
.status());
}
TEST(ModelAssetBundleResourcesTest, CreateFromInvalidFile) {
auto model_file = std::make_unique<proto::ExternalFile>();
model_file->set_file_name(kInvalidTestModelBundlePath);
auto status_or_model_bundle_resources = ModelAssetBundleResources::Create(
kTestModelBundleResourcesTag, std::move(model_file));
EXPECT_EQ(status_or_model_bundle_resources.status().code(),
absl::StatusCode::kNotFound);
EXPECT_THAT(status_or_model_bundle_resources.status().message(),
testing::HasSubstr("Unable to open file"));
EXPECT_THAT(status_or_model_bundle_resources.status().GetPayload(
kMediaPipeTasksPayload),
testing::Optional(absl::Cord(
absl::StrCat(MediaPipeTasksStatus::kFileNotFoundError))));
}
TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
// Creates top-level model asset bundle resources.
auto model_file = std::make_unique<proto::ExternalFile>();
model_file->set_file_name(kTestModelBundlePath);
MP_ASSERT_OK_AND_ASSIGN(
auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
auto status_or_model_bundle_file =
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task");
MP_EXPECT_OK(status_or_model_bundle_file.status());
// Creates sub-task model asset bundle resources.
auto hand_landmaker_model_file = std::make_unique<proto::ExternalFile>();
metadata::SetExternalFile(status_or_model_bundle_file.value(),
hand_landmaker_model_file.get());
MP_ASSERT_OK_AND_ASSIGN(
auto hand_landmaker_model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(hand_landmaker_model_file)));
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
->GetModelFile("dummy_hand_detector.tflite")
.status());
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
->GetModelFile("dummy_hand_landmarker.tflite")
.status());
}
TEST(ModelAssetBundleResourcesTest, ExtractValidTFLiteModelFile) {
// Creates top-level model asset bundle resources.
auto model_file = std::make_unique<proto::ExternalFile>();
model_file->set_file_name(kTestModelBundlePath);
MP_ASSERT_OK_AND_ASSIGN(
auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
auto status_or_model_bundle_file =
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite");
MP_EXPECT_OK(status_or_model_bundle_file.status());
// Verify tflite model works.
auto hand_detector_model_file = std::make_unique<proto::ExternalFile>();
metadata::SetExternalFile(status_or_model_bundle_file.value(),
hand_detector_model_file.get());
MP_ASSERT_OK_AND_ASSIGN(
auto hand_detector_model_resources,
ModelResources::Create(kTestModelResourcesTag,
std::move(hand_detector_model_file)));
Packet model_packet = hand_detector_model_resources->GetModelPacket();
ASSERT_FALSE(model_packet.IsEmpty());
MP_ASSERT_OK(model_packet.ValidateAsType<ModelResources::ModelPtr>());
EXPECT_TRUE(model_packet.Get<ModelResources::ModelPtr>()->initialized());
}
TEST(ModelAssetBundleResourcesTest, ExtractInvalidModelFile) {
// Creates top-level model asset bundle resources.
auto model_file = std::make_unique<proto::ExternalFile>();
model_file->set_file_name(kTestModelBundlePath);
MP_ASSERT_OK_AND_ASSIGN(
auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
auto status = model_bundle_resources->GetModelFile("not_found.task").status();
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
EXPECT_THAT(status.message(),
testing::HasSubstr(
"No model file with name: not_found.task. All model files in "
"the model asset bundle are: "));
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
testing::Optional(absl::Cord(
absl::StrCat(MediaPipeTasksStatus::kFileNotFoundError))));
}
TEST(ModelAssetBundleResourcesTest, ListModelFiles) {
// Creates top-level model asset bundle resources.
auto model_file = std::make_unique<proto::ExternalFile>();
model_file->set_file_name(kTestModelBundlePath);
MP_ASSERT_OK_AND_ASSIGN(
auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
auto model_files = model_bundle_resources->ListModelFiles();
std::vector<std::string> expected_model_files = {
"dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"};
std::sort(model_files.begin(), model_files.end());
EXPECT_THAT(expected_model_files, testing::ElementsAreArray(model_files));
}
} // namespace core
} // namespace tasks
} // namespace mediapipe

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -88,16 +89,6 @@ constexpr char kCorruptedModelPath[] =
"mediapipe/tasks/testdata/core/" "mediapipe/tasks/testdata/core/"
"corrupted_mobilenet_v1_0.25_224_1_default_1.tflite"; "corrupted_mobilenet_v1_0.25_224_1_default_1.tflite";
std::string LoadBinaryContent(const char* filename) {
std::ifstream input_file(filename, std::ios::binary | std::ios::ate);
// Find buffer size from input file, and load the buffer.
size_t buffer_size = input_file.tellg();
std::string buffer(buffer_size, '\0');
input_file.seekg(0, std::ios::beg);
input_file.read(const_cast<char*>(buffer.c_str()), buffer_size);
return buffer;
}
void AssertStatusHasMediaPipeTasksStatusCode( void AssertStatusHasMediaPipeTasksStatusCode(
absl::Status status, MediaPipeTasksStatus mediapipe_tasks_code) { absl::Status status, MediaPipeTasksStatus mediapipe_tasks_code) {
EXPECT_THAT( EXPECT_THAT(

View File

@ -24,6 +24,7 @@ package(
mediapipe_files(srcs = [ mediapipe_files(srcs = [
"corrupted_mobilenet_v1_0.25_224_1_default_1.tflite", "corrupted_mobilenet_v1_0.25_224_1_default_1.tflite",
"dummy_gesture_recognizer.task",
"mobilenet_v1_0.25_224_quant.tflite", "mobilenet_v1_0.25_224_quant.tflite",
"test_model_add_op.tflite", "test_model_add_op.tflite",
"test_model_with_custom_op.tflite", "test_model_with_custom_op.tflite",
@ -36,6 +37,7 @@ filegroup(
name = "test_models", name = "test_models",
srcs = [ srcs = [
"corrupted_mobilenet_v1_0.25_224_1_default_1.tflite", "corrupted_mobilenet_v1_0.25_224_1_default_1.tflite",
"dummy_gesture_recognizer.task",
"mobilenet_v1_0.25_224_quant.tflite", "mobilenet_v1_0.25_224_quant.tflite",
"test_model_add_op.tflite", "test_model_add_op.tflite",
"test_model_with_custom_op.tflite", "test_model_with_custom_op.tflite",

Binary file not shown.

View File

@ -148,6 +148,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1661875711618421"], urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1661875711618421"],
) )
http_file(
name = "com_google_mediapipe_dummy_gesture_recognizer_task",
sha256 = "18e54586bda33300d459ca140cd045f6daf43d897224ba215a16db3423eae18e",
urls = ["https://storage.googleapis.com/mediapipe-assets/dummy_gesture_recognizer.task?generation=1665524417056146"],
)
http_file( http_file(
name = "com_google_mediapipe_empty_vocab_for_regex_tokenizer_txt", name = "com_google_mediapipe_empty_vocab_for_regex_tokenizer_txt",
sha256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", sha256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",