diff --git a/mediapipe/tasks/cc/core/model_asset_bundle_resources.cc b/mediapipe/tasks/cc/core/model_asset_bundle_resources.cc index 5867be49b..2f53ff2d5 100644 --- a/mediapipe/tasks/cc/core/model_asset_bundle_resources.cc +++ b/mediapipe/tasks/cc/core/model_asset_bundle_resources.cc @@ -51,12 +51,11 @@ ModelAssetBundleResources::Create( auto model_bundle_resources = absl::WrapUnique( new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file))); MP_RETURN_IF_ERROR( - model_bundle_resources->ExtractModelFilesFromExternalFileProto()); + model_bundle_resources->ExtractFilesFromExternalFileProto()); return model_bundle_resources; } -absl::Status -ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() { +absl::Status ModelAssetBundleResources::ExtractFilesFromExternalFileProto() { if (model_asset_bundle_file_->has_file_name()) { // If the model asset bundle file name is a relative path, searches the file // in a platform-specific location and returns the absolute path on success. @@ -72,34 +71,32 @@ ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() { model_asset_bundle_file_handler_->GetFileContent().data(); size_t buffer_size = model_asset_bundle_file_handler_->GetFileContent().size(); - return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, - &model_files_); + return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, &files_); } -absl::StatusOr ModelAssetBundleResources::GetModelFile( +absl::StatusOr ModelAssetBundleResources::GetFile( const std::string& filename) const { - auto it = model_files_.find(filename); - if (it == model_files_.end()) { - auto model_files = ListModelFiles(); - std::string all_model_files = - absl::StrJoin(model_files.begin(), model_files.end(), ", "); + auto it = files_.find(filename); + if (it == files_.end()) { + auto files = ListFiles(); + std::string all_files = absl::StrJoin(files.begin(), files.end(), ", "); return CreateStatusWithPayload( StatusCode::kNotFound, - absl::StrFormat("No model file with name: %s. All model files in the " - "model asset bundle are: %s.", - filename, all_model_files), + absl::StrFormat("No file with name: %s. All files in the model asset " + "bundle are: %s.", + filename, all_files), MediaPipeTasksStatus::kFileNotFoundError); } return it->second; } -std::vector ModelAssetBundleResources::ListModelFiles() const { - std::vector model_names; - for (const auto& [model_name, _] : model_files_) { - model_names.push_back(model_name); +std::vector ModelAssetBundleResources::ListFiles() const { + std::vector file_names; + for (const auto& [file_name, _] : files_) { + file_names.push_back(file_name); } - return model_names; + return file_names; } } // namespace core diff --git a/mediapipe/tasks/cc/core/model_asset_bundle_resources.h b/mediapipe/tasks/cc/core/model_asset_bundle_resources.h index 61474d3ad..02d989d4b 100644 --- a/mediapipe/tasks/cc/core/model_asset_bundle_resources.h +++ b/mediapipe/tasks/cc/core/model_asset_bundle_resources.h @@ -28,8 +28,8 @@ namespace core { // The mediapipe task model asset bundle resources class. // A ModelAssetBundleResources object, created from an external file proto, // contains model asset bundle related resources and the method to extract the -// tflite models or model asset bundles for the mediapipe sub-tasks. As the -// resources are owned by the ModelAssetBundleResources object +// tflite models, resource files or model asset bundles for the mediapipe +// sub-tasks. As the resources are owned by the ModelAssetBundleResources object // callers must keep ModelAssetBundleResources alive while using any of the // resources. class ModelAssetBundleResources { @@ -50,14 +50,13 @@ class ModelAssetBundleResources { // Returns the model asset bundle resources tag. std::string GetTag() const { return tag_; } - // Gets the contents of the model file (either tflite model file or model - // bundle file) with the provided name. An error is returned if there is no - // such model file. - absl::StatusOr GetModelFile( - const std::string& filename) const; + // Gets the contents of the model file (either tflite model file, resource + // file or model bundle file) with the provided name. An error is returned if + // there is no such model file. + absl::StatusOr GetFile(const std::string& filename) const; - // Lists all the model file names in the model asset model. - std::vector ListModelFiles() const; + // Lists all the file names in the model asset model. + std::vector ListFiles() const; private: // Constructor. @@ -65,9 +64,9 @@ class ModelAssetBundleResources { const std::string& tag, std::unique_ptr model_asset_bundle_file); - // Extracts the model files (either tflite model file or model bundle file) - // from the external file proto. - absl::Status ExtractModelFilesFromExternalFileProto(); + // Extracts the model files (either tflite model file, resource file or model + // bundle file) from the external file proto. + absl::Status ExtractFilesFromExternalFileProto(); // The model asset bundle resources tag. const std::string tag_; @@ -78,11 +77,11 @@ class ModelAssetBundleResources { // The ExternalFileHandler for the model asset bundle. std::unique_ptr model_asset_bundle_file_handler_; - // The model files bundled in model asset bundle, as a map with the filename + // The files bundled in model asset bundle, as a map with the filename // (corresponding to a basename, e.g. "hand_detector.tflite") as key and - // a pointer to the file contents as value. Each model file can be either - // a TFLite model file or a model bundle file for sub-task. - absl::flat_hash_map model_files_; + // a pointer to the file contents as value. Each file can be either a TFLite + // model file, resource file or a model bundle file for sub-task. + absl::flat_hash_map files_; }; } // namespace core diff --git a/mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc b/mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc index 359deef91..85a94ccc7 100644 --- a/mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc +++ b/mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc @@ -66,10 +66,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromBinaryContent) { ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") - .status()); + model_bundle_resources->GetFile("dummy_hand_landmarker.task").status()); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") + model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite") .status()); } @@ -81,10 +80,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFile) { ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") - .status()); + model_bundle_resources->GetFile("dummy_hand_landmarker.task").status()); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") + model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite") .status()); } @@ -98,10 +96,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) { ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") - .status()); + model_bundle_resources->GetFile("dummy_hand_landmarker.task").status()); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") + model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite") .status()); } #endif // _WIN32 @@ -115,10 +112,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) { ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") - .status()); + model_bundle_resources->GetFile("dummy_hand_landmarker.task").status()); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") + model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite") .status()); } @@ -147,7 +143,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) { ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); auto status_or_model_bundle_file = - model_bundle_resources->GetModelFile("dummy_hand_landmarker.task"); + model_bundle_resources->GetFile("dummy_hand_landmarker.task"); MP_EXPECT_OK(status_or_model_bundle_file.status()); // Creates sub-task model asset bundle resources. @@ -159,10 +155,10 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) { ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(hand_landmaker_model_file))); MP_EXPECT_OK(hand_landmaker_model_bundle_resources - ->GetModelFile("dummy_hand_detector.tflite") + ->GetFile("dummy_hand_detector.tflite") .status()); MP_EXPECT_OK(hand_landmaker_model_bundle_resources - ->GetModelFile("dummy_hand_landmarker.tflite") + ->GetFile("dummy_hand_landmarker.tflite") .status()); } @@ -175,7 +171,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidTFLiteModelFile) { ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); auto status_or_model_bundle_file = - model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite"); + model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite"); MP_EXPECT_OK(status_or_model_bundle_file.status()); // Verify tflite model works. @@ -200,12 +196,12 @@ TEST(ModelAssetBundleResourcesTest, ExtractInvalidModelFile) { auto model_bundle_resources, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); - auto status = model_bundle_resources->GetModelFile("not_found.task").status(); + auto status = model_bundle_resources->GetFile("not_found.task").status(); EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); - EXPECT_THAT(status.message(), - testing::HasSubstr( - "No model file with name: not_found.task. All model files in " - "the model asset bundle are: ")); + EXPECT_THAT( + status.message(), + testing::HasSubstr("No file with name: not_found.task. All files in " + "the model asset bundle are: ")); EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload), testing::Optional(absl::Cord( absl::StrCat(MediaPipeTasksStatus::kFileNotFoundError)))); @@ -219,7 +215,7 @@ TEST(ModelAssetBundleResourcesTest, ListModelFiles) { auto model_bundle_resources, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); - auto model_files = model_bundle_resources->ListModelFiles(); + auto model_files = model_bundle_resources->ListFiles(); std::vector expected_model_files = { "dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"}; std::sort(model_files.begin(), model_files.end()); diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc index d6cc630b2..78927f27b 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc @@ -116,7 +116,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, options->mutable_face_detector_graph_options(); if (!face_detector_graph_options->base_options().has_model_asset()) { ASSIGN_OR_RETURN(const auto face_detector_file, - resources.GetModelFile(kFaceDetectorTFLiteName)); + resources.GetFile(kFaceDetectorTFLiteName)); SetExternalFile(face_detector_file, face_detector_graph_options->mutable_base_options() ->mutable_model_asset(), @@ -132,7 +132,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, if (!face_landmarks_detector_graph_options->base_options() .has_model_asset()) { ASSIGN_OR_RETURN(const auto face_landmarks_detector_file, - resources.GetModelFile(kFaceLandmarksDetectorTFLiteName)); + resources.GetFile(kFaceLandmarksDetectorTFLiteName)); SetExternalFile( face_landmarks_detector_file, face_landmarks_detector_graph_options->mutable_base_options() @@ -146,7 +146,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, ->set_use_stream_mode(options->base_options().use_stream_mode()); absl::StatusOr face_blendshape_model = - resources.GetModelFile(kFaceBlendshapeTFLiteName); + resources.GetFile(kFaceBlendshapeTFLiteName); if (face_blendshape_model.ok()) { SetExternalFile(*face_blendshape_model, face_landmarks_detector_graph_options @@ -327,7 +327,7 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph { // Set the face geometry metdata file for // FaceGeometryFromLandmarksGraph. ASSIGN_OR_RETURN(auto face_geometry_pipeline_metadata_file, - model_asset_bundle_resources->GetModelFile( + model_asset_bundle_resources->GetFile( kFaceGeometryPipelineMetadataName)); SetExternalFile(face_geometry_pipeline_metadata_file, sc->MutableOptions() diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc index 11b2d12c4..55db07cb8 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -92,7 +92,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, GestureRecognizerGraphOptions* options, bool is_copy) { ASSIGN_OR_RETURN(const auto hand_landmarker_file, - resources.GetModelFile(kHandLandmarkerBundleAssetName)); + resources.GetFile(kHandLandmarkerBundleAssetName)); auto* hand_landmarker_graph_options = options->mutable_hand_landmarker_graph_options(); SetExternalFile(hand_landmarker_file, @@ -105,9 +105,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, hand_landmarker_graph_options->mutable_base_options()->set_use_stream_mode( options->base_options().use_stream_mode()); - ASSIGN_OR_RETURN( - const auto hand_gesture_recognizer_file, - resources.GetModelFile(kHandGestureRecognizerBundleAssetName)); + ASSIGN_OR_RETURN(const auto hand_gesture_recognizer_file, + resources.GetFile(kHandGestureRecognizerBundleAssetName)); auto* hand_gesture_recognizer_graph_options = options->mutable_hand_gesture_recognizer_graph_options(); SetExternalFile(hand_gesture_recognizer_file, diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 4db57e85b..3fe999937 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -207,7 +207,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { HandGestureRecognizerGraphOptions* options, bool is_copy) { ASSIGN_OR_RETURN(const auto gesture_embedder_file, - resources.GetModelFile(kGestureEmbedderTFLiteName)); + resources.GetFile(kGestureEmbedderTFLiteName)); auto* gesture_embedder_graph_options = options->mutable_gesture_embedder_graph_options(); SetExternalFile(gesture_embedder_file, @@ -218,9 +218,8 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { options->base_options(), gesture_embedder_graph_options->mutable_base_options()); - ASSIGN_OR_RETURN( - const auto canned_gesture_classifier_file, - resources.GetModelFile(kCannedGestureClassifierTFLiteName)); + ASSIGN_OR_RETURN(const auto canned_gesture_classifier_file, + resources.GetFile(kCannedGestureClassifierTFLiteName)); auto* canned_gesture_classifier_graph_options = options->mutable_canned_gesture_classifier_graph_options(); SetExternalFile( @@ -233,7 +232,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { canned_gesture_classifier_graph_options->mutable_base_options()); const auto custom_gesture_classifier_file = - resources.GetModelFile(kCustomGestureClassifierTFLiteName); + resources.GetFile(kCustomGestureClassifierTFLiteName); if (custom_gesture_classifier_file.ok()) { has_custom_gesture_classifier = true; auto* custom_gesture_classifier_graph_options = diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index 21e43fc82..b37141005 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -97,7 +97,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, options->mutable_hand_detector_graph_options(); if (!hand_detector_graph_options->base_options().has_model_asset()) { ASSIGN_OR_RETURN(const auto hand_detector_file, - resources.GetModelFile(kHandDetectorTFLiteName)); + resources.GetFile(kHandDetectorTFLiteName)); SetExternalFile(hand_detector_file, hand_detector_graph_options->mutable_base_options() ->mutable_model_asset(), @@ -113,7 +113,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, if (!hand_landmarks_detector_graph_options->base_options() .has_model_asset()) { ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file, - resources.GetModelFile(kHandLandmarksDetectorTFLiteName)); + resources.GetFile(kHandLandmarksDetectorTFLiteName)); SetExternalFile( hand_landmarks_detector_file, hand_landmarks_detector_graph_options->mutable_base_options()