mediapipe/mediapipe/tasks/cc/core/model_resources.cc
Bekzhan Bekbolatuly 54d208aa5c Internal change
PiperOrigin-RevId: 524345939
2023-04-14 11:44:36 -07:00

166 lines
6.9 KiB
C++

/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/core/model_resources.h"
#include <cstddef>
#include <memory>
#include <string>
#include <utility>
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/external_file_handler.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/util/resource_util.h"
#include "mediapipe/util/resource_util_custom.h"
#include "mediapipe/util/tflite/error_reporter.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/model_builder.h"
#include "tensorflow/lite/tools/verifier.h"
namespace mediapipe {
namespace tasks {
namespace core {
using ::absl::StatusCode;
using ::mediapipe::api2::MakePacket;
using ::mediapipe::api2::Packet;
using ::mediapipe::api2::PacketAdopting;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
bool ModelResources::Verifier::Verify(const char* data, int length,
tflite::ErrorReporter* reporter) {
return tflite::Verify(data, length, reporter);
}
ModelResources::ModelResources(const std::string& tag,
std::unique_ptr<proto::ExternalFile> model_file,
Packet<tflite::OpResolver> op_resolver_packet)
: tag_(tag),
model_file_(std::move(model_file)),
op_resolver_packet_(op_resolver_packet) {}
/* static */
absl::StatusOr<std::unique_ptr<ModelResources>> ModelResources::Create(
const std::string& tag, std::unique_ptr<proto::ExternalFile> model_file,
std::unique_ptr<tflite::OpResolver> op_resolver) {
return Create(tag, std::move(model_file),
PacketAdopting<tflite::OpResolver>(std::move(op_resolver)));
}
/* static */
absl::StatusOr<std::unique_ptr<ModelResources>> ModelResources::Create(
const std::string& tag, std::unique_ptr<proto::ExternalFile> model_file,
Packet<tflite::OpResolver> op_resolver_packet) {
if (model_file == nullptr) {
return CreateStatusWithPayload(StatusCode::kInvalidArgument,
"The model file proto cannot be nullptr.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
if (op_resolver_packet.IsEmpty()) {
return CreateStatusWithPayload(StatusCode::kInvalidArgument,
"The op resolver packet must be non-empty.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
auto model_resources = absl::WrapUnique(
new ModelResources(tag, std::move(model_file), op_resolver_packet));
MP_RETURN_IF_ERROR(model_resources->BuildModelFromExternalFileProto());
return model_resources;
}
const tflite::Model* ModelResources::GetTfLiteModel() const {
#if !TFLITE_IN_GMSCORE
return model_packet_.Get()->GetModel();
#else
return tflite::GetModel(model_file_handler_->GetFileContent().data());
#endif
}
absl::Status ModelResources::BuildModelFromExternalFileProto() {
if (model_file_->has_file_name()) {
if (HasCustomGlobalResourceProvider()) {
// If the model contents are provided via a custom ResourceProviderFn, the
// open() method may not work. Thus, loads the model content from the
// model file path in advance with the help of GetResourceContents.
MP_RETURN_IF_ERROR(GetResourceContents(
model_file_->file_name(), model_file_->mutable_file_content()));
model_file_->clear_file_name();
} else {
// If the model file name is a relative path, searches the file in a
// platform-specific location and returns the absolute path on success.
ASSIGN_OR_RETURN(std::string path_to_resource,
PathToResourceAsFile(model_file_->file_name()));
model_file_->set_file_name(path_to_resource);
}
}
ASSIGN_OR_RETURN(
model_file_handler_,
ExternalFileHandler::CreateFromExternalFile(model_file_.get()));
const char* buffer_data = model_file_handler_->GetFileContent().data();
size_t buffer_size = model_file_handler_->GetFileContent().size();
// Verifies that the supplied buffer refers to a valid flatbuffer model,
// and that it uses only operators that are supported by the OpResolver
// that was passed to the ModelResources constructor, and then builds
// the model from the buffer.
auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
buffer_data, buffer_size, &verifier_, &error_reporter_);
if (model == nullptr) {
static constexpr char kInvalidFlatbufferMessage[] =
"The model is not a valid Flatbuffer";
// To be replaced with a proper switch-case when TFLite model builder
// returns a `MediaPipeTasksStatus` code capturing this type of error.
if (absl::StrContains(error_reporter_.message(),
kInvalidFlatbufferMessage)) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument, error_reporter_.message(),
MediaPipeTasksStatus::kInvalidFlatBufferError);
} else if (absl::StrContains(error_reporter_.message(),
"Error loading model from buffer")) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument, kInvalidFlatbufferMessage,
MediaPipeTasksStatus::kInvalidFlatBufferError);
} else {
return CreateStatusWithPayload(
StatusCode::kUnknown,
absl::StrCat(
"Could not build model from the provided pre-loaded flatbuffer: ",
error_reporter_.message()));
}
}
model_packet_ = MakePacket<ModelPtr>(
model.release(), [](tflite::FlatBufferModel* model) { delete model; });
ASSIGN_OR_RETURN(auto model_metadata_extractor,
metadata::ModelMetadataExtractor::CreateFromModelBuffer(
buffer_data, buffer_size));
metadata_extractor_packet_ = PacketAdopting<metadata::ModelMetadataExtractor>(
std::move(model_metadata_extractor));
return absl::OkStatus();
}
} // namespace core
} // namespace tasks
} // namespace mediapipe