From a00759007d1f4b8db3afe6b395efc07d692ec1e0 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 2 Oct 2023 09:47:19 -0700 Subject: [PATCH] Add error handling to C API PiperOrigin-RevId: 570094642 --- mediapipe/tasks/c/text/text_classifier/BUILD | 2 +- .../c/text/text_classifier/text_classifier.cc | 40 +++++++++++++------ .../c/text/text_classifier/text_classifier.h | 21 ++++++++-- .../text_classifier/text_classifier_test.cc | 20 ++++++++++ 4 files changed, 65 insertions(+), 18 deletions(-) diff --git a/mediapipe/tasks/c/text/text_classifier/BUILD b/mediapipe/tasks/c/text/text_classifier/BUILD index 29e421567..1d2924e87 100644 --- a/mediapipe/tasks/c/text/text_classifier/BUILD +++ b/mediapipe/tasks/c/text/text_classifier/BUILD @@ -30,6 +30,7 @@ cc_library( "//mediapipe/tasks/c/core:base_options_converter", "//mediapipe/tasks/cc/text/text_classifier", "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", ], alwayslink = 1, ) @@ -78,7 +79,6 @@ cc_test( "//mediapipe/framework/deps:file_path", "//mediapipe/framework/port:gtest", "//mediapipe/tasks/c/components/containers:category", - "//mediapipe/tasks/cc/components/containers:category", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", diff --git a/mediapipe/tasks/c/text/text_classifier/text_classifier.cc b/mediapipe/tasks/c/text/text_classifier/text_classifier.cc index cc3d44a69..f25c572e6 100644 --- a/mediapipe/tasks/c/text/text_classifier/text_classifier.cc +++ b/mediapipe/tasks/c/text/text_classifier/text_classifier.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/log/absl_log.h" +#include "absl/status/status.h" #include "mediapipe/tasks/c/components/containers/classification_result_converter.h" #include "mediapipe/tasks/c/components/processors/classifier_options.h" #include "mediapipe/tasks/c/components/processors/classifier_options_converter.h" @@ -38,9 +39,18 @@ using ::mediapipe::tasks::c::components::processors:: CppConvertToClassifierOptions; using ::mediapipe::tasks::c::core::CppConvertToBaseOptions; using ::mediapipe::tasks::text::text_classifier::TextClassifier; + +int CppProcessError(absl::Status status, char** error_msg) { + if (error_msg) { + *error_msg = strdup(status.ToString().c_str()); + } + return status.raw_code(); +} + } // namespace -TextClassifier* CppTextClassifierCreate(const TextClassifierOptions& options) { +TextClassifier* CppTextClassifierCreate(const TextClassifierOptions& options, + char** error_msg) { auto cpp_options = std::make_unique< ::mediapipe::tasks::text::text_classifier::TextClassifierOptions>(); @@ -52,49 +62,53 @@ TextClassifier* CppTextClassifierCreate(const TextClassifierOptions& options) { if (!classifier.ok()) { ABSL_LOG(ERROR) << "Failed to create TextClassifier: " << classifier.status(); + CppProcessError(classifier.status(), error_msg); return nullptr; } return classifier->release(); } -bool CppTextClassifierClassify(void* classifier, const char* utf8_str, - TextClassifierResult* result) { +int CppTextClassifierClassify(void* classifier, const char* utf8_str, + TextClassifierResult* result, char** error_msg) { auto cpp_classifier = static_cast(classifier); auto cpp_result = cpp_classifier->Classify(utf8_str); if (!cpp_result.ok()) { ABSL_LOG(ERROR) << "Classification failed: " << cpp_result.status(); - return false; + return CppProcessError(cpp_result.status(), error_msg); } CppConvertToClassificationResult(*cpp_result, result); - return true; + return 0; } void CppTextClassifierCloseResult(TextClassifierResult* result) { CppCloseClassificationResult(result); } -void CppTextClassifierClose(void* classifier) { +int CppTextClassifierClose(void* classifier, char** error_msg) { auto cpp_classifier = static_cast(classifier); auto result = cpp_classifier->Close(); if (!result.ok()) { ABSL_LOG(ERROR) << "Failed to close TextClassifier: " << result; + return CppProcessError(result, error_msg); } delete cpp_classifier; + return 0; } } // namespace mediapipe::tasks::c::text::text_classifier extern "C" { -void* text_classifier_create(struct TextClassifierOptions* options) { +void* text_classifier_create(struct TextClassifierOptions* options, + char** error_msg) { return mediapipe::tasks::c::text::text_classifier::CppTextClassifierCreate( - *options); + *options, error_msg); } int text_classifier_classify(void* classifier, const char* utf8_str, - TextClassifierResult* result) { + TextClassifierResult* result, char** error_msg) { return mediapipe::tasks::c::text::text_classifier::CppTextClassifierClassify( - classifier, utf8_str, result); + classifier, utf8_str, result, error_msg); } void text_classifier_close_result(TextClassifierResult* result) { @@ -102,9 +116,9 @@ void text_classifier_close_result(TextClassifierResult* result) { result); } -void text_classifier_close(void* classifier) { - mediapipe::tasks::c::text::text_classifier::CppTextClassifierClose( - classifier); +int text_classifier_close(void* classifier, char** error_ms) { + return mediapipe::tasks::c::text::text_classifier::CppTextClassifierClose( + classifier, error_ms); } } // extern "C" diff --git a/mediapipe/tasks/c/text/text_classifier/text_classifier.h b/mediapipe/tasks/c/text/text_classifier/text_classifier.h index 7794eb0b3..057b00f99 100644 --- a/mediapipe/tasks/c/text/text_classifier/text_classifier.h +++ b/mediapipe/tasks/c/text/text_classifier/text_classifier.h @@ -42,18 +42,31 @@ struct TextClassifierOptions { }; // Creates a TextClassifier from the provided `options`. -MP_EXPORT void* text_classifier_create(struct TextClassifierOptions* options); +// Returns a pointer to the text classifier on success. +// If an error occurs, returns `nullptr` and sets the error parameter to an +// an error message (if `error_msg` is not nullptr). You must free the memory +// allocated for the error message. +MP_EXPORT void* text_classifier_create(struct TextClassifierOptions* options, + char** error_msg = nullptr); -// Performs classification on the input `text`. +// Performs classification on the input `text`. Returns `0` on success. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not nullptr). You must free the memory +// allocated for the error message. MP_EXPORT int text_classifier_classify(void* classifier, const char* utf8_str, - TextClassifierResult* result); + TextClassifierResult* result, + char** error_msg = nullptr); // Frees the memory allocated inside a TextClassifierResult result. Does not // free the result pointer itself. MP_EXPORT void text_classifier_close_result(TextClassifierResult* result); // Shuts down the TextClassifier when all the work is done. Frees all memory. -MP_EXPORT void text_classifier_close(void* classifier); +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not nullptr). You must free the memory +// allocated for the error message. +MP_EXPORT int text_classifier_close(void* classifier, + char** error_msg = nullptr); #ifdef __cplusplus } // extern C diff --git a/mediapipe/tasks/c/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/c/text/text_classifier/text_classifier_test.cc index e3815dd5f..51232d63a 100644 --- a/mediapipe/tasks/c/text/text_classifier/text_classifier_test.cc +++ b/mediapipe/tasks/c/text/text_classifier/text_classifier_test.cc @@ -15,17 +15,20 @@ limitations under the License. #include "mediapipe/tasks/c/text/text_classifier/text_classifier.h" +#include #include #include "absl/flags/flag.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/tasks/c/components/containers/category.h" namespace { using ::mediapipe::file::JoinPath; +using testing::HasSubstr; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite"; @@ -67,4 +70,21 @@ TEST(TextClassifierTest, SmokeTest) { text_classifier_close(classifier); } +TEST(TextClassifierTest, ErrorHandling) { + // It is an error to set neither the asset buffer nor the path. + TextClassifierOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_path= */ nullptr}, + /* classifier_options= */ {}, + }; + + char* error_msg; + void* classifier = text_classifier_create(&options, &error_msg); + EXPECT_EQ(classifier, nullptr); + + EXPECT_THAT(error_msg, HasSubstr("INVALID_ARGUMENT")); + + free(error_msg); +} + } // namespace