Update base audio/vision tasks api to suuport proto3 graph options.
PiperOrigin-RevId: 538661975
This commit is contained in:
parent
a7cd7b9a32
commit
943445fba8
|
@ -43,6 +43,7 @@ cc_library(
|
|||
":base_audio_task_api",
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
|
||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -60,13 +61,8 @@ class AudioTaskApiFactory {
|
|||
"Task graph config should only contain one task subgraph node.",
|
||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||
} else {
|
||||
if (!node.options().HasExtension(Options::ext)) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat(node.calculator(),
|
||||
" is missing the required task options field."),
|
||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||
}
|
||||
MP_RETURN_IF_ERROR(
|
||||
tasks::core::TaskApiFactory::CheckHasValidOptions<Options>(node));
|
||||
found_task_subgraph = true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -81,7 +81,6 @@ class TaskApiFactory {
|
|||
return std::make_unique<T>(std::move(runner));
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename Options>
|
||||
static absl::Status CheckHasValidOptions(
|
||||
const CalculatorGraphConfig::Node& node) {
|
||||
|
|
|
@ -43,6 +43,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:rect",
|
||||
"//mediapipe/tasks/cc/core:base_task_api",
|
||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"@com_google_absl//absl/status",
|
||||
|
@ -58,6 +59,7 @@ cc_library(
|
|||
":base_vision_task_api",
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
||||
|
@ -60,13 +61,8 @@ class VisionTaskApiFactory {
|
|||
"Task graph config should only contain one task subgraph node.",
|
||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||
} else {
|
||||
if (!node.options().HasExtension(Options::ext)) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat(node.calculator(),
|
||||
" is missing the required task options field."),
|
||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||
}
|
||||
MP_RETURN_IF_ERROR(
|
||||
tasks::core::TaskApiFactory::CheckHasValidOptions<Options>(node));
|
||||
found_task_subgraph = true;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user