Internal change
PiperOrigin-RevId: 525487344
This commit is contained in:
parent
9818ebb630
commit
eb62479190
|
@ -318,6 +318,9 @@ cc_library(
|
|||
":model_resources",
|
||||
":task_runner",
|
||||
":utils",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/port:requires",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||
|
|
|
@ -23,7 +23,11 @@ limitations under the License.
|
|||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/port/requires.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
|
@ -54,6 +58,8 @@ class TaskApiFactory {
|
|||
std::unique_ptr<tflite::OpResolver> resolver,
|
||||
PacketsCallback packets_callback = nullptr) {
|
||||
bool found_task_subgraph = false;
|
||||
// This for-loop ensures there's only one subgraph besides
|
||||
// FlowLimiterCalculator.
|
||||
for (const auto& node : graph_config.node()) {
|
||||
if (node.calculator() == "FlowLimiterCalculator") {
|
||||
continue;
|
||||
|
@ -64,13 +70,7 @@ class TaskApiFactory {
|
|||
"Task graph config should only contain one task subgraph node.",
|
||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||
} else {
|
||||
if (!node.options().HasExtension(Options::ext)) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat(node.calculator(),
|
||||
" is missing the required task options field."),
|
||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||
}
|
||||
MP_RETURN_IF_ERROR(CheckHasValidOptions<Options>(node));
|
||||
found_task_subgraph = true;
|
||||
}
|
||||
}
|
||||
|
@ -80,6 +80,35 @@ class TaskApiFactory {
|
|||
std::move(packets_callback)));
|
||||
return std::make_unique<T>(std::move(runner));
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename Options>
|
||||
static absl::Status CheckHasValidOptions(
|
||||
const CalculatorGraphConfig::Node& node) {
|
||||
if constexpr (mediapipe::Requires<Options>(
|
||||
[](auto&& o) -> decltype(o.ext) {})) {
|
||||
if (node.options().HasExtension(Options::ext)) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
} else {
|
||||
#ifndef MEDIAPIPE_PROTO_LITE
|
||||
for (const auto& option : node.node_options()) {
|
||||
if (absl::StrContains(option.type_url(),
|
||||
Options::descriptor()->full_name())) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
#else // MEDIAPIPE_PROTO_LITE
|
||||
// Skip the check for proto lite, as Options::descriptor() is unavailable.
|
||||
return absl::OkStatus();
|
||||
#endif // MEDIAPIPE_PROTO_LITE
|
||||
}
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat(node.calculator(),
|
||||
" is missing the required task options field."),
|
||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace core
|
||||
|
|
Loading…
Reference in New Issue
Block a user