Internal change
PiperOrigin-RevId: 481255129
This commit is contained in:
parent
ca28a19822
commit
eb52b72707
|
@ -138,6 +138,7 @@ cc_test_with_tflite(
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":utils",
|
||||||
"//mediapipe/framework/api2:packet",
|
"//mediapipe/framework/api2:packet",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
|
@ -314,3 +315,41 @@ cc_library(
|
||||||
"@flatbuffers//:runtime_cc",
|
"@flatbuffers//:runtime_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "model_asset_bundle_resources",
|
||||||
|
srcs = ["model_asset_bundle_resources.cc"],
|
||||||
|
hdrs = ["model_asset_bundle_resources.h"],
|
||||||
|
deps = [
|
||||||
|
":external_file_handler",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
|
||||||
|
"//mediapipe/util:resource_util",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "model_asset_bundle_resources_test",
|
||||||
|
srcs = ["model_asset_bundle_resources_test.cc"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/core:test_models",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":model_asset_bundle_resources",
|
||||||
|
":model_resources",
|
||||||
|
":utils",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
|
||||||
|
"@org_tensorflow//tensorflow/lite/c:common",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
107
mediapipe/tasks/cc/core/model_asset_bundle_resources.cc
Normal file
107
mediapipe/tasks/cc/core/model_asset_bundle_resources.cc
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "absl/strings/str_join.h"
|
||||||
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||||
|
#include "mediapipe/util/resource_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
using ::absl::StatusCode;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
ModelAssetBundleResources::ModelAssetBundleResources(
|
||||||
|
const std::string& tag,
|
||||||
|
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file)
|
||||||
|
: tag_(tag), model_asset_bundle_file_(std::move(model_asset_bundle_file)) {}
|
||||||
|
|
||||||
|
/* static */
|
||||||
|
absl::StatusOr<std::unique_ptr<ModelAssetBundleResources>>
|
||||||
|
ModelAssetBundleResources::Create(
|
||||||
|
const std::string& tag,
|
||||||
|
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file) {
|
||||||
|
if (model_asset_bundle_file == nullptr) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
StatusCode::kInvalidArgument,
|
||||||
|
"The model asset bundle file proto cannot be nullptr.",
|
||||||
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
|
}
|
||||||
|
auto model_bundle_resources = absl::WrapUnique(
|
||||||
|
new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file)));
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
model_bundle_resources->ExtractModelFilesFromExternalFileProto());
|
||||||
|
return model_bundle_resources;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status
|
||||||
|
ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
|
||||||
|
if (model_asset_bundle_file_->has_file_name()) {
|
||||||
|
// If the model asset bundle file name is a relative path, searches the file
|
||||||
|
// in a platform-specific location and returns the absolute path on success.
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
std::string path_to_resource,
|
||||||
|
mediapipe::PathToResourceAsFile(model_asset_bundle_file_->file_name()));
|
||||||
|
model_asset_bundle_file_->set_file_name(path_to_resource);
|
||||||
|
}
|
||||||
|
ASSIGN_OR_RETURN(model_asset_bundle_file_handler_,
|
||||||
|
ExternalFileHandler::CreateFromExternalFile(
|
||||||
|
model_asset_bundle_file_.get()));
|
||||||
|
const char* buffer_data =
|
||||||
|
model_asset_bundle_file_handler_->GetFileContent().data();
|
||||||
|
size_t buffer_size =
|
||||||
|
model_asset_bundle_file_handler_->GetFileContent().size();
|
||||||
|
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size,
|
||||||
|
&model_files_);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetModelFile(
|
||||||
|
const std::string& filename) const {
|
||||||
|
auto it = model_files_.find(filename);
|
||||||
|
if (it == model_files_.end()) {
|
||||||
|
auto model_files = ListModelFiles();
|
||||||
|
std::string all_model_files =
|
||||||
|
absl::StrJoin(model_files.begin(), model_files.end(), ", ");
|
||||||
|
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
StatusCode::kNotFound,
|
||||||
|
absl::StrFormat("No model file with name: %s. All model files in the "
|
||||||
|
"model asset bundle are: %s.",
|
||||||
|
filename, all_model_files),
|
||||||
|
MediaPipeTasksStatus::kFileNotFoundError);
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> ModelAssetBundleResources::ListModelFiles() const {
|
||||||
|
std::vector<std::string> model_names;
|
||||||
|
for (const auto& [model_name, _] : model_files_) {
|
||||||
|
model_names.push_back(model_name);
|
||||||
|
}
|
||||||
|
return model_names;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
92
mediapipe/tasks/cc/core/model_asset_bundle_resources.h
Normal file
92
mediapipe/tasks/cc/core/model_asset_bundle_resources.h
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_CORE_MODEL_ASSET_BUNDLE_RESOURCES_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_CORE_MODEL_ASSET_BUNDLE_RESOURCES_H_
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/external_file_handler.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace core {
|
||||||
|
|
||||||
|
// The mediapipe task model asset bundle resources class.
|
||||||
|
// A ModelAssetBundleResources object, created from an external file proto,
|
||||||
|
// contains model asset bundle related resources and the method to extract the
|
||||||
|
// tflite models or model asset bundles for the mediapipe sub-tasks. As the
|
||||||
|
// resources are owned by the ModelAssetBundleResources object
|
||||||
|
// callers must keep ModelAssetBundleResources alive while using any of the
|
||||||
|
// resources.
|
||||||
|
class ModelAssetBundleResources {
|
||||||
|
public:
|
||||||
|
// Takes the ownership of the provided ExternalFile proto and creates
|
||||||
|
// ModelAssetBundleResources from the proto. A non-empty tag
|
||||||
|
// must be set if the ModelAssetBundleResources will be used through
|
||||||
|
// ModelResourcesCacheService.
|
||||||
|
static absl::StatusOr<std::unique_ptr<ModelAssetBundleResources>> Create(
|
||||||
|
const std::string& tag,
|
||||||
|
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
|
||||||
|
|
||||||
|
// ModelResources is neither copyable nor movable.
|
||||||
|
ModelAssetBundleResources(const ModelAssetBundleResources&) = delete;
|
||||||
|
ModelAssetBundleResources& operator=(const ModelAssetBundleResources&) =
|
||||||
|
delete;
|
||||||
|
|
||||||
|
// Returns the model asset bundle resources tag.
|
||||||
|
std::string GetTag() const { return tag_; }
|
||||||
|
|
||||||
|
// Gets the contents of the model file (either tflite model file or model
|
||||||
|
// bundle file) with the provided name. An error is returned if there is no
|
||||||
|
// such model file.
|
||||||
|
absl::StatusOr<absl::string_view> GetModelFile(
|
||||||
|
const std::string& filename) const;
|
||||||
|
|
||||||
|
// Lists all the model file names in the model asset model.
|
||||||
|
std::vector<std::string> ListModelFiles() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Constructor.
|
||||||
|
ModelAssetBundleResources(
|
||||||
|
const std::string& tag,
|
||||||
|
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
|
||||||
|
|
||||||
|
// Extracts the model files (either tflite model file or model bundle file)
|
||||||
|
// from the external file proto.
|
||||||
|
absl::Status ExtractModelFilesFromExternalFileProto();
|
||||||
|
|
||||||
|
// The model asset bundle resources tag.
|
||||||
|
const std::string tag_;
|
||||||
|
|
||||||
|
// The model asset bundle file.
|
||||||
|
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file_;
|
||||||
|
|
||||||
|
// The ExternalFileHandler for the model asset bundle.
|
||||||
|
std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_;
|
||||||
|
|
||||||
|
// The model files bundled in model asset bundle, as a map with the filename
|
||||||
|
// (corresponding to a basename, e.g. "hand_detector.tflite") as key and
|
||||||
|
// a pointer to the file contents as value. Each model file can be either
|
||||||
|
// a TFLite model file or a model bundle file for sub-task.
|
||||||
|
absl::flat_hash_map<std::string, absl::string_view> model_files_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_CORE_MODEL_ASSET_BUNDLE_RESOURCES_H_
|
229
mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc
Normal file
229
mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc
Normal file
|
@ -0,0 +1,229 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||||
|
|
||||||
|
#include <fcntl.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iosfwd>
|
||||||
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/utils.h"
|
||||||
|
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace core {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kTestModelResourcesTag[] = "test_model_asset_resources";
|
||||||
|
|
||||||
|
constexpr char kTestModelBundleResourcesTag[] =
|
||||||
|
"test_model_asset_bundle_resources";
|
||||||
|
|
||||||
|
// Models files in dummy_gesture_recognizer.task:
|
||||||
|
// gesture_recognizer.task
|
||||||
|
// dummy_gesture_recognizer.tflite
|
||||||
|
// dummy_hand_landmarker.task
|
||||||
|
// dummy_hand_detector.tflite
|
||||||
|
// dummy_hand_landmarker.tflite
|
||||||
|
constexpr char kTestModelBundlePath[] =
|
||||||
|
"mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task";
|
||||||
|
|
||||||
|
constexpr char kInvalidTestModelBundlePath[] =
|
||||||
|
"mediapipe/tasks/testdata/core/i_do_not_exist.task";
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST(ModelAssetBundleResourcesTest, CreateFromBinaryContent) {
|
||||||
|
auto model_file = std::make_unique<proto::ExternalFile>();
|
||||||
|
model_file->set_file_content(LoadBinaryContent(kTestModelBundlePath));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto model_bundle_resources,
|
||||||
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
|
std::move(model_file)));
|
||||||
|
MP_EXPECT_OK(
|
||||||
|
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
||||||
|
.status());
|
||||||
|
MP_EXPECT_OK(
|
||||||
|
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
||||||
|
.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ModelAssetBundleResourcesTest, CreateFromFile) {
|
||||||
|
auto model_file = std::make_unique<proto::ExternalFile>();
|
||||||
|
model_file->set_file_name(kTestModelBundlePath);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto model_bundle_resources,
|
||||||
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
|
std::move(model_file)));
|
||||||
|
MP_EXPECT_OK(
|
||||||
|
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
||||||
|
.status());
|
||||||
|
MP_EXPECT_OK(
|
||||||
|
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
||||||
|
.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) {
|
||||||
|
const int model_file_descriptor = open(kTestModelBundlePath, O_RDONLY);
|
||||||
|
auto model_file = std::make_unique<proto::ExternalFile>();
|
||||||
|
model_file->mutable_file_descriptor_meta()->set_fd(model_file_descriptor);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto model_bundle_resources,
|
||||||
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
|
std::move(model_file)));
|
||||||
|
MP_EXPECT_OK(
|
||||||
|
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
||||||
|
.status());
|
||||||
|
MP_EXPECT_OK(
|
||||||
|
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
||||||
|
.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) {
|
||||||
|
auto file_content = LoadBinaryContent(kTestModelBundlePath);
|
||||||
|
auto model_file = std::make_unique<proto::ExternalFile>();
|
||||||
|
metadata::SetExternalFile(file_content, model_file.get());
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto model_bundle_resources,
|
||||||
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
|
std::move(model_file)));
|
||||||
|
MP_EXPECT_OK(
|
||||||
|
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
||||||
|
.status());
|
||||||
|
MP_EXPECT_OK(
|
||||||
|
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
||||||
|
.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ModelAssetBundleResourcesTest, CreateFromInvalidFile) {
|
||||||
|
auto model_file = std::make_unique<proto::ExternalFile>();
|
||||||
|
model_file->set_file_name(kInvalidTestModelBundlePath);
|
||||||
|
auto status_or_model_bundle_resources = ModelAssetBundleResources::Create(
|
||||||
|
kTestModelBundleResourcesTag, std::move(model_file));
|
||||||
|
|
||||||
|
EXPECT_EQ(status_or_model_bundle_resources.status().code(),
|
||||||
|
absl::StatusCode::kNotFound);
|
||||||
|
EXPECT_THAT(status_or_model_bundle_resources.status().message(),
|
||||||
|
testing::HasSubstr("Unable to open file"));
|
||||||
|
EXPECT_THAT(status_or_model_bundle_resources.status().GetPayload(
|
||||||
|
kMediaPipeTasksPayload),
|
||||||
|
testing::Optional(absl::Cord(
|
||||||
|
absl::StrCat(MediaPipeTasksStatus::kFileNotFoundError))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
|
||||||
|
// Creates top-level model asset bundle resources.
|
||||||
|
auto model_file = std::make_unique<proto::ExternalFile>();
|
||||||
|
model_file->set_file_name(kTestModelBundlePath);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto model_bundle_resources,
|
||||||
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
|
std::move(model_file)));
|
||||||
|
auto status_or_model_bundle_file =
|
||||||
|
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task");
|
||||||
|
MP_EXPECT_OK(status_or_model_bundle_file.status());
|
||||||
|
|
||||||
|
// Creates sub-task model asset bundle resources.
|
||||||
|
auto hand_landmaker_model_file = std::make_unique<proto::ExternalFile>();
|
||||||
|
metadata::SetExternalFile(status_or_model_bundle_file.value(),
|
||||||
|
hand_landmaker_model_file.get());
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto hand_landmaker_model_bundle_resources,
|
||||||
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
|
std::move(hand_landmaker_model_file)));
|
||||||
|
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
|
||||||
|
->GetModelFile("dummy_hand_detector.tflite")
|
||||||
|
.status());
|
||||||
|
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
|
||||||
|
->GetModelFile("dummy_hand_landmarker.tflite")
|
||||||
|
.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ModelAssetBundleResourcesTest, ExtractValidTFLiteModelFile) {
|
||||||
|
// Creates top-level model asset bundle resources.
|
||||||
|
auto model_file = std::make_unique<proto::ExternalFile>();
|
||||||
|
model_file->set_file_name(kTestModelBundlePath);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto model_bundle_resources,
|
||||||
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
|
std::move(model_file)));
|
||||||
|
auto status_or_model_bundle_file =
|
||||||
|
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite");
|
||||||
|
MP_EXPECT_OK(status_or_model_bundle_file.status());
|
||||||
|
|
||||||
|
// Verify tflite model works.
|
||||||
|
auto hand_detector_model_file = std::make_unique<proto::ExternalFile>();
|
||||||
|
metadata::SetExternalFile(status_or_model_bundle_file.value(),
|
||||||
|
hand_detector_model_file.get());
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto hand_detector_model_resources,
|
||||||
|
ModelResources::Create(kTestModelResourcesTag,
|
||||||
|
std::move(hand_detector_model_file)));
|
||||||
|
Packet model_packet = hand_detector_model_resources->GetModelPacket();
|
||||||
|
ASSERT_FALSE(model_packet.IsEmpty());
|
||||||
|
MP_ASSERT_OK(model_packet.ValidateAsType<ModelResources::ModelPtr>());
|
||||||
|
EXPECT_TRUE(model_packet.Get<ModelResources::ModelPtr>()->initialized());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ModelAssetBundleResourcesTest, ExtractInvalidModelFile) {
|
||||||
|
// Creates top-level model asset bundle resources.
|
||||||
|
auto model_file = std::make_unique<proto::ExternalFile>();
|
||||||
|
model_file->set_file_name(kTestModelBundlePath);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto model_bundle_resources,
|
||||||
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
|
std::move(model_file)));
|
||||||
|
auto status = model_bundle_resources->GetModelFile("not_found.task").status();
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
|
||||||
|
EXPECT_THAT(status.message(),
|
||||||
|
testing::HasSubstr(
|
||||||
|
"No model file with name: not_found.task. All model files in "
|
||||||
|
"the model asset bundle are: "));
|
||||||
|
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
|
||||||
|
testing::Optional(absl::Cord(
|
||||||
|
absl::StrCat(MediaPipeTasksStatus::kFileNotFoundError))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ModelAssetBundleResourcesTest, ListModelFiles) {
|
||||||
|
// Creates top-level model asset bundle resources.
|
||||||
|
auto model_file = std::make_unique<proto::ExternalFile>();
|
||||||
|
model_file->set_file_name(kTestModelBundlePath);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto model_bundle_resources,
|
||||||
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
|
std::move(model_file)));
|
||||||
|
auto model_files = model_bundle_resources->ListModelFiles();
|
||||||
|
std::vector<std::string> expected_model_files = {
|
||||||
|
"dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"};
|
||||||
|
std::sort(model_files.begin(), model_files.end());
|
||||||
|
EXPECT_THAT(expected_model_files, testing::ElementsAreArray(model_files));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/utils.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
@ -88,16 +89,6 @@ constexpr char kCorruptedModelPath[] =
|
||||||
"mediapipe/tasks/testdata/core/"
|
"mediapipe/tasks/testdata/core/"
|
||||||
"corrupted_mobilenet_v1_0.25_224_1_default_1.tflite";
|
"corrupted_mobilenet_v1_0.25_224_1_default_1.tflite";
|
||||||
|
|
||||||
std::string LoadBinaryContent(const char* filename) {
|
|
||||||
std::ifstream input_file(filename, std::ios::binary | std::ios::ate);
|
|
||||||
// Find buffer size from input file, and load the buffer.
|
|
||||||
size_t buffer_size = input_file.tellg();
|
|
||||||
std::string buffer(buffer_size, '\0');
|
|
||||||
input_file.seekg(0, std::ios::beg);
|
|
||||||
input_file.read(const_cast<char*>(buffer.c_str()), buffer_size);
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
void AssertStatusHasMediaPipeTasksStatusCode(
|
void AssertStatusHasMediaPipeTasksStatusCode(
|
||||||
absl::Status status, MediaPipeTasksStatus mediapipe_tasks_code) {
|
absl::Status status, MediaPipeTasksStatus mediapipe_tasks_code) {
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
|
|
2
mediapipe/tasks/testdata/core/BUILD
vendored
2
mediapipe/tasks/testdata/core/BUILD
vendored
|
@ -24,6 +24,7 @@ package(
|
||||||
|
|
||||||
mediapipe_files(srcs = [
|
mediapipe_files(srcs = [
|
||||||
"corrupted_mobilenet_v1_0.25_224_1_default_1.tflite",
|
"corrupted_mobilenet_v1_0.25_224_1_default_1.tflite",
|
||||||
|
"dummy_gesture_recognizer.task",
|
||||||
"mobilenet_v1_0.25_224_quant.tflite",
|
"mobilenet_v1_0.25_224_quant.tflite",
|
||||||
"test_model_add_op.tflite",
|
"test_model_add_op.tflite",
|
||||||
"test_model_with_custom_op.tflite",
|
"test_model_with_custom_op.tflite",
|
||||||
|
@ -36,6 +37,7 @@ filegroup(
|
||||||
name = "test_models",
|
name = "test_models",
|
||||||
srcs = [
|
srcs = [
|
||||||
"corrupted_mobilenet_v1_0.25_224_1_default_1.tflite",
|
"corrupted_mobilenet_v1_0.25_224_1_default_1.tflite",
|
||||||
|
"dummy_gesture_recognizer.task",
|
||||||
"mobilenet_v1_0.25_224_quant.tflite",
|
"mobilenet_v1_0.25_224_quant.tflite",
|
||||||
"test_model_add_op.tflite",
|
"test_model_add_op.tflite",
|
||||||
"test_model_with_custom_op.tflite",
|
"test_model_with_custom_op.tflite",
|
||||||
|
|
BIN
mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task
vendored
Normal file
BIN
mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task
vendored
Normal file
Binary file not shown.
6
third_party/external_files.bzl
vendored
6
third_party/external_files.bzl
vendored
|
@ -148,6 +148,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1661875711618421"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1661875711618421"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_dummy_gesture_recognizer_task",
|
||||||
|
sha256 = "18e54586bda33300d459ca140cd045f6daf43d897224ba215a16db3423eae18e",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/dummy_gesture_recognizer.task?generation=1665524417056146"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_empty_vocab_for_regex_tokenizer_txt",
|
name = "com_google_mediapipe_empty_vocab_for_regex_tokenizer_txt",
|
||||||
sha256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
sha256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user