Update model_task_graph to support multiple local model resources.

PiperOrigin-RevId: 482917453
This commit is contained in:
MediaPipe Team 2022-10-21 16:41:07 -07:00 committed by Copybara-Service
parent 4a6c23a76a
commit ea1d85d811
2 changed files with 32 additions and 15 deletions

View File

@ -156,21 +156,24 @@ absl::StatusOr<CalculatorGraphConfig> ModelTaskGraph::GetConfig(
} }
absl::StatusOr<const ModelResources*> ModelTaskGraph::CreateModelResources( absl::StatusOr<const ModelResources*> ModelTaskGraph::CreateModelResources(
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file) { SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
const std::string tag_suffix) {
auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); auto model_resources_cache_service = sc->Service(kModelResourcesCacheService);
if (!model_resources_cache_service.IsAvailable()) { if (!model_resources_cache_service.IsAvailable()) {
ASSIGN_OR_RETURN(local_model_resources_, ASSIGN_OR_RETURN(auto local_model_resource,
ModelResources::Create("", std::move(external_file))); ModelResources::Create("", std::move(external_file)));
LOG(WARNING) LOG(WARNING)
<< "A local ModelResources object is created. Please consider using " << "A local ModelResources object is created. Please consider using "
"ModelResourcesCacheService to cache the created ModelResources " "ModelResourcesCacheService to cache the created ModelResources "
"object in the CalculatorGraph."; "object in the CalculatorGraph.";
return local_model_resources_.get(); local_model_resources_.push_back(std::move(local_model_resource));
return local_model_resources_.back().get();
} }
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto op_resolver_packet, auto op_resolver_packet,
model_resources_cache_service.GetObject().GetGraphOpResolverPacket()); model_resources_cache_service.GetObject().GetGraphOpResolverPacket());
const std::string tag = CreateModelResourcesTag(sc->OriginalNode()); const std::string tag =
absl::StrCat(CreateModelResourcesTag(sc->OriginalNode()), tag_suffix);
ASSIGN_OR_RETURN(auto model_resources, ASSIGN_OR_RETURN(auto model_resources,
ModelResources::Create(tag, std::move(external_file), ModelResources::Create(tag, std::move(external_file),
op_resolver_packet)); op_resolver_packet));
@ -182,7 +185,8 @@ absl::StatusOr<const ModelResources*> ModelTaskGraph::CreateModelResources(
absl::StatusOr<const ModelAssetBundleResources*> absl::StatusOr<const ModelAssetBundleResources*>
ModelTaskGraph::CreateModelAssetBundleResources( ModelTaskGraph::CreateModelAssetBundleResources(
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file) { SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
const std::string tag_suffix) {
auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); auto model_resources_cache_service = sc->Service(kModelResourcesCacheService);
bool has_file_pointer_meta = external_file->has_file_pointer_meta(); bool has_file_pointer_meta = external_file->has_file_pointer_meta();
// if external file is set by file pointer, no need to add the model asset // if external file is set by file pointer, no need to add the model asset
@ -190,7 +194,7 @@ ModelTaskGraph::CreateModelAssetBundleResources(
// not owned by this model asset bundle resources. // not owned by this model asset bundle resources.
if (!model_resources_cache_service.IsAvailable() || has_file_pointer_meta) { if (!model_resources_cache_service.IsAvailable() || has_file_pointer_meta) {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
local_model_asset_bundle_resources_, auto local_model_asset_bundle_resource,
ModelAssetBundleResources::Create("", std::move(external_file))); ModelAssetBundleResources::Create("", std::move(external_file)));
if (!has_file_pointer_meta) { if (!has_file_pointer_meta) {
LOG(WARNING) LOG(WARNING)
@ -198,10 +202,12 @@ ModelTaskGraph::CreateModelAssetBundleResources(
"ModelResourcesCacheService to cache the created ModelResources " "ModelResourcesCacheService to cache the created ModelResources "
"object in the CalculatorGraph."; "object in the CalculatorGraph.";
} }
return local_model_asset_bundle_resources_.get(); local_model_asset_bundle_resources_.push_back(
std::move(local_model_asset_bundle_resource));
return local_model_asset_bundle_resources_.back().get();
} }
const std::string tag = const std::string tag = absl::StrCat(
CreateModelAssetBundleResourcesTag(sc->OriginalNode()); CreateModelAssetBundleResourcesTag(sc->OriginalNode()), tag_suffix);
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto model_bundle_resources, auto model_bundle_resources,
ModelAssetBundleResources::Create(tag, std::move(external_file))); ModelAssetBundleResources::Create(tag, std::move(external_file)));

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
@ -75,9 +76,14 @@ class ModelTaskGraph : public Subgraph {
// construction stage. Note that the external file contents will be moved // construction stage. Note that the external file contents will be moved
// into the model resources object on creation. The returned model resources // into the model resources object on creation. The returned model resources
// pointer will provide graph authors with the access to the metadata // pointer will provide graph authors with the access to the metadata
// extractor and the tflite model. // extractor and the tflite model. When the model resources graph service is
// available, a tag is generated internally asscoiated with the created model
// resource. If more than one model resources are created in a graph, the
// model resources graph service add the tag_suffix to support multiple
// resources.
absl::StatusOr<const ModelResources*> CreateModelResources( absl::StatusOr<const ModelResources*> CreateModelResources(
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file); SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
const std::string tag_suffix = "");
// If the model resources graph service is available, creates a model asset // If the model resources graph service is available, creates a model asset
// bundle resources object from the subgraph context, and caches the created // bundle resources object from the subgraph context, and caches the created
@ -103,10 +109,15 @@ class ModelTaskGraph : public Subgraph {
// that can only be used in the graph construction stage. Note that the // that can only be used in the graph construction stage. Note that the
// external file contents will be moved into the model asset bundle resources // external file contents will be moved into the model asset bundle resources
// object on creation. The returned model asset bundle resources pointer will // object on creation. The returned model asset bundle resources pointer will
// provide graph authors with the access to extracted model files. // provide graph authors with the access to extracted model files. When the
// model resources graph service is available, a tag is generated internally
// asscoiated with the created model asset bundle resource. If more than one
// model asset bundle resources are created in a graph, the model resources
// graph service add the tag_suffix to support multiple resources.
absl::StatusOr<const ModelAssetBundleResources*> absl::StatusOr<const ModelAssetBundleResources*>
CreateModelAssetBundleResources( CreateModelAssetBundleResources(
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file); SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
const std::string tag_suffix = "");
// Inserts a mediapipe task inference subgraph into the provided // Inserts a mediapipe task inference subgraph into the provided
// GraphBuilder. The returned node provides the following interfaces to the // GraphBuilder. The returned node provides the following interfaces to the
@ -124,9 +135,9 @@ class ModelTaskGraph : public Subgraph {
api2::builder::Graph& graph) const; api2::builder::Graph& graph) const;
private: private:
std::unique_ptr<ModelResources> local_model_resources_; std::vector<std::unique_ptr<ModelResources>> local_model_resources_;
std::unique_ptr<ModelAssetBundleResources> std::vector<std::unique_ptr<ModelAssetBundleResources>>
local_model_asset_bundle_resources_; local_model_asset_bundle_resources_;
}; };