Make LoadBinaryContent work on Windows

PiperOrigin-RevId: 513330348
This commit is contained in:
Sebastian Schmidt 2023-03-01 13:47:06 -08:00 committed by Copybara-Service
parent 07fa5c2fc8
commit abfcd8ec1d
5 changed files with 27 additions and 8 deletions

View File

@ -332,9 +332,11 @@ cc_library(
"//mediapipe/tasks:internal", "//mediapipe/tasks:internal",
], ],
deps = [ deps = [
":external_file_handler",
"//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@flatbuffers//:runtime_cc", "@flatbuffers//:runtime_cc",
@ -375,6 +377,5 @@ cc_test(
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"//mediapipe/tasks/cc/metadata/utils:zip_utils", "//mediapipe/tasks/cc/metadata/utils:zip_utils",
"@org_tensorflow//tensorflow/lite/c:common",
], ],
) )

View File

@ -102,9 +102,13 @@ absl::StatusOr<std::string> PathToResourceAsFile(std::string path) {
#else #else
if (absl::StartsWith(path, "./")) { if (absl::StartsWith(path, "./")) {
path = "mediapipe" + path.substr(1); path = "mediapipe" + path.substr(1);
} else if (path[0] != '/') {
path = "mediapipe/" + path;
} }
std::string error; std::string error;
// TODO: We should ideally use `CreateForTests` when this is
// accessed from unit tests.
std::unique_ptr<::bazel::tools::cpp::runfiles::Runfiles> runfiles( std::unique_ptr<::bazel::tools::cpp::runfiles::Runfiles> runfiles(
::bazel::tools::cpp::runfiles::Runfiles::Create("", &error)); ::bazel::tools::cpp::runfiles::Runfiles::Create("", &error));
if (!runfiles) { if (!runfiles) {

View File

@ -88,6 +88,7 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFile) {
.status()); .status());
} }
#ifndef _WIN32
TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) { TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) {
const int model_file_descriptor = open(kTestModelBundlePath, O_RDONLY); const int model_file_descriptor = open(kTestModelBundlePath, O_RDONLY);
auto model_file = std::make_unique<proto::ExternalFile>(); auto model_file = std::make_unique<proto::ExternalFile>();
@ -103,6 +104,7 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) {
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
#endif // _WIN32
TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) { TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) {
auto file_content = LoadBinaryContent(kTestModelBundlePath); auto file_content = LoadBinaryContent(kTestModelBundlePath);

View File

@ -136,6 +136,7 @@ TEST_F(ModelResourcesTest, CreateFromFile) {
CheckModelResourcesPackets(model_resources.get()); CheckModelResourcesPackets(model_resources.get());
} }
#ifndef _WIN32
TEST_F(ModelResourcesTest, CreateFromFileDescriptor) { TEST_F(ModelResourcesTest, CreateFromFileDescriptor) {
const int model_file_descriptor = open(kTestModelPath, O_RDONLY); const int model_file_descriptor = open(kTestModelPath, O_RDONLY);
auto model_file = std::make_unique<proto::ExternalFile>(); auto model_file = std::make_unique<proto::ExternalFile>();
@ -145,6 +146,7 @@ TEST_F(ModelResourcesTest, CreateFromFileDescriptor) {
ModelResources::Create(kTestModelResourcesTag, std::move(model_file))); ModelResources::Create(kTestModelResourcesTag, std::move(model_file)));
CheckModelResourcesPackets(model_resources.get()); CheckModelResourcesPackets(model_resources.get());
} }
#endif // _WIN32
TEST_F(ModelResourcesTest, CreateFromInvalidFile) { TEST_F(ModelResourcesTest, CreateFromInvalidFile) {
auto model_file = std::make_unique<proto::ExternalFile>(); auto model_file = std::make_unique<proto::ExternalFile>();
@ -168,6 +170,15 @@ TEST_F(ModelResourcesTest, CreateFromInvalidFileDescriptor) {
auto status_or_model_resources = auto status_or_model_resources =
ModelResources::Create(kTestModelResourcesTag, std::move(model_file)); ModelResources::Create(kTestModelResourcesTag, std::move(model_file));
#ifdef _WIN32
EXPECT_EQ(status_or_model_resources.status().code(),
absl::StatusCode::kFailedPrecondition);
EXPECT_THAT(
status_or_model_resources.status().message(),
testing::HasSubstr("File descriptors are not supported on Windows."));
AssertStatusHasMediaPipeTasksStatusCode(status_or_model_resources.status(),
MediaPipeTasksStatus::kFileReadError);
#else
EXPECT_EQ(status_or_model_resources.status().code(), EXPECT_EQ(status_or_model_resources.status().code(),
absl::StatusCode::kInvalidArgument); absl::StatusCode::kInvalidArgument);
EXPECT_THAT( EXPECT_THAT(
@ -176,6 +187,7 @@ TEST_F(ModelResourcesTest, CreateFromInvalidFileDescriptor) {
AssertStatusHasMediaPipeTasksStatusCode( AssertStatusHasMediaPipeTasksStatusCode(
status_or_model_resources.status(), status_or_model_resources.status(),
MediaPipeTasksStatus::kInvalidArgumentError); MediaPipeTasksStatus::kInvalidArgumentError);
#endif // _WIN32
} }
TEST_F(ModelResourcesTest, CreateFailWithCorruptedFile) { TEST_F(ModelResourcesTest, CreateFailWithCorruptedFile) {

View File

@ -23,6 +23,8 @@ limitations under the License.
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "flatbuffers/flatbuffers.h" #include "flatbuffers/flatbuffers.h"
#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h" #include "mediapipe/calculators/core/flow_limiter_calculator.pb.h"
#include "mediapipe/tasks/cc/core/external_file_handler.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -34,13 +36,11 @@ constexpr char kFlowLimiterCalculatorName[] = "FlowLimiterCalculator";
} // namespace } // namespace
std::string LoadBinaryContent(const char* filename) { std::string LoadBinaryContent(const char* filename) {
std::ifstream input_file(filename, std::ios::binary | std::ios::ate); proto::ExternalFile external_file;
// Find buffer size from input file, and load the buffer. external_file.set_file_name(filename);
size_t buffer_size = input_file.tellg(); auto file_handler =
std::string buffer(buffer_size, '\0'); ExternalFileHandler::CreateFromExternalFile(&external_file);
input_file.seekg(0, std::ios::beg); return std::string{(*file_handler)->GetFileContent()};
input_file.read(const_cast<char*>(buffer.c_str()), buffer_size);
return buffer;
} }
int FindTensorIndexByMetadataName( int FindTensorIndexByMetadataName(