Checks if a custom global resource provider is used as the first step of loading the model resources on all platforms.

PiperOrigin-RevId: 493141519
This commit is contained in:
Jiuqiang Tang 2022-12-05 16:18:36 -08:00 committed by Copybara-Service
parent 99d1dd6fbb
commit 1e76d47a71
4 changed files with 21 additions and 15 deletions

View File

@ -117,6 +117,7 @@ cc_library_with_tflite(
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"//mediapipe/util:resource_util_custom",
"//mediapipe/util/tflite:error_reporter", "//mediapipe/util/tflite:error_reporter",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/util/resource_util.h" #include "mediapipe/util/resource_util.h"
#include "mediapipe/util/resource_util_custom.h"
#include "mediapipe/util/tflite/error_reporter.h" #include "mediapipe/util/tflite/error_reporter.h"
#include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -99,21 +100,20 @@ const tflite::Model* ModelResources::GetTfLiteModel() const {
absl::Status ModelResources::BuildModelFromExternalFileProto() { absl::Status ModelResources::BuildModelFromExternalFileProto() {
if (model_file_->has_file_name()) { if (model_file_->has_file_name()) {
#ifdef __EMSCRIPTEN__ if (HasCustomGlobalResourceProvider()) {
// In browsers, the model file may require a custom ResourceProviderFn to // If the model contents are provided via a custom ResourceProviderFn, the
// provide the model content. The open() method may not work in this case. // open() method may not work. Thus, loads the model content from the
// Thus, loading the model content from the model file path in advance with // model file path in advance with the help of GetResourceContents.
// the help of GetResourceContents. MP_RETURN_IF_ERROR(GetResourceContents(
MP_RETURN_IF_ERROR(mediapipe::GetResourceContents(
model_file_->file_name(), model_file_->mutable_file_content())); model_file_->file_name(), model_file_->mutable_file_content()));
model_file_->clear_file_name(); model_file_->clear_file_name();
#else } else {
// If the model file name is a relative path, searches the file in a // If the model file name is a relative path, searches the file in a
// platform-specific location and returns the absolute path on success. // platform-specific location and returns the absolute path on success.
ASSIGN_OR_RETURN(std::string path_to_resource, ASSIGN_OR_RETURN(std::string path_to_resource,
mediapipe::PathToResourceAsFile(model_file_->file_name())); PathToResourceAsFile(model_file_->file_name()));
model_file_->set_file_name(path_to_resource); model_file_->set_file_name(path_to_resource);
#endif // __EMSCRIPTEN__ }
} }
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
model_file_handler_, model_file_handler_,

View File

@ -37,6 +37,8 @@ absl::Status GetResourceContents(const std::string& path, std::string* output,
return internal::DefaultGetResourceContents(path, output, read_as_binary); return internal::DefaultGetResourceContents(path, output, read_as_binary);
} }
bool HasCustomGlobalResourceProvider() { return resource_provider_ != nullptr; }
void SetCustomGlobalResourceProvider(ResourceProviderFn fn) { void SetCustomGlobalResourceProvider(ResourceProviderFn fn) {
resource_provider_ = std::move(fn); resource_provider_ = std::move(fn);
} }

View File

@ -10,6 +10,9 @@ namespace mediapipe {
typedef std::function<absl::Status(const std::string&, std::string*)> typedef std::function<absl::Status(const std::string&, std::string*)>
ResourceProviderFn; ResourceProviderFn;
// Returns true if files are provided via a custom resource provider.
bool HasCustomGlobalResourceProvider();
// Overrides the behavior of GetResourceContents. // Overrides the behavior of GetResourceContents.
void SetCustomGlobalResourceProvider(ResourceProviderFn fn); void SetCustomGlobalResourceProvider(ResourceProviderFn fn);