Update model_task_graph to support multiple local model resources.
PiperOrigin-RevId: 482917453
This commit is contained in:
parent
4a6c23a76a
commit
ea1d85d811
|
@ -156,21 +156,24 @@ absl::StatusOr<CalculatorGraphConfig> ModelTaskGraph::GetConfig(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<const ModelResources*> ModelTaskGraph::CreateModelResources(
|
absl::StatusOr<const ModelResources*> ModelTaskGraph::CreateModelResources(
|
||||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file) {
|
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
|
||||||
|
const std::string tag_suffix) {
|
||||||
auto model_resources_cache_service = sc->Service(kModelResourcesCacheService);
|
auto model_resources_cache_service = sc->Service(kModelResourcesCacheService);
|
||||||
if (!model_resources_cache_service.IsAvailable()) {
|
if (!model_resources_cache_service.IsAvailable()) {
|
||||||
ASSIGN_OR_RETURN(local_model_resources_,
|
ASSIGN_OR_RETURN(auto local_model_resource,
|
||||||
ModelResources::Create("", std::move(external_file)));
|
ModelResources::Create("", std::move(external_file)));
|
||||||
LOG(WARNING)
|
LOG(WARNING)
|
||||||
<< "A local ModelResources object is created. Please consider using "
|
<< "A local ModelResources object is created. Please consider using "
|
||||||
"ModelResourcesCacheService to cache the created ModelResources "
|
"ModelResourcesCacheService to cache the created ModelResources "
|
||||||
"object in the CalculatorGraph.";
|
"object in the CalculatorGraph.";
|
||||||
return local_model_resources_.get();
|
local_model_resources_.push_back(std::move(local_model_resource));
|
||||||
|
return local_model_resources_.back().get();
|
||||||
}
|
}
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto op_resolver_packet,
|
auto op_resolver_packet,
|
||||||
model_resources_cache_service.GetObject().GetGraphOpResolverPacket());
|
model_resources_cache_service.GetObject().GetGraphOpResolverPacket());
|
||||||
const std::string tag = CreateModelResourcesTag(sc->OriginalNode());
|
const std::string tag =
|
||||||
|
absl::StrCat(CreateModelResourcesTag(sc->OriginalNode()), tag_suffix);
|
||||||
ASSIGN_OR_RETURN(auto model_resources,
|
ASSIGN_OR_RETURN(auto model_resources,
|
||||||
ModelResources::Create(tag, std::move(external_file),
|
ModelResources::Create(tag, std::move(external_file),
|
||||||
op_resolver_packet));
|
op_resolver_packet));
|
||||||
|
@ -182,7 +185,8 @@ absl::StatusOr<const ModelResources*> ModelTaskGraph::CreateModelResources(
|
||||||
|
|
||||||
absl::StatusOr<const ModelAssetBundleResources*>
|
absl::StatusOr<const ModelAssetBundleResources*>
|
||||||
ModelTaskGraph::CreateModelAssetBundleResources(
|
ModelTaskGraph::CreateModelAssetBundleResources(
|
||||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file) {
|
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
|
||||||
|
const std::string tag_suffix) {
|
||||||
auto model_resources_cache_service = sc->Service(kModelResourcesCacheService);
|
auto model_resources_cache_service = sc->Service(kModelResourcesCacheService);
|
||||||
bool has_file_pointer_meta = external_file->has_file_pointer_meta();
|
bool has_file_pointer_meta = external_file->has_file_pointer_meta();
|
||||||
// if external file is set by file pointer, no need to add the model asset
|
// if external file is set by file pointer, no need to add the model asset
|
||||||
|
@ -190,7 +194,7 @@ ModelTaskGraph::CreateModelAssetBundleResources(
|
||||||
// not owned by this model asset bundle resources.
|
// not owned by this model asset bundle resources.
|
||||||
if (!model_resources_cache_service.IsAvailable() || has_file_pointer_meta) {
|
if (!model_resources_cache_service.IsAvailable() || has_file_pointer_meta) {
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
local_model_asset_bundle_resources_,
|
auto local_model_asset_bundle_resource,
|
||||||
ModelAssetBundleResources::Create("", std::move(external_file)));
|
ModelAssetBundleResources::Create("", std::move(external_file)));
|
||||||
if (!has_file_pointer_meta) {
|
if (!has_file_pointer_meta) {
|
||||||
LOG(WARNING)
|
LOG(WARNING)
|
||||||
|
@ -198,10 +202,12 @@ ModelTaskGraph::CreateModelAssetBundleResources(
|
||||||
"ModelResourcesCacheService to cache the created ModelResources "
|
"ModelResourcesCacheService to cache the created ModelResources "
|
||||||
"object in the CalculatorGraph.";
|
"object in the CalculatorGraph.";
|
||||||
}
|
}
|
||||||
return local_model_asset_bundle_resources_.get();
|
local_model_asset_bundle_resources_.push_back(
|
||||||
|
std::move(local_model_asset_bundle_resource));
|
||||||
|
return local_model_asset_bundle_resources_.back().get();
|
||||||
}
|
}
|
||||||
const std::string tag =
|
const std::string tag = absl::StrCat(
|
||||||
CreateModelAssetBundleResourcesTag(sc->OriginalNode());
|
CreateModelAssetBundleResourcesTag(sc->OriginalNode()), tag_suffix);
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto model_bundle_resources,
|
auto model_bundle_resources,
|
||||||
ModelAssetBundleResources::Create(tag, std::move(external_file)));
|
ModelAssetBundleResources::Create(tag, std::move(external_file)));
|
||||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
|
@ -75,9 +76,14 @@ class ModelTaskGraph : public Subgraph {
|
||||||
// construction stage. Note that the external file contents will be moved
|
// construction stage. Note that the external file contents will be moved
|
||||||
// into the model resources object on creation. The returned model resources
|
// into the model resources object on creation. The returned model resources
|
||||||
// pointer will provide graph authors with the access to the metadata
|
// pointer will provide graph authors with the access to the metadata
|
||||||
// extractor and the tflite model.
|
// extractor and the tflite model. When the model resources graph service is
|
||||||
|
// available, a tag is generated internally asscoiated with the created model
|
||||||
|
// resource. If more than one model resources are created in a graph, the
|
||||||
|
// model resources graph service add the tag_suffix to support multiple
|
||||||
|
// resources.
|
||||||
absl::StatusOr<const ModelResources*> CreateModelResources(
|
absl::StatusOr<const ModelResources*> CreateModelResources(
|
||||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file);
|
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
|
||||||
|
const std::string tag_suffix = "");
|
||||||
|
|
||||||
// If the model resources graph service is available, creates a model asset
|
// If the model resources graph service is available, creates a model asset
|
||||||
// bundle resources object from the subgraph context, and caches the created
|
// bundle resources object from the subgraph context, and caches the created
|
||||||
|
@ -103,10 +109,15 @@ class ModelTaskGraph : public Subgraph {
|
||||||
// that can only be used in the graph construction stage. Note that the
|
// that can only be used in the graph construction stage. Note that the
|
||||||
// external file contents will be moved into the model asset bundle resources
|
// external file contents will be moved into the model asset bundle resources
|
||||||
// object on creation. The returned model asset bundle resources pointer will
|
// object on creation. The returned model asset bundle resources pointer will
|
||||||
// provide graph authors with the access to extracted model files.
|
// provide graph authors with the access to extracted model files. When the
|
||||||
|
// model resources graph service is available, a tag is generated internally
|
||||||
|
// asscoiated with the created model asset bundle resource. If more than one
|
||||||
|
// model asset bundle resources are created in a graph, the model resources
|
||||||
|
// graph service add the tag_suffix to support multiple resources.
|
||||||
absl::StatusOr<const ModelAssetBundleResources*>
|
absl::StatusOr<const ModelAssetBundleResources*>
|
||||||
CreateModelAssetBundleResources(
|
CreateModelAssetBundleResources(
|
||||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file);
|
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
|
||||||
|
const std::string tag_suffix = "");
|
||||||
|
|
||||||
// Inserts a mediapipe task inference subgraph into the provided
|
// Inserts a mediapipe task inference subgraph into the provided
|
||||||
// GraphBuilder. The returned node provides the following interfaces to the
|
// GraphBuilder. The returned node provides the following interfaces to the
|
||||||
|
@ -124,9 +135,9 @@ class ModelTaskGraph : public Subgraph {
|
||||||
api2::builder::Graph& graph) const;
|
api2::builder::Graph& graph) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<ModelResources> local_model_resources_;
|
std::vector<std::unique_ptr<ModelResources>> local_model_resources_;
|
||||||
|
|
||||||
std::unique_ptr<ModelAssetBundleResources>
|
std::vector<std::unique_ptr<ModelAssetBundleResources>>
|
||||||
local_model_asset_bundle_resources_;
|
local_model_asset_bundle_resources_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user