Rename *ModelFile to *File for methods of ModelAssetBundleResources.

PiperOrigin-RevId: 516667461
This commit is contained in:
MediaPipe Team 2023-03-14 16:40:05 -07:00 committed by Copybara-Service
parent cd2cc971bb
commit 9a89b47572
7 changed files with 62 additions and 72 deletions

View File

@ -51,12 +51,11 @@ ModelAssetBundleResources::Create(
auto model_bundle_resources = absl::WrapUnique( auto model_bundle_resources = absl::WrapUnique(
new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file))); new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file)));
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
model_bundle_resources->ExtractModelFilesFromExternalFileProto()); model_bundle_resources->ExtractFilesFromExternalFileProto());
return model_bundle_resources; return model_bundle_resources;
} }
absl::Status absl::Status ModelAssetBundleResources::ExtractFilesFromExternalFileProto() {
ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
if (model_asset_bundle_file_->has_file_name()) { if (model_asset_bundle_file_->has_file_name()) {
// If the model asset bundle file name is a relative path, searches the file // If the model asset bundle file name is a relative path, searches the file
// in a platform-specific location and returns the absolute path on success. // in a platform-specific location and returns the absolute path on success.
@ -72,34 +71,32 @@ ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
model_asset_bundle_file_handler_->GetFileContent().data(); model_asset_bundle_file_handler_->GetFileContent().data();
size_t buffer_size = size_t buffer_size =
model_asset_bundle_file_handler_->GetFileContent().size(); model_asset_bundle_file_handler_->GetFileContent().size();
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, &files_);
&model_files_);
} }
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetModelFile( absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetFile(
const std::string& filename) const { const std::string& filename) const {
auto it = model_files_.find(filename); auto it = files_.find(filename);
if (it == model_files_.end()) { if (it == files_.end()) {
auto model_files = ListModelFiles(); auto files = ListFiles();
std::string all_model_files = std::string all_files = absl::StrJoin(files.begin(), files.end(), ", ");
absl::StrJoin(model_files.begin(), model_files.end(), ", ");
return CreateStatusWithPayload( return CreateStatusWithPayload(
StatusCode::kNotFound, StatusCode::kNotFound,
absl::StrFormat("No model file with name: %s. All model files in the " absl::StrFormat("No file with name: %s. All files in the model asset "
"model asset bundle are: %s.", "bundle are: %s.",
filename, all_model_files), filename, all_files),
MediaPipeTasksStatus::kFileNotFoundError); MediaPipeTasksStatus::kFileNotFoundError);
} }
return it->second; return it->second;
} }
std::vector<std::string> ModelAssetBundleResources::ListModelFiles() const { std::vector<std::string> ModelAssetBundleResources::ListFiles() const {
std::vector<std::string> model_names; std::vector<std::string> file_names;
for (const auto& [model_name, _] : model_files_) { for (const auto& [file_name, _] : files_) {
model_names.push_back(model_name); file_names.push_back(file_name);
} }
return model_names; return file_names;
} }
} // namespace core } // namespace core

View File

@ -28,8 +28,8 @@ namespace core {
// The mediapipe task model asset bundle resources class. // The mediapipe task model asset bundle resources class.
// A ModelAssetBundleResources object, created from an external file proto, // A ModelAssetBundleResources object, created from an external file proto,
// contains model asset bundle related resources and the method to extract the // contains model asset bundle related resources and the method to extract the
// tflite models or model asset bundles for the mediapipe sub-tasks. As the // tflite models, resource files or model asset bundles for the mediapipe
// resources are owned by the ModelAssetBundleResources object // sub-tasks. As the resources are owned by the ModelAssetBundleResources object
// callers must keep ModelAssetBundleResources alive while using any of the // callers must keep ModelAssetBundleResources alive while using any of the
// resources. // resources.
class ModelAssetBundleResources { class ModelAssetBundleResources {
@ -50,14 +50,13 @@ class ModelAssetBundleResources {
// Returns the model asset bundle resources tag. // Returns the model asset bundle resources tag.
std::string GetTag() const { return tag_; } std::string GetTag() const { return tag_; }
// Gets the contents of the model file (either tflite model file or model // Gets the contents of the model file (either tflite model file, resource
// bundle file) with the provided name. An error is returned if there is no // file or model bundle file) with the provided name. An error is returned if
// such model file. // there is no such model file.
absl::StatusOr<absl::string_view> GetModelFile( absl::StatusOr<absl::string_view> GetFile(const std::string& filename) const;
const std::string& filename) const;
// Lists all the model file names in the model asset model. // Lists all the file names in the model asset model.
std::vector<std::string> ListModelFiles() const; std::vector<std::string> ListFiles() const;
private: private:
// Constructor. // Constructor.
@ -65,9 +64,9 @@ class ModelAssetBundleResources {
const std::string& tag, const std::string& tag,
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file); std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
// Extracts the model files (either tflite model file or model bundle file) // Extracts the model files (either tflite model file, resource file or model
// from the external file proto. // bundle file) from the external file proto.
absl::Status ExtractModelFilesFromExternalFileProto(); absl::Status ExtractFilesFromExternalFileProto();
// The model asset bundle resources tag. // The model asset bundle resources tag.
const std::string tag_; const std::string tag_;
@ -78,11 +77,11 @@ class ModelAssetBundleResources {
// The ExternalFileHandler for the model asset bundle. // The ExternalFileHandler for the model asset bundle.
std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_; std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_;
// The model files bundled in model asset bundle, as a map with the filename // The files bundled in model asset bundle, as a map with the filename
// (corresponding to a basename, e.g. "hand_detector.tflite") as key and // (corresponding to a basename, e.g. "hand_detector.tflite") as key and
// a pointer to the file contents as value. Each model file can be either // a pointer to the file contents as value. Each file can be either a TFLite
// a TFLite model file or a model bundle file for sub-task. // model file, resource file or a model bundle file for sub-task.
absl::flat_hash_map<std::string, absl::string_view> model_files_; absl::flat_hash_map<std::string, absl::string_view> files_;
}; };
} // namespace core } // namespace core

View File

@ -66,10 +66,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromBinaryContent) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
@ -81,10 +80,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
@ -98,10 +96,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
#endif // _WIN32 #endif // _WIN32
@ -115,10 +112,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
@ -147,7 +143,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto status_or_model_bundle_file = auto status_or_model_bundle_file =
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task"); model_bundle_resources->GetFile("dummy_hand_landmarker.task");
MP_EXPECT_OK(status_or_model_bundle_file.status()); MP_EXPECT_OK(status_or_model_bundle_file.status());
// Creates sub-task model asset bundle resources. // Creates sub-task model asset bundle resources.
@ -159,10 +155,10 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(hand_landmaker_model_file))); std::move(hand_landmaker_model_file)));
MP_EXPECT_OK(hand_landmaker_model_bundle_resources MP_EXPECT_OK(hand_landmaker_model_bundle_resources
->GetModelFile("dummy_hand_detector.tflite") ->GetFile("dummy_hand_detector.tflite")
.status()); .status());
MP_EXPECT_OK(hand_landmaker_model_bundle_resources MP_EXPECT_OK(hand_landmaker_model_bundle_resources
->GetModelFile("dummy_hand_landmarker.tflite") ->GetFile("dummy_hand_landmarker.tflite")
.status()); .status());
} }
@ -175,7 +171,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidTFLiteModelFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto status_or_model_bundle_file = auto status_or_model_bundle_file =
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite"); model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite");
MP_EXPECT_OK(status_or_model_bundle_file.status()); MP_EXPECT_OK(status_or_model_bundle_file.status());
// Verify tflite model works. // Verify tflite model works.
@ -200,12 +196,12 @@ TEST(ModelAssetBundleResourcesTest, ExtractInvalidModelFile) {
auto model_bundle_resources, auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto status = model_bundle_resources->GetModelFile("not_found.task").status(); auto status = model_bundle_resources->GetFile("not_found.task").status();
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
EXPECT_THAT(status.message(), EXPECT_THAT(
testing::HasSubstr( status.message(),
"No model file with name: not_found.task. All model files in " testing::HasSubstr("No file with name: not_found.task. All files in "
"the model asset bundle are: ")); "the model asset bundle are: "));
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload), EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
testing::Optional(absl::Cord( testing::Optional(absl::Cord(
absl::StrCat(MediaPipeTasksStatus::kFileNotFoundError)))); absl::StrCat(MediaPipeTasksStatus::kFileNotFoundError))));
@ -219,7 +215,7 @@ TEST(ModelAssetBundleResourcesTest, ListModelFiles) {
auto model_bundle_resources, auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto model_files = model_bundle_resources->ListModelFiles(); auto model_files = model_bundle_resources->ListFiles();
std::vector<std::string> expected_model_files = { std::vector<std::string> expected_model_files = {
"dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"}; "dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"};
std::sort(model_files.begin(), model_files.end()); std::sort(model_files.begin(), model_files.end());

View File

@ -116,7 +116,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
options->mutable_face_detector_graph_options(); options->mutable_face_detector_graph_options();
if (!face_detector_graph_options->base_options().has_model_asset()) { if (!face_detector_graph_options->base_options().has_model_asset()) {
ASSIGN_OR_RETURN(const auto face_detector_file, ASSIGN_OR_RETURN(const auto face_detector_file,
resources.GetModelFile(kFaceDetectorTFLiteName)); resources.GetFile(kFaceDetectorTFLiteName));
SetExternalFile(face_detector_file, SetExternalFile(face_detector_file,
face_detector_graph_options->mutable_base_options() face_detector_graph_options->mutable_base_options()
->mutable_model_asset(), ->mutable_model_asset(),
@ -132,7 +132,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
if (!face_landmarks_detector_graph_options->base_options() if (!face_landmarks_detector_graph_options->base_options()
.has_model_asset()) { .has_model_asset()) {
ASSIGN_OR_RETURN(const auto face_landmarks_detector_file, ASSIGN_OR_RETURN(const auto face_landmarks_detector_file,
resources.GetModelFile(kFaceLandmarksDetectorTFLiteName)); resources.GetFile(kFaceLandmarksDetectorTFLiteName));
SetExternalFile( SetExternalFile(
face_landmarks_detector_file, face_landmarks_detector_file,
face_landmarks_detector_graph_options->mutable_base_options() face_landmarks_detector_graph_options->mutable_base_options()
@ -146,7 +146,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
->set_use_stream_mode(options->base_options().use_stream_mode()); ->set_use_stream_mode(options->base_options().use_stream_mode());
absl::StatusOr<absl::string_view> face_blendshape_model = absl::StatusOr<absl::string_view> face_blendshape_model =
resources.GetModelFile(kFaceBlendshapeTFLiteName); resources.GetFile(kFaceBlendshapeTFLiteName);
if (face_blendshape_model.ok()) { if (face_blendshape_model.ok()) {
SetExternalFile(*face_blendshape_model, SetExternalFile(*face_blendshape_model,
face_landmarks_detector_graph_options face_landmarks_detector_graph_options
@ -327,7 +327,7 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
// Set the face geometry metdata file for // Set the face geometry metdata file for
// FaceGeometryFromLandmarksGraph. // FaceGeometryFromLandmarksGraph.
ASSIGN_OR_RETURN(auto face_geometry_pipeline_metadata_file, ASSIGN_OR_RETURN(auto face_geometry_pipeline_metadata_file,
model_asset_bundle_resources->GetModelFile( model_asset_bundle_resources->GetFile(
kFaceGeometryPipelineMetadataName)); kFaceGeometryPipelineMetadataName));
SetExternalFile(face_geometry_pipeline_metadata_file, SetExternalFile(face_geometry_pipeline_metadata_file,
sc->MutableOptions<FaceLandmarkerGraphOptions>() sc->MutableOptions<FaceLandmarkerGraphOptions>()

View File

@ -92,7 +92,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
GestureRecognizerGraphOptions* options, GestureRecognizerGraphOptions* options,
bool is_copy) { bool is_copy) {
ASSIGN_OR_RETURN(const auto hand_landmarker_file, ASSIGN_OR_RETURN(const auto hand_landmarker_file,
resources.GetModelFile(kHandLandmarkerBundleAssetName)); resources.GetFile(kHandLandmarkerBundleAssetName));
auto* hand_landmarker_graph_options = auto* hand_landmarker_graph_options =
options->mutable_hand_landmarker_graph_options(); options->mutable_hand_landmarker_graph_options();
SetExternalFile(hand_landmarker_file, SetExternalFile(hand_landmarker_file,
@ -105,9 +105,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
hand_landmarker_graph_options->mutable_base_options()->set_use_stream_mode( hand_landmarker_graph_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode()); options->base_options().use_stream_mode());
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(const auto hand_gesture_recognizer_file,
const auto hand_gesture_recognizer_file, resources.GetFile(kHandGestureRecognizerBundleAssetName));
resources.GetModelFile(kHandGestureRecognizerBundleAssetName));
auto* hand_gesture_recognizer_graph_options = auto* hand_gesture_recognizer_graph_options =
options->mutable_hand_gesture_recognizer_graph_options(); options->mutable_hand_gesture_recognizer_graph_options();
SetExternalFile(hand_gesture_recognizer_file, SetExternalFile(hand_gesture_recognizer_file,

View File

@ -207,7 +207,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
HandGestureRecognizerGraphOptions* options, HandGestureRecognizerGraphOptions* options,
bool is_copy) { bool is_copy) {
ASSIGN_OR_RETURN(const auto gesture_embedder_file, ASSIGN_OR_RETURN(const auto gesture_embedder_file,
resources.GetModelFile(kGestureEmbedderTFLiteName)); resources.GetFile(kGestureEmbedderTFLiteName));
auto* gesture_embedder_graph_options = auto* gesture_embedder_graph_options =
options->mutable_gesture_embedder_graph_options(); options->mutable_gesture_embedder_graph_options();
SetExternalFile(gesture_embedder_file, SetExternalFile(gesture_embedder_file,
@ -218,9 +218,8 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
options->base_options(), options->base_options(),
gesture_embedder_graph_options->mutable_base_options()); gesture_embedder_graph_options->mutable_base_options());
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(const auto canned_gesture_classifier_file,
const auto canned_gesture_classifier_file, resources.GetFile(kCannedGestureClassifierTFLiteName));
resources.GetModelFile(kCannedGestureClassifierTFLiteName));
auto* canned_gesture_classifier_graph_options = auto* canned_gesture_classifier_graph_options =
options->mutable_canned_gesture_classifier_graph_options(); options->mutable_canned_gesture_classifier_graph_options();
SetExternalFile( SetExternalFile(
@ -233,7 +232,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
canned_gesture_classifier_graph_options->mutable_base_options()); canned_gesture_classifier_graph_options->mutable_base_options());
const auto custom_gesture_classifier_file = const auto custom_gesture_classifier_file =
resources.GetModelFile(kCustomGestureClassifierTFLiteName); resources.GetFile(kCustomGestureClassifierTFLiteName);
if (custom_gesture_classifier_file.ok()) { if (custom_gesture_classifier_file.ok()) {
has_custom_gesture_classifier = true; has_custom_gesture_classifier = true;
auto* custom_gesture_classifier_graph_options = auto* custom_gesture_classifier_graph_options =

View File

@ -97,7 +97,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
options->mutable_hand_detector_graph_options(); options->mutable_hand_detector_graph_options();
if (!hand_detector_graph_options->base_options().has_model_asset()) { if (!hand_detector_graph_options->base_options().has_model_asset()) {
ASSIGN_OR_RETURN(const auto hand_detector_file, ASSIGN_OR_RETURN(const auto hand_detector_file,
resources.GetModelFile(kHandDetectorTFLiteName)); resources.GetFile(kHandDetectorTFLiteName));
SetExternalFile(hand_detector_file, SetExternalFile(hand_detector_file,
hand_detector_graph_options->mutable_base_options() hand_detector_graph_options->mutable_base_options()
->mutable_model_asset(), ->mutable_model_asset(),
@ -113,7 +113,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
if (!hand_landmarks_detector_graph_options->base_options() if (!hand_landmarks_detector_graph_options->base_options()
.has_model_asset()) { .has_model_asset()) {
ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file, ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file,
resources.GetModelFile(kHandLandmarksDetectorTFLiteName)); resources.GetFile(kHandLandmarksDetectorTFLiteName));
SetExternalFile( SetExternalFile(
hand_landmarks_detector_file, hand_landmarks_detector_file,
hand_landmarks_detector_graph_options->mutable_base_options() hand_landmarks_detector_graph_options->mutable_base_options()