diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 5aa9c9729..95cfdd15e 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -318,6 +318,9 @@ cc_library( ":model_resources", ":task_runner", ":utils", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework/port:requires", + "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", diff --git a/mediapipe/tasks/cc/core/task_api_factory.h b/mediapipe/tasks/cc/core/task_api_factory.h index 631696b4c..83c2f3207 100644 --- a/mediapipe/tasks/cc/core/task_api_factory.h +++ b/mediapipe/tasks/cc/core/task_api_factory.h @@ -23,7 +23,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/port/requires.h" +#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/model_resources.h" @@ -54,6 +58,8 @@ class TaskApiFactory { std::unique_ptr resolver, PacketsCallback packets_callback = nullptr) { bool found_task_subgraph = false; + // This for-loop ensures there's only one subgraph besides + // FlowLimiterCalculator. for (const auto& node : graph_config.node()) { if (node.calculator() == "FlowLimiterCalculator") { continue; @@ -64,13 +70,7 @@ class TaskApiFactory { "Task graph config should only contain one task subgraph node.", MediaPipeTasksStatus::kInvalidTaskGraphConfigError); } else { - if (!node.options().HasExtension(Options::ext)) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrCat(node.calculator(), - " is missing the required task options field."), - MediaPipeTasksStatus::kInvalidTaskGraphConfigError); - } + MP_RETURN_IF_ERROR(CheckHasValidOptions(node)); found_task_subgraph = true; } } @@ -80,6 +80,35 @@ class TaskApiFactory { std::move(packets_callback))); return std::make_unique(std::move(runner)); } + + private: + template + static absl::Status CheckHasValidOptions( + const CalculatorGraphConfig::Node& node) { + if constexpr (mediapipe::Requires( + [](auto&& o) -> decltype(o.ext) {})) { + if (node.options().HasExtension(Options::ext)) { + return absl::OkStatus(); + } + } else { +#ifndef MEDIAPIPE_PROTO_LITE + for (const auto& option : node.node_options()) { + if (absl::StrContains(option.type_url(), + Options::descriptor()->full_name())) { + return absl::OkStatus(); + } + } +#else // MEDIAPIPE_PROTO_LITE + // Skip the check for proto lite, as Options::descriptor() is unavailable. + return absl::OkStatus(); +#endif // MEDIAPIPE_PROTO_LITE + } + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat(node.calculator(), + " is missing the required task options field."), + MediaPipeTasksStatus::kInvalidTaskGraphConfigError); + } }; } // namespace core