diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index ce9181d51..fa61feb9d 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -261,6 +261,7 @@ cc_library_with_tflite( deps = [ "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:executor", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:name_util", "//mediapipe/tasks/cc:common", @@ -319,6 +320,7 @@ cc_library( ":task_runner", ":utils", "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:executor", "//mediapipe/framework/port:requires", "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", diff --git a/mediapipe/tasks/cc/core/task_api_factory.h b/mediapipe/tasks/cc/core/task_api_factory.h index 6291f361e..3f173818f 100644 --- a/mediapipe/tasks/cc/core/task_api_factory.h +++ b/mediapipe/tasks/cc/core/task_api_factory.h @@ -17,6 +17,7 @@ limitations under the License. #define MEDIAPIPE_TASKS_CC_CORE_TASK_API_FACTORY_H_ #include +#include #include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/executor.h" #include "mediapipe/framework/port/requires.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" @@ -56,7 +58,9 @@ class TaskApiFactory { static absl::StatusOr> Create( CalculatorGraphConfig graph_config, std::unique_ptr resolver, - PacketsCallback packets_callback = nullptr) { + PacketsCallback packets_callback = nullptr, + std::shared_ptr default_executor = nullptr, + std::optional input_side_packets = std::nullopt) { bool found_task_subgraph = false; // This for-loop ensures there's only one subgraph besides // FlowLimiterCalculator. @@ -77,7 +81,9 @@ class TaskApiFactory { MP_ASSIGN_OR_RETURN( auto runner, core::TaskRunner::Create(std::move(graph_config), std::move(resolver), - std::move(packets_callback))); + std::move(packets_callback), + std::move(default_executor), + std::move(input_side_packets))); return std::make_unique(std::move(runner)); } diff --git a/mediapipe/tasks/cc/core/task_runner.cc b/mediapipe/tasks/cc/core/task_runner.cc index 797dfa35f..88c91bcdb 100644 --- a/mediapipe/tasks/cc/core/task_runner.cc +++ b/mediapipe/tasks/cc/core/task_runner.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -33,6 +34,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/executor.h" #include "mediapipe/framework/tool/name_util.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" @@ -89,17 +91,22 @@ absl::StatusOr GenerateOutputPacketMap( absl::StatusOr> TaskRunner::Create( CalculatorGraphConfig config, std::unique_ptr op_resolver, - PacketsCallback packets_callback) { + PacketsCallback packets_callback, + std::shared_ptr default_executor, + std::optional input_side_packets) { auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback)); - MP_RETURN_IF_ERROR( - task_runner->Initialize(std::move(config), std::move(op_resolver))); + MP_RETURN_IF_ERROR(task_runner->Initialize( + std::move(config), std::move(op_resolver), std::move(default_executor), + std::move(input_side_packets))); MP_RETURN_IF_ERROR(task_runner->Start()); return task_runner; } absl::Status TaskRunner::Initialize( CalculatorGraphConfig config, - std::unique_ptr op_resolver) { + std::unique_ptr op_resolver, + std::shared_ptr default_executor, + std::optional input_side_packets) { if (initialized_) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -123,7 +130,9 @@ absl::Status TaskRunner::Initialize( MediaPipeTasksStatus::kRunnerInitializationError); } config.clear_output_stream(); - PacketMap input_side_packets; + if (!input_side_packets) { + input_side_packets.emplace(); + } if (packets_callback_) { tool::AddMultiStreamCallback( output_stream_names_, @@ -132,7 +141,7 @@ absl::Status TaskRunner::Initialize( GenerateOutputPacketMap(packets, output_stream_names_)); return; }, - &config, &input_side_packets, + &config, &input_side_packets.value(), /*observe_timestamp_bounds=*/true); } else { mediapipe::tool::AddMultiStreamCallback( @@ -142,8 +151,14 @@ absl::Status TaskRunner::Initialize( GenerateOutputPacketMap(packets, output_stream_names_); return; }, - &config, &input_side_packets, /*observe_timestamp_bounds=*/true); + &config, &input_side_packets.value(), + /*observe_timestamp_bounds=*/true); } + + if (default_executor) { + MP_RETURN_IF_ERROR(graph_.SetExecutor("", std::move(default_executor))); + } + auto model_resources_cache = std::make_shared(std::move(op_resolver)); MP_RETURN_IF_ERROR( @@ -152,7 +167,7 @@ absl::Status TaskRunner::Initialize( "ModelResourcesCacheService is not set up successfully.", MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError)); MP_RETURN_IF_ERROR( - AddPayload(graph_.Initialize(std::move(config), input_side_packets), + AddPayload(graph_.Initialize(std::move(config), *input_side_packets), "MediaPipe CalculatorGraph is not successfully initialized.", MediaPipeTasksStatus::kRunnerInitializationError)); initialized_ = true; diff --git a/mediapipe/tasks/cc/core/task_runner.h b/mediapipe/tasks/cc/core/task_runner.h index cd77c0555..810063d4b 100644 --- a/mediapipe/tasks/cc/core/task_runner.h +++ b/mediapipe/tasks/cc/core/task_runner.h @@ -24,6 +24,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -34,6 +35,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/executor.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" @@ -73,7 +75,9 @@ class TaskRunner { static absl::StatusOr> Create( CalculatorGraphConfig config, std::unique_ptr op_resolver = nullptr, - PacketsCallback packets_callback = nullptr); + PacketsCallback packets_callback = nullptr, + std::shared_ptr default_executor = nullptr, + std::optional input_side_packets = std::nullopt); // TaskRunner is neither copyable nor movable. TaskRunner(const TaskRunner&) = delete; @@ -125,7 +129,9 @@ class TaskRunner { // be only initialized once. absl::Status Initialize( CalculatorGraphConfig config, - std::unique_ptr op_resolver = nullptr); + std::unique_ptr op_resolver = nullptr, + std::shared_ptr default_executor = nullptr, + std::optional input_side_packets = std::nullopt); // Starts the task runner. Returns an ok status to indicate that the // runner is ready to accept input data. Otherwise, returns an error status to