Internal change

PiperOrigin-RevId: 525487344
This commit is contained in:
MediaPipe Team 2023-04-19 10:32:54 -07:00 committed by Copybara-Service
parent 9818ebb630
commit eb62479190
2 changed files with 39 additions and 7 deletions

View File

@ -318,6 +318,9 @@ cc_library(
":model_resources", ":model_resources",
":task_runner", ":task_runner",
":utils", ":utils",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/port:requires",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto",

View File

@ -23,7 +23,11 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/port/requires.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/base_task_api.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
@ -54,6 +58,8 @@ class TaskApiFactory {
std::unique_ptr<tflite::OpResolver> resolver, std::unique_ptr<tflite::OpResolver> resolver,
PacketsCallback packets_callback = nullptr) { PacketsCallback packets_callback = nullptr) {
bool found_task_subgraph = false; bool found_task_subgraph = false;
// This for-loop ensures there's only one subgraph besides
// FlowLimiterCalculator.
for (const auto& node : graph_config.node()) { for (const auto& node : graph_config.node()) {
if (node.calculator() == "FlowLimiterCalculator") { if (node.calculator() == "FlowLimiterCalculator") {
continue; continue;
@ -64,13 +70,7 @@ class TaskApiFactory {
"Task graph config should only contain one task subgraph node.", "Task graph config should only contain one task subgraph node.",
MediaPipeTasksStatus::kInvalidTaskGraphConfigError); MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
} else { } else {
if (!node.options().HasExtension(Options::ext)) { MP_RETURN_IF_ERROR(CheckHasValidOptions<Options>(node));
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat(node.calculator(),
" is missing the required task options field."),
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
}
found_task_subgraph = true; found_task_subgraph = true;
} }
} }
@ -80,6 +80,35 @@ class TaskApiFactory {
std::move(packets_callback))); std::move(packets_callback)));
return std::make_unique<T>(std::move(runner)); return std::make_unique<T>(std::move(runner));
} }
private:
template <typename Options>
static absl::Status CheckHasValidOptions(
const CalculatorGraphConfig::Node& node) {
if constexpr (mediapipe::Requires<Options>(
[](auto&& o) -> decltype(o.ext) {})) {
if (node.options().HasExtension(Options::ext)) {
return absl::OkStatus();
}
} else {
#ifndef MEDIAPIPE_PROTO_LITE
for (const auto& option : node.node_options()) {
if (absl::StrContains(option.type_url(),
Options::descriptor()->full_name())) {
return absl::OkStatus();
}
}
#else // MEDIAPIPE_PROTO_LITE
// Skip the check for proto lite, as Options::descriptor() is unavailable.
return absl::OkStatus();
#endif // MEDIAPIPE_PROTO_LITE
}
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat(node.calculator(),
" is missing the required task options field."),
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
}
}; };
} // namespace core } // namespace core