Add model bundle in hand landmark task.

PiperOrigin-RevId: 481960266
This commit is contained in:
Yuqi Li 2022-10-18 10:36:43 -07:00 committed by Copybara-Service
parent 6bb5ff989d
commit bc47589c9b
15 changed files with 249 additions and 27 deletions

View File

@ -73,6 +73,7 @@ cc_library(
srcs = ["model_task_graph.cc"],
hdrs = ["model_task_graph.h"],
deps = [
":model_asset_bundle_resources",
":model_resources",
":model_resources_cache",
":model_resources_calculator",
@ -163,6 +164,7 @@ cc_library_with_tflite(
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
],
deps = [
":model_asset_bundle_resources",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:packet",
"//mediapipe/tasks/cc:common",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/strings/substitute.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "tensorflow/lite/core/api/op_resolver.h"
@ -39,12 +40,16 @@ ModelResourcesCache::ModelResourcesCache(
graph_op_resolver_packet_ =
api2::PacketAdopting<tflite::OpResolver>(std::move(graph_op_resolver));
}
};
}
bool ModelResourcesCache::Exists(const std::string& tag) const {
return model_resources_collection_.contains(tag);
}
bool ModelResourcesCache::ModelAssetBundleExists(const std::string& tag) const {
return model_asset_bundle_resources_collection_.contains(tag);
}
absl::Status ModelResourcesCache::AddModelResources(
std::unique_ptr<ModelResources> model_resources) {
if (model_resources == nullptr) {
@ -94,6 +99,62 @@ absl::StatusOr<const ModelResources*> ModelResourcesCache::GetModelResources(
return model_resources_collection_.at(tag).get();
}
absl::Status ModelResourcesCache::AddModelAssetBundleResources(
std::unique_ptr<ModelAssetBundleResources> model_asset_bundle_resources) {
if (model_asset_bundle_resources == nullptr) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"ModelAssetBundleResources object is null.",
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
}
const std::string& tag = model_asset_bundle_resources->GetTag();
if (tag.empty()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"ModelAssetBundleResources must have a non-empty tag.",
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
}
if (ModelAssetBundleExists(tag)) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::Substitute(
"ModelAssetBundleResources with tag \"$0\" already exists.", tag),
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
}
model_asset_bundle_resources_collection_.emplace(
tag, std::move(model_asset_bundle_resources));
return absl::OkStatus();
}
absl::Status ModelResourcesCache::AddModelAssetBundleResourcesCollection(
std::vector<std::unique_ptr<ModelAssetBundleResources>>&
model_asset_bundle_resources_collection) {
for (auto& model_bundle_resources : model_asset_bundle_resources_collection) {
MP_RETURN_IF_ERROR(
AddModelAssetBundleResources(std::move(model_bundle_resources)));
}
return absl::OkStatus();
}
absl::StatusOr<const ModelAssetBundleResources*>
ModelResourcesCache::GetModelAssetBundleResources(
const std::string& tag) const {
if (tag.empty()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"ModelAssetBundleResources must be retrieved with a non-empty tag.",
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
}
if (!ModelAssetBundleExists(tag)) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::Substitute(
"ModelAssetBundleResources with tag \"$0\" does not exist.", tag),
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
}
return model_asset_bundle_resources_collection_.at(tag).get();
}
absl::StatusOr<api2::Packet<tflite::OpResolver>>
ModelResourcesCache::GetGraphOpResolverPacket() const {
if (graph_op_resolver_packet_.IsEmpty()) {

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "tensorflow/lite/core/api/op_resolver.h"
@ -46,6 +47,10 @@ class ModelResourcesCache {
// Returns whether the tag exists in the model resources cache.
bool Exists(const std::string& tag) const;
// Returns whether the tag of the model asset bundle exists in the model
// resources cache.
bool ModelAssetBundleExists(const std::string& tag) const;
// Adds a ModelResources object into the cache.
// The tag of the ModelResources must be unique; the ownership of the
// ModelResources will be transferred into the cache.
@ -62,6 +67,23 @@ class ModelResourcesCache {
absl::StatusOr<const ModelResources*> GetModelResources(
const std::string& tag) const;
// Adds a ModelAssetBundleResources object into the cache.
// The tag of the ModelAssetBundleResources must be unique; the ownership of
// the ModelAssetBundleResources will be transferred into the cache.
absl::Status AddModelAssetBundleResources(
std::unique_ptr<ModelAssetBundleResources> model_asset_bundle_resources);
// Adds a collection of the ModelAssetBundleResources objects into the cache.
// The tag of the each ModelAssetBundleResources must be unique; the ownership
// of every ModelAssetBundleResources will be transferred into the cache.
absl::Status AddModelAssetBundleResourcesCollection(
std::vector<std::unique_ptr<ModelAssetBundleResources>>&
model_asset_bundle_resources_collection);
// Retrieves a const ModelAssetBundleResources pointer by the unique tag.
absl::StatusOr<const ModelAssetBundleResources*> GetModelAssetBundleResources(
const std::string& tag) const;
// Retrieves the graph op resolver packet.
absl::StatusOr<api2::Packet<tflite::OpResolver>> GetGraphOpResolverPacket()
const;
@ -73,6 +95,11 @@ class ModelResourcesCache {
// A collection of ModelResources objects for the models in the graph.
absl::flat_hash_map<std::string, std::unique_ptr<ModelResources>>
model_resources_collection_;
// A collection of ModelAssetBundleResources objects for the model bundles in
// the graph.
absl::flat_hash_map<std::string, std::unique_ptr<ModelAssetBundleResources>>
model_asset_bundle_resources_collection_;
};
// Global service for mediapipe task model resources cache.

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
@ -70,6 +71,17 @@ std::string CreateModelResourcesTag(const CalculatorGraphConfig::Node& node) {
node_type);
}
std::string CreateModelAssetBundleResourcesTag(
const CalculatorGraphConfig::Node& node) {
std::vector<std::string> names = absl::StrSplit(node.name(), "__");
std::string node_type = node.calculator();
std::replace(node_type.begin(), node_type.end(), '.', '_');
absl::AsciiStrToLower(&node_type);
return absl::StrFormat("%s_%s_model_asset_bundle_resources",
names.back().empty() ? "unnamed" : names.back(),
node_type);
}
} // namespace
// Defines the mediapipe task inference unit as a MediaPipe subgraph that
@ -168,6 +180,38 @@ absl::StatusOr<const ModelResources*> ModelTaskGraph::CreateModelResources(
return model_resources_cache_service.GetObject().GetModelResources(tag);
}
absl::StatusOr<const ModelAssetBundleResources*>
ModelTaskGraph::CreateModelAssetBundleResources(
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file) {
auto model_resources_cache_service = sc->Service(kModelResourcesCacheService);
bool has_file_pointer_meta = external_file->has_file_pointer_meta();
// if external file is set by file pointer, no need to add the model asset
// bundle resources into the model resources service since the memory is
// not owned by this model asset bundle resources.
if (!model_resources_cache_service.IsAvailable() || has_file_pointer_meta) {
ASSIGN_OR_RETURN(
local_model_asset_bundle_resources_,
ModelAssetBundleResources::Create("", std::move(external_file)));
if (!has_file_pointer_meta) {
LOG(WARNING)
<< "A local ModelResources object is created. Please consider using "
"ModelResourcesCacheService to cache the created ModelResources "
"object in the CalculatorGraph.";
}
return local_model_asset_bundle_resources_.get();
}
const std::string tag =
CreateModelAssetBundleResourcesTag(sc->OriginalNode());
ASSIGN_OR_RETURN(
auto model_bundle_resources,
ModelAssetBundleResources::Create(tag, std::move(external_file)));
MP_RETURN_IF_ERROR(
model_resources_cache_service.GetObject().AddModelAssetBundleResources(
std::move(model_bundle_resources)));
return model_resources_cache_service.GetObject().GetModelAssetBundleResources(
tag);
}
GenericNode& ModelTaskGraph::AddInference(
const ModelResources& model_resources,
const proto::Acceleration& acceleration, Graph& graph) const {

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/subgraph.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
@ -78,6 +79,35 @@ class ModelTaskGraph : public Subgraph {
absl::StatusOr<const ModelResources*> CreateModelResources(
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file);
// If the model resources graph service is available, creates a model asset
// bundle resources object from the subgraph context, and caches the created
// model asset bundle resources into the model resources graph service on
// success. Otherwise, creates a local model asset bundle resources object
// that can only be used in the graph construction stage. The returned model
// resources pointer will provide graph authors with the access to extracted
// model files.
template <typename Options>
absl::StatusOr<const ModelAssetBundleResources*>
CreateModelAssetBundleResources(SubgraphContext* sc) {
auto external_file = std::make_unique<proto::ExternalFile>();
external_file->Swap(sc->MutableOptions<Options>()
->mutable_base_options()
->mutable_model_asset());
return CreateModelAssetBundleResources(sc, std::move(external_file));
}
// If the model resources graph service is available, creates a model asset
// bundle resources object from the subgraph context, and caches the created
// model asset bundle resources into the model resources graph service on
// success. Otherwise, creates a local model asset bundle resources object
// that can only be used in the graph construction stage. Note that the
// external file contents will be moved into the model asset bundle resources
// object on creation. The returned model asset bundle resources pointer will
// provide graph authors with the access to extracted model files.
absl::StatusOr<const ModelAssetBundleResources*>
CreateModelAssetBundleResources(
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file);
// Inserts a mediapipe task inference subgraph into the provided
// GraphBuilder. The returned node provides the following interfaces to the
// the rest of the graph:
@ -95,6 +125,9 @@ class ModelTaskGraph : public Subgraph {
private:
std::unique_ptr<ModelResources> local_model_resources_;
std::unique_ptr<ModelAssetBundleResources>
local_model_asset_bundle_resources_;
};
} // namespace core

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include <string>
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
@ -162,13 +164,17 @@ absl::Status ExtractFilesfromZipFile(
return absl::OkStatus();
}
void SetExternalFile(const std::string_view& file_content,
core::proto::ExternalFile* model_file) {
void SetExternalFile(const absl::string_view& file_content,
core::proto::ExternalFile* model_file, bool is_copy) {
if (is_copy) {
std::string str_content{file_content};
model_file->set_file_content(str_content);
} else {
auto pointer = reinterpret_cast<uint64_t>(file_content.data());
model_file->mutable_file_pointer_meta()->set_pointer(pointer);
model_file->mutable_file_pointer_meta()->set_length(file_content.length());
}
}
} // namespace metadata
} // namespace tasks

View File

@ -35,10 +35,13 @@ absl::Status ExtractFilesfromZipFile(
const char* buffer_data, const size_t buffer_size,
absl::flat_hash_map<std::string, absl::string_view>* files);
// Set file_pointer_meta in ExternalFile which is the pointer points to location
// of a file in memory by file_content.
void SetExternalFile(const std::string_view& file_content,
core::proto::ExternalFile* model_file);
// Set the ExternalFile object by file_content in memory. By default,
// `is_copy=false` which means to set `file_pointer_meta` in ExternalFile which
// is the pointer points to location of a file in memory. Otherwise, if
// `is_copy=true`, copy the memory into `file_content` in ExternalFile.
void SetExternalFile(const absl::string_view& file_content,
core::proto::ExternalFile* model_file,
bool is_copy = false);
} // namespace metadata
} // namespace tasks

View File

@ -91,10 +91,14 @@ cc_library(
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/utils:gate",
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
"//mediapipe/tasks/cc/core:model_resources_cache",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
"//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator",

View File

@ -29,10 +29,14 @@ limitations under the License.
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/utils/gate.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
@ -50,6 +54,8 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::utils::DisallowIf;
using ::mediapipe::tasks::core::ModelAssetBundleResources;
using ::mediapipe::tasks::metadata::SetExternalFile;
using ::mediapipe::tasks::vision::hand_detector::proto::
HandDetectorGraphOptions;
using ::mediapipe::tasks::vision::hand_landmarker::proto::
@ -65,6 +71,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
constexpr char kPalmRectsTag[] = "PALM_RECTS";
constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator";
constexpr char kHandDetectorTFLiteName[] = "hand_detector.tflite";
constexpr char kHandLandmarksDetectorTFLiteName[] =
"hand_landmarks_detector.tflite";
struct HandLandmarkerOutputs {
Source<std::vector<NormalizedLandmarkList>> landmark_lists;
@ -76,6 +85,27 @@ struct HandLandmarkerOutputs {
Source<Image> image;
};
// Sets the base options in the sub tasks.
absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
HandLandmarkerGraphOptions* options,
bool is_copy) {
ASSIGN_OR_RETURN(const auto hand_detector_file,
resources.GetModelFile(kHandDetectorTFLiteName));
SetExternalFile(hand_detector_file,
options->mutable_hand_detector_graph_options()
->mutable_base_options()
->mutable_model_asset(),
is_copy);
ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file,
resources.GetModelFile(kHandLandmarksDetectorTFLiteName));
SetExternalFile(hand_landmarks_detector_file,
options->mutable_hand_landmarks_detector_graph_options()
->mutable_base_options()
->mutable_model_asset(),
is_copy);
return absl::OkStatus();
}
} // namespace
// A "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" performs hand
@ -154,6 +184,20 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
if (sc->Options<HandLandmarkerGraphOptions>()
.base_options()
.has_model_asset()) {
ASSIGN_OR_RETURN(
const auto* model_asset_bundle_resources,
CreateModelAssetBundleResources<HandLandmarkerGraphOptions>(sc));
// Copies the file content instead of passing the pointer of file in
// memory if the subgraph model resource service is not available.
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
*model_asset_bundle_resources,
sc->MutableOptions<HandLandmarkerGraphOptions>(),
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable()));
}
ASSIGN_OR_RETURN(
auto hand_landmarker_outputs,
BuildHandLandmarkerGraph(sc->Options<HandLandmarkerGraphOptions>(),

View File

@ -65,8 +65,7 @@ using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite";
constexpr char kHandLandmarkerFullModel[] = "hand_landmark_full.tflite";
constexpr char kHandLandmarkerModelBundle[] = "hand_landmark.task";
constexpr char kLeftHandsImage[] = "left_hands.jpg";
constexpr char kImageTag[] = "IMAGE";
@ -105,17 +104,9 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() {
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph");
auto& options =
hand_landmarker_graph.GetOptions<HandLandmarkerGraphOptions>();
options.mutable_hand_detector_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(JoinPath("./", kTestDataDirectory, kPalmDetectionModel));
options.mutable_hand_detector_graph_options()->mutable_base_options();
options.mutable_base_options()->mutable_model_asset()->set_file_name(
JoinPath("./", kTestDataDirectory, kHandLandmarkerModelBundle));
options.mutable_hand_detector_graph_options()->set_num_hands(kMaxNumHands);
options.mutable_hand_landmarks_detector_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(
JoinPath("./", kTestDataDirectory, kHandLandmarkerFullModel));
options.set_min_tracking_confidence(kMinTrackingConfidence);
graph[Input<Image>(kImageTag)].SetName(kImageName) >>

View File

@ -29,8 +29,8 @@ message HandLandmarkerGraphOptions {
extend mediapipe.CalculatorOptions {
optional HandLandmarkerGraphOptions ext = 462713202;
}
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc.
// Base options for configuring MediaPipe Tasks, such as specifying the model
// asset bundle file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
// Options for hand detector graph.

View File

@ -433,8 +433,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder()
.setBaseOptions(
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
.mergeFrom(convertBaseOptionsToProto(baseOptionsHandLandmarker())));
.setUseStreamMode(runningMode() != RunningMode.IMAGE));
minTrackingConfidence()
.ifPresent(handLandmarkerGraphOptionsBuilder::setMinTrackingConfidence);
handLandmarkerGraphOptionsBuilder

View File

@ -35,6 +35,7 @@ mediapipe_files(srcs = [
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
"deeplabv3.tflite",
"hand_landmark.task",
"hand_landmark_full.tflite",
"hand_landmark_lite.tflite",
"left_hands.jpg",
@ -109,6 +110,7 @@ filegroup(
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
"deeplabv3.tflite",
"hand_landmark.task",
"hand_landmark_full.tflite",
"hand_landmark_lite.tflite",
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite",

Binary file not shown.

View File

@ -268,6 +268,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark_lite.tflite?generation=1661875766398729"],
)
http_file(
name = "com_google_mediapipe_hand_landmark_task",
sha256 = "dd830295598e48e6bbbdf22fd9e69538fa07768106cd9ceb04d5462ca7e38c95",
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark.task?generation=1665707323647357"],
)
http_file(
name = "com_google_mediapipe_hand_recrop_tflite",
sha256 = "67d996ce96f9d36fe17d2693022c6da93168026ab2f028f9e2365398d8ac7d5d",