Add model bundle in hand landmark task.
PiperOrigin-RevId: 481960266
This commit is contained in:
parent
6bb5ff989d
commit
bc47589c9b
|
@ -73,6 +73,7 @@ cc_library(
|
||||||
srcs = ["model_task_graph.cc"],
|
srcs = ["model_task_graph.cc"],
|
||||||
hdrs = ["model_task_graph.h"],
|
hdrs = ["model_task_graph.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":model_asset_bundle_resources",
|
||||||
":model_resources",
|
":model_resources",
|
||||||
":model_resources_cache",
|
":model_resources_cache",
|
||||||
":model_resources_calculator",
|
":model_resources_calculator",
|
||||||
|
@ -163,6 +164,7 @@ cc_library_with_tflite(
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":model_asset_bundle_resources",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/api2:packet",
|
"//mediapipe/framework/api2:packet",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
|
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
#include "mediapipe/framework/api2/packet.h"
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
|
||||||
|
@ -39,12 +40,16 @@ ModelResourcesCache::ModelResourcesCache(
|
||||||
graph_op_resolver_packet_ =
|
graph_op_resolver_packet_ =
|
||||||
api2::PacketAdopting<tflite::OpResolver>(std::move(graph_op_resolver));
|
api2::PacketAdopting<tflite::OpResolver>(std::move(graph_op_resolver));
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
bool ModelResourcesCache::Exists(const std::string& tag) const {
|
bool ModelResourcesCache::Exists(const std::string& tag) const {
|
||||||
return model_resources_collection_.contains(tag);
|
return model_resources_collection_.contains(tag);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ModelResourcesCache::ModelAssetBundleExists(const std::string& tag) const {
|
||||||
|
return model_asset_bundle_resources_collection_.contains(tag);
|
||||||
|
}
|
||||||
|
|
||||||
absl::Status ModelResourcesCache::AddModelResources(
|
absl::Status ModelResourcesCache::AddModelResources(
|
||||||
std::unique_ptr<ModelResources> model_resources) {
|
std::unique_ptr<ModelResources> model_resources) {
|
||||||
if (model_resources == nullptr) {
|
if (model_resources == nullptr) {
|
||||||
|
@ -94,6 +99,62 @@ absl::StatusOr<const ModelResources*> ModelResourcesCache::GetModelResources(
|
||||||
return model_resources_collection_.at(tag).get();
|
return model_resources_collection_.at(tag).get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::Status ModelResourcesCache::AddModelAssetBundleResources(
|
||||||
|
std::unique_ptr<ModelAssetBundleResources> model_asset_bundle_resources) {
|
||||||
|
if (model_asset_bundle_resources == nullptr) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
"ModelAssetBundleResources object is null.",
|
||||||
|
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
|
||||||
|
}
|
||||||
|
const std::string& tag = model_asset_bundle_resources->GetTag();
|
||||||
|
if (tag.empty()) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
"ModelAssetBundleResources must have a non-empty tag.",
|
||||||
|
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
|
||||||
|
}
|
||||||
|
if (ModelAssetBundleExists(tag)) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::Substitute(
|
||||||
|
"ModelAssetBundleResources with tag \"$0\" already exists.", tag),
|
||||||
|
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
|
||||||
|
}
|
||||||
|
model_asset_bundle_resources_collection_.emplace(
|
||||||
|
tag, std::move(model_asset_bundle_resources));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status ModelResourcesCache::AddModelAssetBundleResourcesCollection(
|
||||||
|
std::vector<std::unique_ptr<ModelAssetBundleResources>>&
|
||||||
|
model_asset_bundle_resources_collection) {
|
||||||
|
for (auto& model_bundle_resources : model_asset_bundle_resources_collection) {
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
AddModelAssetBundleResources(std::move(model_bundle_resources)));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<const ModelAssetBundleResources*>
|
||||||
|
ModelResourcesCache::GetModelAssetBundleResources(
|
||||||
|
const std::string& tag) const {
|
||||||
|
if (tag.empty()) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
"ModelAssetBundleResources must be retrieved with a non-empty tag.",
|
||||||
|
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
|
||||||
|
}
|
||||||
|
if (!ModelAssetBundleExists(tag)) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::Substitute(
|
||||||
|
"ModelAssetBundleResources with tag \"$0\" does not exist.", tag),
|
||||||
|
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
|
||||||
|
}
|
||||||
|
return model_asset_bundle_resources_collection_.at(tag).get();
|
||||||
|
}
|
||||||
|
|
||||||
absl::StatusOr<api2::Packet<tflite::OpResolver>>
|
absl::StatusOr<api2::Packet<tflite::OpResolver>>
|
||||||
ModelResourcesCache::GetGraphOpResolverPacket() const {
|
ModelResourcesCache::GetGraphOpResolverPacket() const {
|
||||||
if (graph_op_resolver_packet_.IsEmpty()) {
|
if (graph_op_resolver_packet_.IsEmpty()) {
|
||||||
|
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/framework/api2/packet.h"
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
|
||||||
|
@ -46,6 +47,10 @@ class ModelResourcesCache {
|
||||||
// Returns whether the tag exists in the model resources cache.
|
// Returns whether the tag exists in the model resources cache.
|
||||||
bool Exists(const std::string& tag) const;
|
bool Exists(const std::string& tag) const;
|
||||||
|
|
||||||
|
// Returns whether the tag of the model asset bundle exists in the model
|
||||||
|
// resources cache.
|
||||||
|
bool ModelAssetBundleExists(const std::string& tag) const;
|
||||||
|
|
||||||
// Adds a ModelResources object into the cache.
|
// Adds a ModelResources object into the cache.
|
||||||
// The tag of the ModelResources must be unique; the ownership of the
|
// The tag of the ModelResources must be unique; the ownership of the
|
||||||
// ModelResources will be transferred into the cache.
|
// ModelResources will be transferred into the cache.
|
||||||
|
@ -62,6 +67,23 @@ class ModelResourcesCache {
|
||||||
absl::StatusOr<const ModelResources*> GetModelResources(
|
absl::StatusOr<const ModelResources*> GetModelResources(
|
||||||
const std::string& tag) const;
|
const std::string& tag) const;
|
||||||
|
|
||||||
|
// Adds a ModelAssetBundleResources object into the cache.
|
||||||
|
// The tag of the ModelAssetBundleResources must be unique; the ownership of
|
||||||
|
// the ModelAssetBundleResources will be transferred into the cache.
|
||||||
|
absl::Status AddModelAssetBundleResources(
|
||||||
|
std::unique_ptr<ModelAssetBundleResources> model_asset_bundle_resources);
|
||||||
|
|
||||||
|
// Adds a collection of the ModelAssetBundleResources objects into the cache.
|
||||||
|
// The tag of the each ModelAssetBundleResources must be unique; the ownership
|
||||||
|
// of every ModelAssetBundleResources will be transferred into the cache.
|
||||||
|
absl::Status AddModelAssetBundleResourcesCollection(
|
||||||
|
std::vector<std::unique_ptr<ModelAssetBundleResources>>&
|
||||||
|
model_asset_bundle_resources_collection);
|
||||||
|
|
||||||
|
// Retrieves a const ModelAssetBundleResources pointer by the unique tag.
|
||||||
|
absl::StatusOr<const ModelAssetBundleResources*> GetModelAssetBundleResources(
|
||||||
|
const std::string& tag) const;
|
||||||
|
|
||||||
// Retrieves the graph op resolver packet.
|
// Retrieves the graph op resolver packet.
|
||||||
absl::StatusOr<api2::Packet<tflite::OpResolver>> GetGraphOpResolverPacket()
|
absl::StatusOr<api2::Packet<tflite::OpResolver>> GetGraphOpResolverPacket()
|
||||||
const;
|
const;
|
||||||
|
@ -73,6 +95,11 @@ class ModelResourcesCache {
|
||||||
// A collection of ModelResources objects for the models in the graph.
|
// A collection of ModelResources objects for the models in the graph.
|
||||||
absl::flat_hash_map<std::string, std::unique_ptr<ModelResources>>
|
absl::flat_hash_map<std::string, std::unique_ptr<ModelResources>>
|
||||||
model_resources_collection_;
|
model_resources_collection_;
|
||||||
|
|
||||||
|
// A collection of ModelAssetBundleResources objects for the model bundles in
|
||||||
|
// the graph.
|
||||||
|
absl::flat_hash_map<std::string, std::unique_ptr<ModelAssetBundleResources>>
|
||||||
|
model_asset_bundle_resources_collection_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Global service for mediapipe task model resources cache.
|
// Global service for mediapipe task model resources cache.
|
||||||
|
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
#include "mediapipe/framework/port/logging.h"
|
#include "mediapipe/framework/port/logging.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||||
|
@ -70,6 +71,17 @@ std::string CreateModelResourcesTag(const CalculatorGraphConfig::Node& node) {
|
||||||
node_type);
|
node_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string CreateModelAssetBundleResourcesTag(
|
||||||
|
const CalculatorGraphConfig::Node& node) {
|
||||||
|
std::vector<std::string> names = absl::StrSplit(node.name(), "__");
|
||||||
|
std::string node_type = node.calculator();
|
||||||
|
std::replace(node_type.begin(), node_type.end(), '.', '_');
|
||||||
|
absl::AsciiStrToLower(&node_type);
|
||||||
|
return absl::StrFormat("%s_%s_model_asset_bundle_resources",
|
||||||
|
names.back().empty() ? "unnamed" : names.back(),
|
||||||
|
node_type);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Defines the mediapipe task inference unit as a MediaPipe subgraph that
|
// Defines the mediapipe task inference unit as a MediaPipe subgraph that
|
||||||
|
@ -168,6 +180,38 @@ absl::StatusOr<const ModelResources*> ModelTaskGraph::CreateModelResources(
|
||||||
return model_resources_cache_service.GetObject().GetModelResources(tag);
|
return model_resources_cache_service.GetObject().GetModelResources(tag);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<const ModelAssetBundleResources*>
|
||||||
|
ModelTaskGraph::CreateModelAssetBundleResources(
|
||||||
|
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file) {
|
||||||
|
auto model_resources_cache_service = sc->Service(kModelResourcesCacheService);
|
||||||
|
bool has_file_pointer_meta = external_file->has_file_pointer_meta();
|
||||||
|
// if external file is set by file pointer, no need to add the model asset
|
||||||
|
// bundle resources into the model resources service since the memory is
|
||||||
|
// not owned by this model asset bundle resources.
|
||||||
|
if (!model_resources_cache_service.IsAvailable() || has_file_pointer_meta) {
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
local_model_asset_bundle_resources_,
|
||||||
|
ModelAssetBundleResources::Create("", std::move(external_file)));
|
||||||
|
if (!has_file_pointer_meta) {
|
||||||
|
LOG(WARNING)
|
||||||
|
<< "A local ModelResources object is created. Please consider using "
|
||||||
|
"ModelResourcesCacheService to cache the created ModelResources "
|
||||||
|
"object in the CalculatorGraph.";
|
||||||
|
}
|
||||||
|
return local_model_asset_bundle_resources_.get();
|
||||||
|
}
|
||||||
|
const std::string tag =
|
||||||
|
CreateModelAssetBundleResourcesTag(sc->OriginalNode());
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
auto model_bundle_resources,
|
||||||
|
ModelAssetBundleResources::Create(tag, std::move(external_file)));
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
model_resources_cache_service.GetObject().AddModelAssetBundleResources(
|
||||||
|
std::move(model_bundle_resources)));
|
||||||
|
return model_resources_cache_service.GetObject().GetModelAssetBundleResources(
|
||||||
|
tag);
|
||||||
|
}
|
||||||
|
|
||||||
GenericNode& ModelTaskGraph::AddInference(
|
GenericNode& ModelTaskGraph::AddInference(
|
||||||
const ModelResources& model_resources,
|
const ModelResources& model_resources,
|
||||||
const proto::Acceleration& acceleration, Graph& graph) const {
|
const proto::Acceleration& acceleration, Graph& graph) const {
|
||||||
|
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/subgraph.h"
|
#include "mediapipe/framework/subgraph.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
|
@ -78,6 +79,35 @@ class ModelTaskGraph : public Subgraph {
|
||||||
absl::StatusOr<const ModelResources*> CreateModelResources(
|
absl::StatusOr<const ModelResources*> CreateModelResources(
|
||||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file);
|
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file);
|
||||||
|
|
||||||
|
// If the model resources graph service is available, creates a model asset
|
||||||
|
// bundle resources object from the subgraph context, and caches the created
|
||||||
|
// model asset bundle resources into the model resources graph service on
|
||||||
|
// success. Otherwise, creates a local model asset bundle resources object
|
||||||
|
// that can only be used in the graph construction stage. The returned model
|
||||||
|
// resources pointer will provide graph authors with the access to extracted
|
||||||
|
// model files.
|
||||||
|
template <typename Options>
|
||||||
|
absl::StatusOr<const ModelAssetBundleResources*>
|
||||||
|
CreateModelAssetBundleResources(SubgraphContext* sc) {
|
||||||
|
auto external_file = std::make_unique<proto::ExternalFile>();
|
||||||
|
external_file->Swap(sc->MutableOptions<Options>()
|
||||||
|
->mutable_base_options()
|
||||||
|
->mutable_model_asset());
|
||||||
|
return CreateModelAssetBundleResources(sc, std::move(external_file));
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the model resources graph service is available, creates a model asset
|
||||||
|
// bundle resources object from the subgraph context, and caches the created
|
||||||
|
// model asset bundle resources into the model resources graph service on
|
||||||
|
// success. Otherwise, creates a local model asset bundle resources object
|
||||||
|
// that can only be used in the graph construction stage. Note that the
|
||||||
|
// external file contents will be moved into the model asset bundle resources
|
||||||
|
// object on creation. The returned model asset bundle resources pointer will
|
||||||
|
// provide graph authors with the access to extracted model files.
|
||||||
|
absl::StatusOr<const ModelAssetBundleResources*>
|
||||||
|
CreateModelAssetBundleResources(
|
||||||
|
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file);
|
||||||
|
|
||||||
// Inserts a mediapipe task inference subgraph into the provided
|
// Inserts a mediapipe task inference subgraph into the provided
|
||||||
// GraphBuilder. The returned node provides the following interfaces to the
|
// GraphBuilder. The returned node provides the following interfaces to the
|
||||||
// the rest of the graph:
|
// the rest of the graph:
|
||||||
|
@ -95,6 +125,9 @@ class ModelTaskGraph : public Subgraph {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<ModelResources> local_model_resources_;
|
std::unique_ptr<ModelResources> local_model_resources_;
|
||||||
|
|
||||||
|
std::unique_ptr<ModelAssetBundleResources>
|
||||||
|
local_model_asset_bundle_resources_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace core
|
} // namespace core
|
||||||
|
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "absl/cleanup/cleanup.h"
|
#include "absl/cleanup/cleanup.h"
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
|
@ -162,12 +164,16 @@ absl::Status ExtractFilesfromZipFile(
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetExternalFile(const std::string_view& file_content,
|
void SetExternalFile(const absl::string_view& file_content,
|
||||||
core::proto::ExternalFile* model_file) {
|
core::proto::ExternalFile* model_file, bool is_copy) {
|
||||||
|
if (is_copy) {
|
||||||
|
std::string str_content{file_content};
|
||||||
|
model_file->set_file_content(str_content);
|
||||||
|
} else {
|
||||||
auto pointer = reinterpret_cast<uint64_t>(file_content.data());
|
auto pointer = reinterpret_cast<uint64_t>(file_content.data());
|
||||||
|
|
||||||
model_file->mutable_file_pointer_meta()->set_pointer(pointer);
|
model_file->mutable_file_pointer_meta()->set_pointer(pointer);
|
||||||
model_file->mutable_file_pointer_meta()->set_length(file_content.length());
|
model_file->mutable_file_pointer_meta()->set_length(file_content.length());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace metadata
|
} // namespace metadata
|
||||||
|
|
|
@ -35,10 +35,13 @@ absl::Status ExtractFilesfromZipFile(
|
||||||
const char* buffer_data, const size_t buffer_size,
|
const char* buffer_data, const size_t buffer_size,
|
||||||
absl::flat_hash_map<std::string, absl::string_view>* files);
|
absl::flat_hash_map<std::string, absl::string_view>* files);
|
||||||
|
|
||||||
// Set file_pointer_meta in ExternalFile which is the pointer points to location
|
// Set the ExternalFile object by file_content in memory. By default,
|
||||||
// of a file in memory by file_content.
|
// `is_copy=false` which means to set `file_pointer_meta` in ExternalFile which
|
||||||
void SetExternalFile(const std::string_view& file_content,
|
// is the pointer points to location of a file in memory. Otherwise, if
|
||||||
core::proto::ExternalFile* model_file);
|
// `is_copy=true`, copy the memory into `file_content` in ExternalFile.
|
||||||
|
void SetExternalFile(const absl::string_view& file_content,
|
||||||
|
core::proto::ExternalFile* model_file,
|
||||||
|
bool is_copy = false);
|
||||||
|
|
||||||
} // namespace metadata
|
} // namespace metadata
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
|
|
|
@ -91,10 +91,14 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/formats:tensor",
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components/utils:gate",
|
"//mediapipe/tasks/cc/components/utils:gate",
|
||||||
|
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
|
||||||
|
"//mediapipe/tasks/cc/core:model_resources_cache",
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
"//mediapipe/tasks/cc/core:utils",
|
"//mediapipe/tasks/cc/core:utils",
|
||||||
|
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
|
||||||
"//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph",
|
"//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph",
|
||||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator",
|
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator",
|
||||||
|
|
|
@ -29,10 +29,14 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/utils/gate.h"
|
#include "mediapipe/tasks/cc/components/utils/gate.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
#include "mediapipe/tasks/cc/core/utils.h"
|
#include "mediapipe/tasks/cc/core/utils.h"
|
||||||
|
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h"
|
#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
|
||||||
|
@ -50,6 +54,8 @@ using ::mediapipe::api2::Output;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::tasks::components::utils::DisallowIf;
|
using ::mediapipe::tasks::components::utils::DisallowIf;
|
||||||
|
using ::mediapipe::tasks::core::ModelAssetBundleResources;
|
||||||
|
using ::mediapipe::tasks::metadata::SetExternalFile;
|
||||||
using ::mediapipe::tasks::vision::hand_detector::proto::
|
using ::mediapipe::tasks::vision::hand_detector::proto::
|
||||||
HandDetectorGraphOptions;
|
HandDetectorGraphOptions;
|
||||||
using ::mediapipe::tasks::vision::hand_landmarker::proto::
|
using ::mediapipe::tasks::vision::hand_landmarker::proto::
|
||||||
|
@ -65,6 +71,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS";
|
||||||
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
|
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
|
||||||
constexpr char kPalmRectsTag[] = "PALM_RECTS";
|
constexpr char kPalmRectsTag[] = "PALM_RECTS";
|
||||||
constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator";
|
constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator";
|
||||||
|
constexpr char kHandDetectorTFLiteName[] = "hand_detector.tflite";
|
||||||
|
constexpr char kHandLandmarksDetectorTFLiteName[] =
|
||||||
|
"hand_landmarks_detector.tflite";
|
||||||
|
|
||||||
struct HandLandmarkerOutputs {
|
struct HandLandmarkerOutputs {
|
||||||
Source<std::vector<NormalizedLandmarkList>> landmark_lists;
|
Source<std::vector<NormalizedLandmarkList>> landmark_lists;
|
||||||
|
@ -76,6 +85,27 @@ struct HandLandmarkerOutputs {
|
||||||
Source<Image> image;
|
Source<Image> image;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Sets the base options in the sub tasks.
|
||||||
|
absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
||||||
|
HandLandmarkerGraphOptions* options,
|
||||||
|
bool is_copy) {
|
||||||
|
ASSIGN_OR_RETURN(const auto hand_detector_file,
|
||||||
|
resources.GetModelFile(kHandDetectorTFLiteName));
|
||||||
|
SetExternalFile(hand_detector_file,
|
||||||
|
options->mutable_hand_detector_graph_options()
|
||||||
|
->mutable_base_options()
|
||||||
|
->mutable_model_asset(),
|
||||||
|
is_copy);
|
||||||
|
ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file,
|
||||||
|
resources.GetModelFile(kHandLandmarksDetectorTFLiteName));
|
||||||
|
SetExternalFile(hand_landmarks_detector_file,
|
||||||
|
options->mutable_hand_landmarks_detector_graph_options()
|
||||||
|
->mutable_base_options()
|
||||||
|
->mutable_model_asset(),
|
||||||
|
is_copy);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// A "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" performs hand
|
// A "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" performs hand
|
||||||
|
@ -154,6 +184,20 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||||
SubgraphContext* sc) override {
|
SubgraphContext* sc) override {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
|
if (sc->Options<HandLandmarkerGraphOptions>()
|
||||||
|
.base_options()
|
||||||
|
.has_model_asset()) {
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
const auto* model_asset_bundle_resources,
|
||||||
|
CreateModelAssetBundleResources<HandLandmarkerGraphOptions>(sc));
|
||||||
|
// Copies the file content instead of passing the pointer of file in
|
||||||
|
// memory if the subgraph model resource service is not available.
|
||||||
|
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
|
||||||
|
*model_asset_bundle_resources,
|
||||||
|
sc->MutableOptions<HandLandmarkerGraphOptions>(),
|
||||||
|
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
|
||||||
|
.IsAvailable()));
|
||||||
|
}
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto hand_landmarker_outputs,
|
auto hand_landmarker_outputs,
|
||||||
BuildHandLandmarkerGraph(sc->Options<HandLandmarkerGraphOptions>(),
|
BuildHandLandmarkerGraph(sc->Options<HandLandmarkerGraphOptions>(),
|
||||||
|
|
|
@ -65,8 +65,7 @@ using ::testing::proto::Approximately;
|
||||||
using ::testing::proto::Partially;
|
using ::testing::proto::Partially;
|
||||||
|
|
||||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||||
constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite";
|
constexpr char kHandLandmarkerModelBundle[] = "hand_landmark.task";
|
||||||
constexpr char kHandLandmarkerFullModel[] = "hand_landmark_full.tflite";
|
|
||||||
constexpr char kLeftHandsImage[] = "left_hands.jpg";
|
constexpr char kLeftHandsImage[] = "left_hands.jpg";
|
||||||
|
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
@ -105,17 +104,9 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() {
|
||||||
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph");
|
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph");
|
||||||
auto& options =
|
auto& options =
|
||||||
hand_landmarker_graph.GetOptions<HandLandmarkerGraphOptions>();
|
hand_landmarker_graph.GetOptions<HandLandmarkerGraphOptions>();
|
||||||
options.mutable_hand_detector_graph_options()
|
options.mutable_base_options()->mutable_model_asset()->set_file_name(
|
||||||
->mutable_base_options()
|
JoinPath("./", kTestDataDirectory, kHandLandmarkerModelBundle));
|
||||||
->mutable_model_asset()
|
|
||||||
->set_file_name(JoinPath("./", kTestDataDirectory, kPalmDetectionModel));
|
|
||||||
options.mutable_hand_detector_graph_options()->mutable_base_options();
|
|
||||||
options.mutable_hand_detector_graph_options()->set_num_hands(kMaxNumHands);
|
options.mutable_hand_detector_graph_options()->set_num_hands(kMaxNumHands);
|
||||||
options.mutable_hand_landmarks_detector_graph_options()
|
|
||||||
->mutable_base_options()
|
|
||||||
->mutable_model_asset()
|
|
||||||
->set_file_name(
|
|
||||||
JoinPath("./", kTestDataDirectory, kHandLandmarkerFullModel));
|
|
||||||
options.set_min_tracking_confidence(kMinTrackingConfidence);
|
options.set_min_tracking_confidence(kMinTrackingConfidence);
|
||||||
|
|
||||||
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
|
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
|
||||||
|
|
|
@ -29,8 +29,8 @@ message HandLandmarkerGraphOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional HandLandmarkerGraphOptions ext = 462713202;
|
optional HandLandmarkerGraphOptions ext = 462713202;
|
||||||
}
|
}
|
||||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||||
// model file with metadata, accelerator options, etc.
|
// asset bundle file with metadata, accelerator options, etc.
|
||||||
optional core.proto.BaseOptions base_options = 1;
|
optional core.proto.BaseOptions base_options = 1;
|
||||||
|
|
||||||
// Options for hand detector graph.
|
// Options for hand detector graph.
|
||||||
|
|
|
@ -433,8 +433,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
||||||
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder()
|
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder()
|
||||||
.setBaseOptions(
|
.setBaseOptions(
|
||||||
BaseOptionsProto.BaseOptions.newBuilder()
|
BaseOptionsProto.BaseOptions.newBuilder()
|
||||||
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
|
.setUseStreamMode(runningMode() != RunningMode.IMAGE));
|
||||||
.mergeFrom(convertBaseOptionsToProto(baseOptionsHandLandmarker())));
|
|
||||||
minTrackingConfidence()
|
minTrackingConfidence()
|
||||||
.ifPresent(handLandmarkerGraphOptionsBuilder::setMinTrackingConfidence);
|
.ifPresent(handLandmarkerGraphOptionsBuilder::setMinTrackingConfidence);
|
||||||
handLandmarkerGraphOptionsBuilder
|
handLandmarkerGraphOptionsBuilder
|
||||||
|
|
2
mediapipe/tasks/testdata/vision/BUILD
vendored
2
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -35,6 +35,7 @@ mediapipe_files(srcs = [
|
||||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
|
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
|
||||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
|
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
|
||||||
"deeplabv3.tflite",
|
"deeplabv3.tflite",
|
||||||
|
"hand_landmark.task",
|
||||||
"hand_landmark_full.tflite",
|
"hand_landmark_full.tflite",
|
||||||
"hand_landmark_lite.tflite",
|
"hand_landmark_lite.tflite",
|
||||||
"left_hands.jpg",
|
"left_hands.jpg",
|
||||||
|
@ -109,6 +110,7 @@ filegroup(
|
||||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
|
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
|
||||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
|
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
|
||||||
"deeplabv3.tflite",
|
"deeplabv3.tflite",
|
||||||
|
"hand_landmark.task",
|
||||||
"hand_landmark_full.tflite",
|
"hand_landmark_full.tflite",
|
||||||
"hand_landmark_lite.tflite",
|
"hand_landmark_lite.tflite",
|
||||||
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite",
|
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite",
|
||||||
|
|
BIN
mediapipe/tasks/testdata/vision/hand_landmark.task
vendored
Normal file
BIN
mediapipe/tasks/testdata/vision/hand_landmark.task
vendored
Normal file
Binary file not shown.
6
third_party/external_files.bzl
vendored
6
third_party/external_files.bzl
vendored
|
@ -268,6 +268,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark_lite.tflite?generation=1661875766398729"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark_lite.tflite?generation=1661875766398729"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_hand_landmark_task",
|
||||||
|
sha256 = "dd830295598e48e6bbbdf22fd9e69538fa07768106cd9ceb04d5462ca7e38c95",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark.task?generation=1665707323647357"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_hand_recrop_tflite",
|
name = "com_google_mediapipe_hand_recrop_tflite",
|
||||||
sha256 = "67d996ce96f9d36fe17d2693022c6da93168026ab2f028f9e2365398d8ac7d5d",
|
sha256 = "67d996ce96f9d36fe17d2693022c6da93168026ab2f028f9e2365398d8ac7d5d",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user