Rename *ModelFile to *File for methods of ModelAssetBundleResources.
PiperOrigin-RevId: 516667461
This commit is contained in:
parent
cd2cc971bb
commit
9a89b47572
|
@ -51,12 +51,11 @@ ModelAssetBundleResources::Create(
|
|||
auto model_bundle_resources = absl::WrapUnique(
|
||||
new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file)));
|
||||
MP_RETURN_IF_ERROR(
|
||||
model_bundle_resources->ExtractModelFilesFromExternalFileProto());
|
||||
model_bundle_resources->ExtractFilesFromExternalFileProto());
|
||||
return model_bundle_resources;
|
||||
}
|
||||
|
||||
absl::Status
|
||||
ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
|
||||
absl::Status ModelAssetBundleResources::ExtractFilesFromExternalFileProto() {
|
||||
if (model_asset_bundle_file_->has_file_name()) {
|
||||
// If the model asset bundle file name is a relative path, searches the file
|
||||
// in a platform-specific location and returns the absolute path on success.
|
||||
|
@ -72,34 +71,32 @@ ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
|
|||
model_asset_bundle_file_handler_->GetFileContent().data();
|
||||
size_t buffer_size =
|
||||
model_asset_bundle_file_handler_->GetFileContent().size();
|
||||
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size,
|
||||
&model_files_);
|
||||
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, &files_);
|
||||
}
|
||||
|
||||
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetModelFile(
|
||||
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetFile(
|
||||
const std::string& filename) const {
|
||||
auto it = model_files_.find(filename);
|
||||
if (it == model_files_.end()) {
|
||||
auto model_files = ListModelFiles();
|
||||
std::string all_model_files =
|
||||
absl::StrJoin(model_files.begin(), model_files.end(), ", ");
|
||||
auto it = files_.find(filename);
|
||||
if (it == files_.end()) {
|
||||
auto files = ListFiles();
|
||||
std::string all_files = absl::StrJoin(files.begin(), files.end(), ", ");
|
||||
|
||||
return CreateStatusWithPayload(
|
||||
StatusCode::kNotFound,
|
||||
absl::StrFormat("No model file with name: %s. All model files in the "
|
||||
"model asset bundle are: %s.",
|
||||
filename, all_model_files),
|
||||
absl::StrFormat("No file with name: %s. All files in the model asset "
|
||||
"bundle are: %s.",
|
||||
filename, all_files),
|
||||
MediaPipeTasksStatus::kFileNotFoundError);
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> ModelAssetBundleResources::ListModelFiles() const {
|
||||
std::vector<std::string> model_names;
|
||||
for (const auto& [model_name, _] : model_files_) {
|
||||
model_names.push_back(model_name);
|
||||
std::vector<std::string> ModelAssetBundleResources::ListFiles() const {
|
||||
std::vector<std::string> file_names;
|
||||
for (const auto& [file_name, _] : files_) {
|
||||
file_names.push_back(file_name);
|
||||
}
|
||||
return model_names;
|
||||
return file_names;
|
||||
}
|
||||
|
||||
} // namespace core
|
||||
|
|
|
@ -28,8 +28,8 @@ namespace core {
|
|||
// The mediapipe task model asset bundle resources class.
|
||||
// A ModelAssetBundleResources object, created from an external file proto,
|
||||
// contains model asset bundle related resources and the method to extract the
|
||||
// tflite models or model asset bundles for the mediapipe sub-tasks. As the
|
||||
// resources are owned by the ModelAssetBundleResources object
|
||||
// tflite models, resource files or model asset bundles for the mediapipe
|
||||
// sub-tasks. As the resources are owned by the ModelAssetBundleResources object
|
||||
// callers must keep ModelAssetBundleResources alive while using any of the
|
||||
// resources.
|
||||
class ModelAssetBundleResources {
|
||||
|
@ -50,14 +50,13 @@ class ModelAssetBundleResources {
|
|||
// Returns the model asset bundle resources tag.
|
||||
std::string GetTag() const { return tag_; }
|
||||
|
||||
// Gets the contents of the model file (either tflite model file or model
|
||||
// bundle file) with the provided name. An error is returned if there is no
|
||||
// such model file.
|
||||
absl::StatusOr<absl::string_view> GetModelFile(
|
||||
const std::string& filename) const;
|
||||
// Gets the contents of the model file (either tflite model file, resource
|
||||
// file or model bundle file) with the provided name. An error is returned if
|
||||
// there is no such model file.
|
||||
absl::StatusOr<absl::string_view> GetFile(const std::string& filename) const;
|
||||
|
||||
// Lists all the model file names in the model asset model.
|
||||
std::vector<std::string> ListModelFiles() const;
|
||||
// Lists all the file names in the model asset model.
|
||||
std::vector<std::string> ListFiles() const;
|
||||
|
||||
private:
|
||||
// Constructor.
|
||||
|
@ -65,9 +64,9 @@ class ModelAssetBundleResources {
|
|||
const std::string& tag,
|
||||
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
|
||||
|
||||
// Extracts the model files (either tflite model file or model bundle file)
|
||||
// from the external file proto.
|
||||
absl::Status ExtractModelFilesFromExternalFileProto();
|
||||
// Extracts the model files (either tflite model file, resource file or model
|
||||
// bundle file) from the external file proto.
|
||||
absl::Status ExtractFilesFromExternalFileProto();
|
||||
|
||||
// The model asset bundle resources tag.
|
||||
const std::string tag_;
|
||||
|
@ -78,11 +77,11 @@ class ModelAssetBundleResources {
|
|||
// The ExternalFileHandler for the model asset bundle.
|
||||
std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_;
|
||||
|
||||
// The model files bundled in model asset bundle, as a map with the filename
|
||||
// The files bundled in model asset bundle, as a map with the filename
|
||||
// (corresponding to a basename, e.g. "hand_detector.tflite") as key and
|
||||
// a pointer to the file contents as value. Each model file can be either
|
||||
// a TFLite model file or a model bundle file for sub-task.
|
||||
absl::flat_hash_map<std::string, absl::string_view> model_files_;
|
||||
// a pointer to the file contents as value. Each file can be either a TFLite
|
||||
// model file, resource file or a model bundle file for sub-task.
|
||||
absl::flat_hash_map<std::string, absl::string_view> files_;
|
||||
};
|
||||
|
||||
} // namespace core
|
||||
|
|
|
@ -66,10 +66,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromBinaryContent) {
|
|||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
||||
.status());
|
||||
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
||||
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||
.status());
|
||||
}
|
||||
|
||||
|
@ -81,10 +80,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFile) {
|
|||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
||||
.status());
|
||||
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
||||
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||
.status());
|
||||
}
|
||||
|
||||
|
@ -98,10 +96,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) {
|
|||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
||||
.status());
|
||||
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
||||
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||
.status());
|
||||
}
|
||||
#endif // _WIN32
|
||||
|
@ -115,10 +112,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) {
|
|||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
||||
.status());
|
||||
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
||||
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||
.status());
|
||||
}
|
||||
|
||||
|
@ -147,7 +143,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
|
|||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
auto status_or_model_bundle_file =
|
||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task");
|
||||
model_bundle_resources->GetFile("dummy_hand_landmarker.task");
|
||||
MP_EXPECT_OK(status_or_model_bundle_file.status());
|
||||
|
||||
// Creates sub-task model asset bundle resources.
|
||||
|
@ -159,10 +155,10 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
|
|||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(hand_landmaker_model_file)));
|
||||
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
|
||||
->GetModelFile("dummy_hand_detector.tflite")
|
||||
->GetFile("dummy_hand_detector.tflite")
|
||||
.status());
|
||||
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
|
||||
->GetModelFile("dummy_hand_landmarker.tflite")
|
||||
->GetFile("dummy_hand_landmarker.tflite")
|
||||
.status());
|
||||
}
|
||||
|
||||
|
@ -175,7 +171,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidTFLiteModelFile) {
|
|||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
auto status_or_model_bundle_file =
|
||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite");
|
||||
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite");
|
||||
MP_EXPECT_OK(status_or_model_bundle_file.status());
|
||||
|
||||
// Verify tflite model works.
|
||||
|
@ -200,11 +196,11 @@ TEST(ModelAssetBundleResourcesTest, ExtractInvalidModelFile) {
|
|||
auto model_bundle_resources,
|
||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
auto status = model_bundle_resources->GetModelFile("not_found.task").status();
|
||||
auto status = model_bundle_resources->GetFile("not_found.task").status();
|
||||
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
|
||||
EXPECT_THAT(status.message(),
|
||||
testing::HasSubstr(
|
||||
"No model file with name: not_found.task. All model files in "
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr("No file with name: not_found.task. All files in "
|
||||
"the model asset bundle are: "));
|
||||
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
|
||||
testing::Optional(absl::Cord(
|
||||
|
@ -219,7 +215,7 @@ TEST(ModelAssetBundleResourcesTest, ListModelFiles) {
|
|||
auto model_bundle_resources,
|
||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
auto model_files = model_bundle_resources->ListModelFiles();
|
||||
auto model_files = model_bundle_resources->ListFiles();
|
||||
std::vector<std::string> expected_model_files = {
|
||||
"dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"};
|
||||
std::sort(model_files.begin(), model_files.end());
|
||||
|
|
|
@ -116,7 +116,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
options->mutable_face_detector_graph_options();
|
||||
if (!face_detector_graph_options->base_options().has_model_asset()) {
|
||||
ASSIGN_OR_RETURN(const auto face_detector_file,
|
||||
resources.GetModelFile(kFaceDetectorTFLiteName));
|
||||
resources.GetFile(kFaceDetectorTFLiteName));
|
||||
SetExternalFile(face_detector_file,
|
||||
face_detector_graph_options->mutable_base_options()
|
||||
->mutable_model_asset(),
|
||||
|
@ -132,7 +132,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
if (!face_landmarks_detector_graph_options->base_options()
|
||||
.has_model_asset()) {
|
||||
ASSIGN_OR_RETURN(const auto face_landmarks_detector_file,
|
||||
resources.GetModelFile(kFaceLandmarksDetectorTFLiteName));
|
||||
resources.GetFile(kFaceLandmarksDetectorTFLiteName));
|
||||
SetExternalFile(
|
||||
face_landmarks_detector_file,
|
||||
face_landmarks_detector_graph_options->mutable_base_options()
|
||||
|
@ -146,7 +146,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
->set_use_stream_mode(options->base_options().use_stream_mode());
|
||||
|
||||
absl::StatusOr<absl::string_view> face_blendshape_model =
|
||||
resources.GetModelFile(kFaceBlendshapeTFLiteName);
|
||||
resources.GetFile(kFaceBlendshapeTFLiteName);
|
||||
if (face_blendshape_model.ok()) {
|
||||
SetExternalFile(*face_blendshape_model,
|
||||
face_landmarks_detector_graph_options
|
||||
|
@ -327,7 +327,7 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
|
|||
// Set the face geometry metdata file for
|
||||
// FaceGeometryFromLandmarksGraph.
|
||||
ASSIGN_OR_RETURN(auto face_geometry_pipeline_metadata_file,
|
||||
model_asset_bundle_resources->GetModelFile(
|
||||
model_asset_bundle_resources->GetFile(
|
||||
kFaceGeometryPipelineMetadataName));
|
||||
SetExternalFile(face_geometry_pipeline_metadata_file,
|
||||
sc->MutableOptions<FaceLandmarkerGraphOptions>()
|
||||
|
|
|
@ -92,7 +92,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
GestureRecognizerGraphOptions* options,
|
||||
bool is_copy) {
|
||||
ASSIGN_OR_RETURN(const auto hand_landmarker_file,
|
||||
resources.GetModelFile(kHandLandmarkerBundleAssetName));
|
||||
resources.GetFile(kHandLandmarkerBundleAssetName));
|
||||
auto* hand_landmarker_graph_options =
|
||||
options->mutable_hand_landmarker_graph_options();
|
||||
SetExternalFile(hand_landmarker_file,
|
||||
|
@ -105,9 +105,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
hand_landmarker_graph_options->mutable_base_options()->set_use_stream_mode(
|
||||
options->base_options().use_stream_mode());
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto hand_gesture_recognizer_file,
|
||||
resources.GetModelFile(kHandGestureRecognizerBundleAssetName));
|
||||
ASSIGN_OR_RETURN(const auto hand_gesture_recognizer_file,
|
||||
resources.GetFile(kHandGestureRecognizerBundleAssetName));
|
||||
auto* hand_gesture_recognizer_graph_options =
|
||||
options->mutable_hand_gesture_recognizer_graph_options();
|
||||
SetExternalFile(hand_gesture_recognizer_file,
|
||||
|
|
|
@ -207,7 +207,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
HandGestureRecognizerGraphOptions* options,
|
||||
bool is_copy) {
|
||||
ASSIGN_OR_RETURN(const auto gesture_embedder_file,
|
||||
resources.GetModelFile(kGestureEmbedderTFLiteName));
|
||||
resources.GetFile(kGestureEmbedderTFLiteName));
|
||||
auto* gesture_embedder_graph_options =
|
||||
options->mutable_gesture_embedder_graph_options();
|
||||
SetExternalFile(gesture_embedder_file,
|
||||
|
@ -218,9 +218,8 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
options->base_options(),
|
||||
gesture_embedder_graph_options->mutable_base_options());
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto canned_gesture_classifier_file,
|
||||
resources.GetModelFile(kCannedGestureClassifierTFLiteName));
|
||||
ASSIGN_OR_RETURN(const auto canned_gesture_classifier_file,
|
||||
resources.GetFile(kCannedGestureClassifierTFLiteName));
|
||||
auto* canned_gesture_classifier_graph_options =
|
||||
options->mutable_canned_gesture_classifier_graph_options();
|
||||
SetExternalFile(
|
||||
|
@ -233,7 +232,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
canned_gesture_classifier_graph_options->mutable_base_options());
|
||||
|
||||
const auto custom_gesture_classifier_file =
|
||||
resources.GetModelFile(kCustomGestureClassifierTFLiteName);
|
||||
resources.GetFile(kCustomGestureClassifierTFLiteName);
|
||||
if (custom_gesture_classifier_file.ok()) {
|
||||
has_custom_gesture_classifier = true;
|
||||
auto* custom_gesture_classifier_graph_options =
|
||||
|
|
|
@ -97,7 +97,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
options->mutable_hand_detector_graph_options();
|
||||
if (!hand_detector_graph_options->base_options().has_model_asset()) {
|
||||
ASSIGN_OR_RETURN(const auto hand_detector_file,
|
||||
resources.GetModelFile(kHandDetectorTFLiteName));
|
||||
resources.GetFile(kHandDetectorTFLiteName));
|
||||
SetExternalFile(hand_detector_file,
|
||||
hand_detector_graph_options->mutable_base_options()
|
||||
->mutable_model_asset(),
|
||||
|
@ -113,7 +113,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
if (!hand_landmarks_detector_graph_options->base_options()
|
||||
.has_model_asset()) {
|
||||
ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file,
|
||||
resources.GetModelFile(kHandLandmarksDetectorTFLiteName));
|
||||
resources.GetFile(kHandLandmarksDetectorTFLiteName));
|
||||
SetExternalFile(
|
||||
hand_landmarks_detector_file,
|
||||
hand_landmarks_detector_graph_options->mutable_base_options()
|
||||
|
|
Loading…
Reference in New Issue
Block a user