Update base audio/vision tasks api to suuport proto3 graph options.
PiperOrigin-RevId: 538661975
This commit is contained in:
parent
a7cd7b9a32
commit
943445fba8
|
@ -43,6 +43,7 @@ cc_library(
|
||||||
":base_audio_task_api",
|
":base_audio_task_api",
|
||||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
|
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
@ -60,13 +61,8 @@ class AudioTaskApiFactory {
|
||||||
"Task graph config should only contain one task subgraph node.",
|
"Task graph config should only contain one task subgraph node.",
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||||
} else {
|
} else {
|
||||||
if (!node.options().HasExtension(Options::ext)) {
|
MP_RETURN_IF_ERROR(
|
||||||
return CreateStatusWithPayload(
|
tasks::core::TaskApiFactory::CheckHasValidOptions<Options>(node));
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
absl::StrCat(node.calculator(),
|
|
||||||
" is missing the required task options field."),
|
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
|
||||||
}
|
|
||||||
found_task_subgraph = true;
|
found_task_subgraph = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -81,7 +81,6 @@ class TaskApiFactory {
|
||||||
return std::make_unique<T>(std::move(runner));
|
return std::make_unique<T>(std::move(runner));
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
template <typename Options>
|
template <typename Options>
|
||||||
static absl::Status CheckHasValidOptions(
|
static absl::Status CheckHasValidOptions(
|
||||||
const CalculatorGraphConfig::Node& node) {
|
const CalculatorGraphConfig::Node& node) {
|
||||||
|
|
|
@ -43,6 +43,7 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers:rect",
|
"//mediapipe/tasks/cc/components/containers:rect",
|
||||||
"//mediapipe/tasks/cc/core:base_task_api",
|
"//mediapipe/tasks/cc/core:base_task_api",
|
||||||
|
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||||
"//mediapipe/tasks/cc/core:task_runner",
|
"//mediapipe/tasks/cc/core:task_runner",
|
||||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
|
@ -58,6 +59,7 @@ cc_library(
|
||||||
":base_vision_task_api",
|
":base_vision_task_api",
|
||||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
|
||||||
|
@ -60,13 +61,8 @@ class VisionTaskApiFactory {
|
||||||
"Task graph config should only contain one task subgraph node.",
|
"Task graph config should only contain one task subgraph node.",
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||||
} else {
|
} else {
|
||||||
if (!node.options().HasExtension(Options::ext)) {
|
MP_RETURN_IF_ERROR(
|
||||||
return CreateStatusWithPayload(
|
tasks::core::TaskApiFactory::CheckHasValidOptions<Options>(node));
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
absl::StrCat(node.calculator(),
|
|
||||||
" is missing the required task options field."),
|
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
|
||||||
}
|
|
||||||
found_task_subgraph = true;
|
found_task_subgraph = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user