Rename *ModelFile to *File for methods of ModelAssetBundleResources.

PiperOrigin-RevId: 516667461
This commit is contained in:
MediaPipe Team 2023-03-14 16:40:05 -07:00 committed by Copybara-Service
parent cd2cc971bb
commit 9a89b47572
7 changed files with 62 additions and 72 deletions

View File

@ -51,12 +51,11 @@ ModelAssetBundleResources::Create(
auto model_bundle_resources = absl::WrapUnique(
new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file)));
MP_RETURN_IF_ERROR(
model_bundle_resources->ExtractModelFilesFromExternalFileProto());
model_bundle_resources->ExtractFilesFromExternalFileProto());
return model_bundle_resources;
}
absl::Status
ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
absl::Status ModelAssetBundleResources::ExtractFilesFromExternalFileProto() {
if (model_asset_bundle_file_->has_file_name()) {
// If the model asset bundle file name is a relative path, searches the file
// in a platform-specific location and returns the absolute path on success.
@ -72,34 +71,32 @@ ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
model_asset_bundle_file_handler_->GetFileContent().data();
size_t buffer_size =
model_asset_bundle_file_handler_->GetFileContent().size();
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size,
&model_files_);
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, &files_);
}
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetModelFile(
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetFile(
const std::string& filename) const {
auto it = model_files_.find(filename);
if (it == model_files_.end()) {
auto model_files = ListModelFiles();
std::string all_model_files =
absl::StrJoin(model_files.begin(), model_files.end(), ", ");
auto it = files_.find(filename);
if (it == files_.end()) {
auto files = ListFiles();
std::string all_files = absl::StrJoin(files.begin(), files.end(), ", ");
return CreateStatusWithPayload(
StatusCode::kNotFound,
absl::StrFormat("No model file with name: %s. All model files in the "
"model asset bundle are: %s.",
filename, all_model_files),
absl::StrFormat("No file with name: %s. All files in the model asset "
"bundle are: %s.",
filename, all_files),
MediaPipeTasksStatus::kFileNotFoundError);
}
return it->second;
}
std::vector<std::string> ModelAssetBundleResources::ListModelFiles() const {
std::vector<std::string> model_names;
for (const auto& [model_name, _] : model_files_) {
model_names.push_back(model_name);
std::vector<std::string> ModelAssetBundleResources::ListFiles() const {
std::vector<std::string> file_names;
for (const auto& [file_name, _] : files_) {
file_names.push_back(file_name);
}
return model_names;
return file_names;
}
} // namespace core

View File

@ -28,8 +28,8 @@ namespace core {
// The mediapipe task model asset bundle resources class.
// A ModelAssetBundleResources object, created from an external file proto,
// contains model asset bundle related resources and the method to extract the
// tflite models or model asset bundles for the mediapipe sub-tasks. As the
// resources are owned by the ModelAssetBundleResources object
// tflite models, resource files or model asset bundles for the mediapipe
// sub-tasks. As the resources are owned by the ModelAssetBundleResources object
// callers must keep ModelAssetBundleResources alive while using any of the
// resources.
class ModelAssetBundleResources {
@ -50,14 +50,13 @@ class ModelAssetBundleResources {
// Returns the model asset bundle resources tag.
std::string GetTag() const { return tag_; }
// Gets the contents of the model file (either tflite model file or model
// bundle file) with the provided name. An error is returned if there is no
// such model file.
absl::StatusOr<absl::string_view> GetModelFile(
const std::string& filename) const;
// Gets the contents of the model file (either tflite model file, resource
// file or model bundle file) with the provided name. An error is returned if
// there is no such model file.
absl::StatusOr<absl::string_view> GetFile(const std::string& filename) const;
// Lists all the model file names in the model asset model.
std::vector<std::string> ListModelFiles() const;
// Lists all the file names in the model asset model.
std::vector<std::string> ListFiles() const;
private:
// Constructor.
@ -65,9 +64,9 @@ class ModelAssetBundleResources {
const std::string& tag,
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
// Extracts the model files (either tflite model file or model bundle file)
// from the external file proto.
absl::Status ExtractModelFilesFromExternalFileProto();
// Extracts the model files (either tflite model file, resource file or model
// bundle file) from the external file proto.
absl::Status ExtractFilesFromExternalFileProto();
// The model asset bundle resources tag.
const std::string tag_;
@ -78,11 +77,11 @@ class ModelAssetBundleResources {
// The ExternalFileHandler for the model asset bundle.
std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_;
// The model files bundled in model asset bundle, as a map with the filename
// The files bundled in model asset bundle, as a map with the filename
// (corresponding to a basename, e.g. "hand_detector.tflite") as key and
// a pointer to the file contents as value. Each model file can be either
// a TFLite model file or a model bundle file for sub-task.
absl::flat_hash_map<std::string, absl::string_view> model_files_;
// a pointer to the file contents as value. Each file can be either a TFLite
// model file, resource file or a model bundle file for sub-task.
absl::flat_hash_map<std::string, absl::string_view> files_;
};
} // namespace core

View File

@ -66,10 +66,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromBinaryContent) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
.status());
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status());
}
@ -81,10 +80,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
.status());
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status());
}
@ -98,10 +96,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
.status());
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status());
}
#endif // _WIN32
@ -115,10 +112,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
.status());
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status());
}
@ -147,7 +143,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
auto status_or_model_bundle_file =
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task");
model_bundle_resources->GetFile("dummy_hand_landmarker.task");
MP_EXPECT_OK(status_or_model_bundle_file.status());
// Creates sub-task model asset bundle resources.
@ -159,10 +155,10 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(hand_landmaker_model_file)));
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
->GetModelFile("dummy_hand_detector.tflite")
->GetFile("dummy_hand_detector.tflite")
.status());
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
->GetModelFile("dummy_hand_landmarker.tflite")
->GetFile("dummy_hand_landmarker.tflite")
.status());
}
@ -175,7 +171,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidTFLiteModelFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
auto status_or_model_bundle_file =
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite");
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite");
MP_EXPECT_OK(status_or_model_bundle_file.status());
// Verify tflite model works.
@ -200,11 +196,11 @@ TEST(ModelAssetBundleResourcesTest, ExtractInvalidModelFile) {
auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
auto status = model_bundle_resources->GetModelFile("not_found.task").status();
auto status = model_bundle_resources->GetFile("not_found.task").status();
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
EXPECT_THAT(status.message(),
testing::HasSubstr(
"No model file with name: not_found.task. All model files in "
EXPECT_THAT(
status.message(),
testing::HasSubstr("No file with name: not_found.task. All files in "
"the model asset bundle are: "));
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
testing::Optional(absl::Cord(
@ -219,7 +215,7 @@ TEST(ModelAssetBundleResourcesTest, ListModelFiles) {
auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
auto model_files = model_bundle_resources->ListModelFiles();
auto model_files = model_bundle_resources->ListFiles();
std::vector<std::string> expected_model_files = {
"dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"};
std::sort(model_files.begin(), model_files.end());

View File

@ -116,7 +116,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
options->mutable_face_detector_graph_options();
if (!face_detector_graph_options->base_options().has_model_asset()) {
ASSIGN_OR_RETURN(const auto face_detector_file,
resources.GetModelFile(kFaceDetectorTFLiteName));
resources.GetFile(kFaceDetectorTFLiteName));
SetExternalFile(face_detector_file,
face_detector_graph_options->mutable_base_options()
->mutable_model_asset(),
@ -132,7 +132,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
if (!face_landmarks_detector_graph_options->base_options()
.has_model_asset()) {
ASSIGN_OR_RETURN(const auto face_landmarks_detector_file,
resources.GetModelFile(kFaceLandmarksDetectorTFLiteName));
resources.GetFile(kFaceLandmarksDetectorTFLiteName));
SetExternalFile(
face_landmarks_detector_file,
face_landmarks_detector_graph_options->mutable_base_options()
@ -146,7 +146,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
->set_use_stream_mode(options->base_options().use_stream_mode());
absl::StatusOr<absl::string_view> face_blendshape_model =
resources.GetModelFile(kFaceBlendshapeTFLiteName);
resources.GetFile(kFaceBlendshapeTFLiteName);
if (face_blendshape_model.ok()) {
SetExternalFile(*face_blendshape_model,
face_landmarks_detector_graph_options
@ -327,7 +327,7 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
// Set the face geometry metdata file for
// FaceGeometryFromLandmarksGraph.
ASSIGN_OR_RETURN(auto face_geometry_pipeline_metadata_file,
model_asset_bundle_resources->GetModelFile(
model_asset_bundle_resources->GetFile(
kFaceGeometryPipelineMetadataName));
SetExternalFile(face_geometry_pipeline_metadata_file,
sc->MutableOptions<FaceLandmarkerGraphOptions>()

View File

@ -92,7 +92,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
GestureRecognizerGraphOptions* options,
bool is_copy) {
ASSIGN_OR_RETURN(const auto hand_landmarker_file,
resources.GetModelFile(kHandLandmarkerBundleAssetName));
resources.GetFile(kHandLandmarkerBundleAssetName));
auto* hand_landmarker_graph_options =
options->mutable_hand_landmarker_graph_options();
SetExternalFile(hand_landmarker_file,
@ -105,9 +105,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
hand_landmarker_graph_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode());
ASSIGN_OR_RETURN(
const auto hand_gesture_recognizer_file,
resources.GetModelFile(kHandGestureRecognizerBundleAssetName));
ASSIGN_OR_RETURN(const auto hand_gesture_recognizer_file,
resources.GetFile(kHandGestureRecognizerBundleAssetName));
auto* hand_gesture_recognizer_graph_options =
options->mutable_hand_gesture_recognizer_graph_options();
SetExternalFile(hand_gesture_recognizer_file,

View File

@ -207,7 +207,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
HandGestureRecognizerGraphOptions* options,
bool is_copy) {
ASSIGN_OR_RETURN(const auto gesture_embedder_file,
resources.GetModelFile(kGestureEmbedderTFLiteName));
resources.GetFile(kGestureEmbedderTFLiteName));
auto* gesture_embedder_graph_options =
options->mutable_gesture_embedder_graph_options();
SetExternalFile(gesture_embedder_file,
@ -218,9 +218,8 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
options->base_options(),
gesture_embedder_graph_options->mutable_base_options());
ASSIGN_OR_RETURN(
const auto canned_gesture_classifier_file,
resources.GetModelFile(kCannedGestureClassifierTFLiteName));
ASSIGN_OR_RETURN(const auto canned_gesture_classifier_file,
resources.GetFile(kCannedGestureClassifierTFLiteName));
auto* canned_gesture_classifier_graph_options =
options->mutable_canned_gesture_classifier_graph_options();
SetExternalFile(
@ -233,7 +232,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
canned_gesture_classifier_graph_options->mutable_base_options());
const auto custom_gesture_classifier_file =
resources.GetModelFile(kCustomGestureClassifierTFLiteName);
resources.GetFile(kCustomGestureClassifierTFLiteName);
if (custom_gesture_classifier_file.ok()) {
has_custom_gesture_classifier = true;
auto* custom_gesture_classifier_graph_options =

View File

@ -97,7 +97,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
options->mutable_hand_detector_graph_options();
if (!hand_detector_graph_options->base_options().has_model_asset()) {
ASSIGN_OR_RETURN(const auto hand_detector_file,
resources.GetModelFile(kHandDetectorTFLiteName));
resources.GetFile(kHandDetectorTFLiteName));
SetExternalFile(hand_detector_file,
hand_detector_graph_options->mutable_base_options()
->mutable_model_asset(),
@ -113,7 +113,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
if (!hand_landmarks_detector_graph_options->base_options()
.has_model_asset()) {
ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file,
resources.GetModelFile(kHandLandmarksDetectorTFLiteName));
resources.GetFile(kHandLandmarksDetectorTFLiteName));
SetExternalFile(
hand_landmarks_detector_file,
hand_landmarks_detector_graph_options->mutable_base_options()