Internal change
PiperOrigin-RevId: 525487344
This commit is contained in:
parent
9818ebb630
commit
eb62479190
|
@ -318,6 +318,9 @@ cc_library(
|
||||||
":model_resources",
|
":model_resources",
|
||||||
":task_runner",
|
":task_runner",
|
||||||
":utils",
|
":utils",
|
||||||
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/framework/port:requires",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||||
|
|
|
@ -23,7 +23,11 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/port/requires.h"
|
||||||
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
@ -54,6 +58,8 @@ class TaskApiFactory {
|
||||||
std::unique_ptr<tflite::OpResolver> resolver,
|
std::unique_ptr<tflite::OpResolver> resolver,
|
||||||
PacketsCallback packets_callback = nullptr) {
|
PacketsCallback packets_callback = nullptr) {
|
||||||
bool found_task_subgraph = false;
|
bool found_task_subgraph = false;
|
||||||
|
// This for-loop ensures there's only one subgraph besides
|
||||||
|
// FlowLimiterCalculator.
|
||||||
for (const auto& node : graph_config.node()) {
|
for (const auto& node : graph_config.node()) {
|
||||||
if (node.calculator() == "FlowLimiterCalculator") {
|
if (node.calculator() == "FlowLimiterCalculator") {
|
||||||
continue;
|
continue;
|
||||||
|
@ -64,13 +70,7 @@ class TaskApiFactory {
|
||||||
"Task graph config should only contain one task subgraph node.",
|
"Task graph config should only contain one task subgraph node.",
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||||
} else {
|
} else {
|
||||||
if (!node.options().HasExtension(Options::ext)) {
|
MP_RETURN_IF_ERROR(CheckHasValidOptions<Options>(node));
|
||||||
return CreateStatusWithPayload(
|
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
absl::StrCat(node.calculator(),
|
|
||||||
" is missing the required task options field."),
|
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
|
||||||
}
|
|
||||||
found_task_subgraph = true;
|
found_task_subgraph = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -80,6 +80,35 @@ class TaskApiFactory {
|
||||||
std::move(packets_callback)));
|
std::move(packets_callback)));
|
||||||
return std::make_unique<T>(std::move(runner));
|
return std::make_unique<T>(std::move(runner));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename Options>
|
||||||
|
static absl::Status CheckHasValidOptions(
|
||||||
|
const CalculatorGraphConfig::Node& node) {
|
||||||
|
if constexpr (mediapipe::Requires<Options>(
|
||||||
|
[](auto&& o) -> decltype(o.ext) {})) {
|
||||||
|
if (node.options().HasExtension(Options::ext)) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#ifndef MEDIAPIPE_PROTO_LITE
|
||||||
|
for (const auto& option : node.node_options()) {
|
||||||
|
if (absl::StrContains(option.type_url(),
|
||||||
|
Options::descriptor()->full_name())) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else // MEDIAPIPE_PROTO_LITE
|
||||||
|
// Skip the check for proto lite, as Options::descriptor() is unavailable.
|
||||||
|
return absl::OkStatus();
|
||||||
|
#endif // MEDIAPIPE_PROTO_LITE
|
||||||
|
}
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrCat(node.calculator(),
|
||||||
|
" is missing the required task options field."),
|
||||||
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace core
|
} // namespace core
|
||||||
|
|
Loading…
Reference in New Issue
Block a user